I am trying to solve FizzBuzz using Keras and it works quite well for numbers between 1 and 10.000 (90-100% win rate and close to 0 loss). However, if I try even higher numbers, that is numbers between 1 and 100.000 it doesn't seem to perform well (~50% win rate, loss ~0.3). In fact, it performs quite poorly and I have no clue what I can do to solve this task. So far I am using a very simple neural net architecture with 3 hidden layers:
model = Sequential()
model.add(Dense(2000, input_dim=state_size, activation="relu"))
model.add(Dense(1000, activation="relu"))
model.add(Dense(500, activation="relu"))
model.add(Dense(num_actions, activation="softmax"))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=["accuracy"])
I found that the more neurons I have the better it performs, at least for numbers below 10.000.
I am training my neural net in a step-wise fashion, meaning that I am not computing the inputs and targets beforehand, but instead train the network step by step. Again, this works quite well and it shouldn't make a difference right? Here's the main loop:
for epoch in range(np_epochs):
action = random_number()
x_raw = to_binary(action)
x = np.expand_dims(x_raw, 0)
prediction = model.predict(x)
y, victory, _, _ = check_prediction(action, prediction)
memory.append((x_raw, y))
curr_batch_size = min(batch_size, len(memory))
batch = random.sample(memory, curr_batch_size)
inputs = []
targets = []
for i, t in batch:
inputs.append(i)
targets.append(t)
if victory:
wins += 1
loss, accuracy = model.train_on_batch(np.array(inputs), np.array(targets))
As you can see, I am training my network not on decimal numbers but convert them into binary first before feeding it into the net.
Another thing to mention here is that I am using a memory, to make it more like a supervised problem. I thought it may perform better if train on numbers that the neural net has already been trained on. It doesn't seem to make any difference at all.
Is there anything I can do to solve this particular problem with a neural net? I mean is it so hard for a function approximator to figure out the simple math behind FizzBuzz? Am I doing something wrong? Do you suggest a different architecture?
See my code on MachineLabs. You can simply fork my lab and fiddle with it if you want. To view to code, simply click on the 'Editor' tab at the top.