Theano MiniBatch Iterator not working
I coded a minibatch iterator to get predicted results from my neural network. However, i made some tests and noticed a few errors.
Basically :
If batch_size > amount of inputs : error
I made a script to show this bug in my code. Its shown below:
import numpy as np
def minibatch_iterator_predictor(inputs, batch_size):
assert len(inputs) > 0
for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
excerpt = slice(start_idx, start_idx + batch_size)
yield inputs[excerpt]
def test(x, batch_size):
prediction = np.empty((x.shape[0], 2), dtype=np.float32)
index = 0
for batch in minibatch_iterator_predictor(inputs=x, batch_size=batch_size):
inputs = batch
# y = self.predict_function(inputs)
y = inputs
prediction[index * batch_size:batch_size * (index + 1), :] = y[:]
index += 1
return prediction
######################################
#TEST SCRIPT
######################################
#Input
arr = np.zeros(shape=(10, 2))
arr[0] = [1, 0]
arr[1] = [2, 0]
arr[2] = [3, 0]
arr[3] = [4, 0]
arr[4] = [5, 0]
arr[5] = [6, 0]
arr[6] = [7, 0]
arr[7] = [8, 0]
arr[8] = [9, 0]
arr[9] = [10, 0]
###############################################
batch_size = 5
print "\nBatch_size ", batch_size
r = test(x=arr, batch_size=batch_size)
#Debug
for k in xrange(r.shape[0]):
print str(k) + " : " + str(r[k])
##Assert
assert arr.shape[0] == r.shape[0]
for k in xrange(0,r.shape[0]):
print r[k] == arr[k]
Here are the Tests
For batch_size = 10 :
Batch_size 10
0 : [ 1. 0.]
1 : [ 2. 0.]
2 : [ 3. 0.]
3 : [ 4. 0.]
4 : [ 5. 0.]
5 : [ 6. 0.]
6 : [ 7. 0.]
7 : [ 8. 0.]
8 : [ 9. 0.]
9 : [ 10. 0.]
For batch_size = 11 :
0 : [ 1.13876845e-37 0.00000000e+00]
1 : [ 1.14048027e-37 0.00000000e+00]
2 : [ 1.14048745e-37 0.00000000e+00]
3 : [ 9.65151604e-38 0.00000000e+00]
4 : [ 1.14002468e-37 0.00000000e+00]
5 : [ 1.14340036e-37 0.00000000e+00]
6 : [ 1.14343264e-37 0.00000000e+00]
7 : [ 8.02794698e-38 0.00000000e+00]
8 : [ 8.02794698e-38 0.00000000e+00]
9 : [ 8.02794698e-38 0.00000000e+00]
For Batch_size 12
0 : [ 1.13876845e-37 0.00000000e+00]
1 : [ 1.14048027e-37 0.00000000e+00]
2 : [ 1.14048745e-37 0.00000000e+00]
3 : [ 9.65151604e-38 0.00000000e+00]
4 : [ 1.14002468e-37 0.00000000e+00]
5 : [ 1.14340036e-37 0.00000000e+00]
6 : [ 1.14343264e-37 0.00000000e+00]
7 : [ 8.10141537e-38 0.00000000e+00]
8 : [ 8.10141537e-38 0.00000000e+00]
9 : [ 8.10141537e-38 0.00000000e+00]
How can i fix this?