class Dataset(object):
def __init__(self, X, y, batch_size, shuffle=False):
"""
Construct a Dataset object to iterate over data X and labels y
Inputs:
- X: Numpy array of data, of any shape
- batch_size: Integer giving number of elements per minibatch
- shuffle: (optional) Boolean, whether to shuffle the data on each
"""
assert X.shape[0] == y.shape[0], 'Got different numbers of data and labels'
self.X, self.y = X, y
self.batch_size, self.shuffle = batch_size, shuffle
def __iter__(self):
N, B = self.X.shape[0], self.batch_size
idxs = np.arange(N)
if self.shuffle:
np.random.shuffle(idxs)
return iter((self.X[i:i+B], self.y[i:i+B]) for i in range(0, N, B))
Asked
Active
Viewed 101 times
0

barbsan
- 3,418
- 11
- 21
- 28

user8801679
- 1
- 1
-
What's the problem with this code? Do you get any error? – barbsan Jan 31 '19 at 07:28
-
i don't understand how shuffling of data occur in __iter__ – user8801679 Jan 31 '19 at 07:55