I have a text generation task learning to predict the next word with an LSTM network with multiple output layers. After the generation of a sentence has finished, I calculate a reward for the whole sentence and try to update the output layers participated in the generation (contributing layers get the calculated reward value, others get 0). My problem is that even if I update only the selected output layers, it seems that other layer's weights got updated instead.
I have a minimized example with dummy data to present the problem:
import random
import numpy as np
import tensorflow as tf
from keras.layers import Input, LSTM, Dense, Embedding
from keras.utils import pad_sequences
from tensorflow.keras.models import Model
def policy_gradient_loss(y_true, y_pred):
return tf.reduce_mean(tf.math.log(y_pred) * float(y_true))
# Define the model with 3 output layers (named 'a', 'b' and 'c').
input_layer = Input(shape=(4,))
embedding_layer = Embedding(input_dim=10, output_dim=4)(input_layer)
lstm_layer = LSTM(4)(embedding_layer)
output_layers = [Dense(3, activation='softmax', name=name)(lstm_layer) for name in ['a', 'b', 'c']]
model = Model(inputs=input_layer, outputs=output_layers)
model.compile(loss=[policy_gradient_loss] * 3, optimizer='adam', run_eagerly=True)
# Dummy input data.
input_data = np.array([[2, 3, 4, 5]])
# Create target data to reward only the 'b' output layer.
target_data = [np.array([0]) for _ in range(len(model._output_layers))]
target_data[1] = np.array([10])
# Save initial weights.
initial_weights = model.get_weights()
model.train_on_batch(input_data, y=target_data)
# Save weights after the learning.
updated_weights = model.get_weights()
# Compare the before-after weights.
for layer_idx, (layer_name, initial_w, updated_w) in enumerate(zip([layer.name for layer in model.layers], initial_weights, updated_weights)):
if not tf.math.reduce_all(tf.equal(initial_w, updated_w)):
print(f'The weights in layer {layer_idx} ({layer_name}). has changed.')
Result:
The weights in layer 0 (input_1). has changed.
The weights in layer 1 (embedding). has changed.
The weights in layer 2 (lstm). has changed.
The weights in layer 3 (a). has changed.
My expectation would be to get the layer 4. (output layer 'b') updated instead of layer 'a' (or at least beside 'a').
What am I missing? Is my expectation or my implementation wrong? (Or both...?)