0

On the last day, I'm trying to deal with an error I get in the DQNAGENT fit function. I get the following error:

ValueError: Error when checking input: expected dense_input to have 2 dimensions, but got array with shape (1, 3, 4)

in dqnagent.fit function. I tried to use a custom flappy bird env to train a DQNAgent with it and with my custom state. I get it in

DQN_flappy\venv\lib\site-packages\rl\core.py", line 168, in fit
    action = self.forward(observation)

if it makes sense to someone here. It looks like they are just adding another dimension to their code for some reason that I don't know why but I guess I made a problem somewhere here. Here's the env code(the Game() is the game itself, I checked manually and also did a NEAT project with it so I'm almost certainly sure it's not the problem):

class flappy_env(Env):

    def __init__(self):
        self.game = Game()
        self.observation_space = Box(low=np.array([-0.4, -2.0, -1.0, -1.0], dtype=np.float32),
                                     high=np.array([1.0, 2.0, 1.0, 0.5], dtype=np.float32))
        self.action_space = Discrete(2)

    def step(self, action):
        done, score, reward = self.game.play_step(action)
        state = self.game.get_state()
        info = {}
        return state, reward, done, info

    def render(self):
        pass

    def reset(self):
        self.game.reset_game()
        return self.game.get_state()

and the game.get_state():

def get_state(self):
    if len(self.pipe_group) > 0:
       bird_y_loc = self.flappy.rect.y
       x_dist_pipe_bird = self.pipe_group.sprites()[0].rect.left - self.flappy.rect.right
       bot_pipe_y_loc = self.pipe_group.sprites()[0].rect.top - bird_y_loc
       top_pipe_y_loc = self.pipe_group.sprites()[1].rect.bottom - bird_y_loc
       return np.array([x_dist_pipe_bird / 500, 10 * bot_pipe_y_loc / screen_height,
                        5 * top_pipe_y_loc / screen_height, self.flappy.vel / 35], dtype=np.float32)
    # shouldn't get here
    return None

here's the model code in case it helps:

<pre><code>
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
from rl.agents import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
from env import flappy_env


def build_model():
    model = Sequential()
    model.add(Dense(16, input_shape=(4,), activation='relu'))
    model.add(Dense(16, activation='relu'))
    model.add(Dense(2, activation='linear'))
    return model


def build_agent(sequential_model):
        policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.2, nb_steps=10000)
    memory = SequentialMemory(limit=1000, window_length=3)
    dqn = DQNAgent(model=sequential_model, memory=memory, policy=policy,
                   enable_dueling_network=True, dueling_type='avg',
                   nb_actions=2, nb_steps_warmup=1000)
    return dqn


env = flappy_env()
model = build_model()
model.summary()
dqn = build_agent(model)
dqn.compile(Adam(learning_rate=1e-4))
dqn.fit(env, nb_steps=10000, visualize=False)

I am debugging it for so many hours and can't find anything except a weird line in dqn.py (belongs to keras.rl2) q_values = self.compute_batch_q_values([state]).flatten() which adds a new dimention as you can see. I also saw that the 3 in the (1,3,4) shape is my window_length in the agent memory. I'm trying to add as much as needed so adding the whole console as well:

<pre><code>
Training for 10000 steps ...
Interval 1 (0 steps performed)
Traceback (most recent call last):
  File "C:\Users\kfir\PycharmProjects\DQN_flappy\model.py", line 33, in <module>
    dqn.fit(env, nb_steps=10000, visualize=False)
  File "C:\Users\kfir\PycharmProjects\DQN_flappy\venv\lib\site-packages\rl\core.py", line 168, in fit
    action = self.forward(observation)
  File "C:\Users\kfir\PycharmProjects\DQN_flappy\venv\lib\site-packages\rl\agents\dqn.py", line 224, in forward
    q_values = self.compute_q_values(state)
  File "C:\Users\kfir\PycharmProjects\DQN_flappy\venv\lib\site-packages\rl\agents\dqn.py", line 68, in compute_q_values
    q_values = self.compute_batch_q_values([state]).flatten()
  File "C:\Users\kfir\PycharmProjects\DQN_flappy\venv\lib\site-packages\rl\agents\dqn.py", line 63, in compute_batch_q_values
    q_values = self.model.predict_on_batch(batch)
  File "C:\Users\kfir\PycharmProjects\DQN_flappy\venv\lib\site-packages\keras\engine\training_v1.py", line 1305, in predict_on_batch
    inputs, _, _ = self._standardize_user_data(
  File "C:\Users\kfir\PycharmProjects\DQN_flappy\venv\lib\site-packages\keras\engine\training_v1.py", line 2652, in _standardize_user_data
    return self._standardize_tensors(
  File "C:\Users\kfir\PycharmProjects\DQN_flappy\venv\lib\site-packages\keras\engine\training_v1.py", line 2693, in _standardize_tensors
    x = training_utils_v1.standardize_input_data(
  File "C:\Users\kfir\PycharmProjects\DQN_flappy\venv\lib\site-packages\keras\engine\training_utils_v1.py", line 712, in standardize_input_data
    raise ValueError(
ValueError: Error when checking input: expected dense_input to have 2 dimensions, but got array with shape (1, 3, 4)

BTW if there are more files that are needed and I didn't supply, I will add them. Please help me find it I tried everything I could think of. Huge thanks in advance to everyone trying to help/read the question!

I expected that there won't be any weird dimensions added

user
  • 1,220
  • 1
  • 12
  • 31
kfir
  • 1
  • 2
  • in the step function, can you check the shape of the `state` parameter ? – Toyo Apr 25 '23 at 01:01
  • @Toyo Hey, the code doesn't reach the step function, I get those errors before. I did check however the state shape in the reset method, and got (4, ) which is the same as it should be – kfir Apr 25 '23 at 08:07
  • could you post a complete source code in order to reproduce that error please ? – Toyo Apr 25 '23 at 08:25
  • sure, here's the full source code: [link] (https://github.com/kfir7755/DQN_flappy_keras.git) thanks for the help! – kfir Apr 25 '23 at 12:03
  • that link is 404. – Toyo Apr 26 '23 at 00:38
  • sorry for that. I made the project public so now you will be able to see it – kfir Apr 26 '23 at 00:44

0 Answers0