I have been trying to write a custom loss function in Keras to include reward in Reinforcement Learning.
The model takes current state image and previous action as input. The previous action is concatenated at a later stage in the model. Custom_CE_Loss is used to add reward component to the CCE loss.
def create_model(input_img, episode_reward):
model = Conv2D(filters=16, kernel_size=(3, 3), activation='relu')(input_img)
model = MaxPooling2D(pool_size=(2, 2))(model)
model = BatchNormalization()(model)
model = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(model)
model = MaxPooling2D(pool_size=(2, 2))(model)
model = BatchNormalization()(model)
model = Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(model)
model = MaxPooling2D(pool_size=(2, 2))(model)
model = BatchNormalization()(model)
model = Flatten()(model)
model = Concatenate()([model, episode_reward])
model = BatchNormalization()(model)
model = Dense(units=2049, activation='relu')(model)
model = Dropout(0.1)(model)
model = Dense(units=128, activation='relu')(model)
output = Dense(units=3, activation='softmax')(model)
return output
class Custom_CE_Loss(tf.keras.losses.Loss):
def __init__(self, reward_list):
super().__init__()
self.reward_list = reward_list
def call(self, y_true, y_pred):
log_y_pred = tf.math.log(y_pred)
elements = -tf.math.multiply_no_nan(x=log_y_pred, y=y_true)
elements = tf.math.multiply(elements, self.reward_list)
return tf.reduce_mean(tf.reduce_sum(elements,axis=1))
input_img = Input(shape=(100,100,1), name = 'screen_diff')
episode_reward = Input(shape=(1,),name = 'episode_reward')
policy_network_train = create_model(input_img, episode_reward)
policy_network_train = keras.models.Model(inputs=[input_img, episode_reward], outputs=policy_network_train)
policy_network_train.compile(optimizer='Adam',loss=Custom_CE_Loss(episode_reward))
model_train.fit(x = [state_list, reward_list_processed], y = action_list, batch_size=32, verbose=2)
I get the following error
File "/opt/miniconda3/lib/python3.9/site-packages/keras/engine/training.py", line 1051, in train_function *
return step_function(self, iterator)
File "/opt/miniconda3/lib/python3.9/site-packages/keras/engine/training.py", line 1040, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/opt/miniconda3/lib/python3.9/site-packages/keras/engine/training.py", line 1030, in run_step **
outputs = model.train_step(data)
File "/opt/miniconda3/lib/python3.9/site-packages/keras/engine/training.py", line 890, in train_step
loss = self.compute_loss(x, y, y_pred, sample_weight)
File "/opt/miniconda3/lib/python3.9/site-packages/keras/engine/training.py", line 948, in compute_loss
return self.compiled_loss(
File "/opt/miniconda3/lib/python3.9/site-packages/keras/engine/compile_utils.py", line 239, in __call__
self._loss_metric.update_state(
File "/opt/miniconda3/lib/python3.9/site-packages/keras/utils/metrics_utils.py", line 70, in decorated
update_op = update_state_fn(*args, **kwargs)
File "/opt/miniconda3/lib/python3.9/site-packages/keras/metrics/base_metric.py", line 140, in update_state_fn
return ag_update_state(*args, **kwargs)
File "/opt/miniconda3/lib/python3.9/site-packages/keras/metrics/base_metric.py", line 449, in update_state **
sample_weight = tf.__internal__.ops.broadcast_weights(
File "/opt/miniconda3/lib/python3.9/site-packages/keras/engine/keras_tensor.py", line 254, in __array__
raise TypeError(
TypeError: You are passing KerasTensor(type_spec=TensorSpec(shape=(), dtype=tf.float32, name=None), name='Placeholder:0', description="created by layer 'tf.cast_2'"), an intermediate Keras symbolic input/output, to a TF API that does not allow registering custom dispatchers, such as `tf.cond`, `tf.function`, gradient tapes, or `tf.map_fn`. Keras Functional model construction only supports TF API calls that *do* support dispatching, such as `tf.math.add` or `tf.reshape`. Other APIs cannot be called directly on symbolic Kerasinputs/outputs. You can work around this limitation by putting the operation in a custom Keras layer `call` and calling that layer on this symbolic input/output.