Just as the title says, I keep running into an error when following a tutorial to make a reinforcement learning agent using keras RL. The code of which is below:
import gym
import random
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Convolution2D
from tensorflow.keras.optimizers import Adam
from rl.agents import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
def build_model(height, width, channels, actions):
model = Sequential()
model.add(Convolution2D(32, (8,8), strides=(4,4), activation='relu', input_shape=(3,height, width, channels)))
model.add(Convolution2D(64, (4,4), strides=(2,2), activation='relu'))
model.add(Convolution2D(64, (3,3), activation='relu'))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(Dense(actions, activation='linear'))
return model
def build_agent(model, actions):
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.2, nb_steps=10000)
memory = SequentialMemory(limit=1000, window_length=3)
dqn = DQNAgent(model=model, memory=memory, policy=policy,
enable_dueling_network=True, dueling_type='avg',
nb_actions=actions, nb_steps_warmup=1000
)
return dqn
env = gym.make('SpaceInvaders-v4')
height, width, channels = env.observation_space.shape
actions = env.action_space.n
env.unwrapped.get_action_meanings()
model = build_model(height, width, channels, actions)
model.summary()
dqn = build_agent(model, actions)
dqn.compile(Adam(learning_rate=1e-4))
dqn.fit(env, 10000, visualize=False, verbose=2)
The code seems to be fine up until the fit code, which gives the following error: AttributeError: 'int' object has no attribute 'shape'
Which the code for which is stored within keras-RL itself I believe, being:
def zeroed_observation(observation):
"""Return an array of zeros with same shape as given observation
# Argument
observation (list): List of observation
# Return
A np.ndarray of zeros with observation.shape
"""
if hasattr(observation, 'shape'):
return np.zeros(observation.shape)
if isinstance(observation, dict):
keys = observation.keys()
obs = dict()
for key in keys:
obs[key] = np.zeros(observation[key].shape) <--- This line is the problem
return obs
elif hasattr(observation, '__iter__'):
out = []
for x in observation:
out.append(zeroed_observation(x))
return out
else:
return 0.
Im not entirely sure what it is Im doing wrong, as Im following said tutorial to a T and nothing is working. Im using python 3.10, tensorflow 2.11.0, keras-rl2 1.0.4 and gym[atari]==0.18.0
Ive tried modifying the dqn.fit line to contain more possibly missing atributes, but that doesnt seem to work dqn.fit(env, 10000, 1, None, 1, False, 0, None, 10000,None) <--- The attempted fix, im aware not all the values are the same but still, same error no matter changing them
Aside from this, Im not sure what Im doing wrong, as Im a newcomer to keras-rl, and its self taught for the most part.
Full error:
Traceback (most recent call last):
File "C:\Users\Thomas Burns\Desktop\Python File Attempt\test Bank Heist Model.py", line 43, in <module>
dqn.fit(env, nb_steps = 10000, visualize=False, verbose=0)
File "C:\Users\Thomas Burns\AppData\Roaming\Python\Python310\site-packages\rl\core.py", line 168, in fit
action = self.forward(observation)
File "C:\Users\Thomas Burns\AppData\Roaming\Python\Python310\site-packages\rl\agents\dqn.py", line 223, in forward
state = self.memory.get_recent_state(observation)
File "C:\Users\Thomas Burns\AppData\Roaming\Python\Python310\site-packages\rl\memory.py", line 107, in get_recent_state
state.insert(0, zeroed_observation(state[0]))
File "C:\Users\Thomas Burns\AppData\Roaming\Python\Python310\site-packages\rl\memory.py", line 63, in zeroed_observation
out.append(zeroed_observation(x))
File "C:\Users\Thomas Burns\AppData\Roaming\Python\Python310\site-packages\rl\memory.py", line 58, in zeroed_observation
obs[key] = np.zeros(observation[key].shape)
AttributeError: 'int' object has no attribute 'shape'