0

This is the minimal example to reproduce the problem:

from keras.models import Sequential
from keras.layers import Dense, Flatten, LeakyReLU
from keras.regularizers import l1
from rl.agents.dqn import DQNAgent

reg = l1(1e-5)
relu_alpha = 0.01

model = Sequential()

model.add(Flatten(input_shape=[128,10,20]))
model.add(Dense(16, kernel_regularizer = reg))
model.add(LeakyReLU(alpha = relu_alpha))
model.add(Dense(3, activation = "linear", kernel_regularizer = reg))

model.compile(loss='mse', jit_compile=True)

The result is the following error:

  File "C:\Users\lbosc\anaconda3\envs\ml_env3\lib\site-packages\keras\engine\training_v1.py", line 306, in compile
    raise TypeError(
TypeError: Invalid keyword argument(s) in `compile`: {'jit_compile'}

I was able to understand that the problem is with the line import DQNAgent which re-defines the compile method of Sequential with its own. In fact, if you delete that import the compilation finishes correctly. The compile method of keras.models.Sequential is defined in the file training.py of the library and accepts the parameter jit_compile, while the compile method used by rl.agents.dqn.DQNAgent at the end calls the compile method of keras defined in the file training_v1.py, which is a bit different and does not accept the parameter jit_compile.

So, since in the successive part of the code I need the DQNAgent, the question at the end is: how do I force keras-rl2 to use training.py instead of training_v1.py? Or, in alternative, how do I say to my code which is the compile method to use?

Luca
  • 169
  • 8

1 Answers1

0

I'm posting the solution I finally found, even if I did not fully understand why it works. The __init__ file in rl.agents imports rl.agents.ddpg; there you should replace the following call:

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

with:

from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()

and there you are: keras-rl starts using the v2 functions of tensorflow instead of the v1 compatibility ones and the rest of my code, which is compatible with v2, does not find conflicts anymore.

Luca
  • 169
  • 8