Source code for eazygrad.data.dataloader
import random
import numpy as np
[docs]
class Dataloader:
"""
Very simple dataloader with no multiprocessing (mostly for MNIST which is already loaded in RAM)
"""
def __init__(self, dataset, batch_size, shuffle=True, drop_last=True):
self.dataset = dataset
self.batch_size = batch_size
size = len(dataset.data)
self.indices = list(range(size))
self.num_batch = size // self.batch_size
remainder = size % self.batch_size
if remainder != 0 and not drop_last:
self.num_batch += 1
self.shuffle = shuffle
self.drop_last = drop_last
def __iter__(self):
if self.shuffle:
random.shuffle(self.indices)
for i in range(self.num_batch):
batch_idx = self.indices[i*self.batch_size:(i+1)*self.batch_size]
d = self.dataset.data[batch_idx]
t = self.dataset.targets[batch_idx]
yield (d,t)