0

I am trying to use keras-rl2 DQNAgent to solve the taxi problem in open AI Gym. For a quick refresh, please find it in Gym-Documentation, thank you! https://www.gymlibrary.dev/environments/toy_text/taxi/

Here are my process: 0.Open the Taxi-v3 environment from gym 1.Build the deep learning model by keras Sequential API with Embedding and Dense layers 2.Import the Epsilon Greedy policy and Sequential Memory deque from keras-rl2's rl 3.input the model, policy, and the memory in to rl.agent.DQNAgent and compile the model

But when i fit the model(agent) the error pops up:

Training for 1000000 steps ...
Interval 1 (0 steps performed)

---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

<ipython-input-180-908ee27d8389> in <module>
      1 agent.compile(Adam(lr=0.001),metrics=['mae'])
----> 2 agent.fit(env, nb_steps=1000000, visualize=False, verbose=1, nb_max_episode_steps=99, log_interval=100000)

/usr/local/lib/python3.8/dist-packages/rl/core.py in fit(self, env, nb_steps, action_repetition, callbacks, verbose, visualize, nb_max_start_steps, start_step_policy, log_interval, nb_max_episode_steps)
    179                         observation, r, done, info = self.processor.process_step(observation, r, done, info)
    180                     for key, value in info.items():
--> 181                         if not np.isreal(value):
    182                             continue
    183                         if key not in accumulated_info:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

I tried to run the code to Cart Pole problem there's no error came out. I am wondering if the states in taxi problem is just a scalar (500), not like cart-pole has a state of an array with 4 elements? Please help or a little advise will help a lot, also if you can help me to extend the steps more than 200 is better!!(env._max_episode_steps=5000)

#import environment and visualization
import gym
from gym import wrappers
!pip install gym[classic_control]

#import Deep Learning api
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Flatten, Input, Embedding,Reshape
from tensorflow.keras.optimizers import Adam

#import rl agent library
!pip install gym
!pip install keras
!pip install keras-rl2

#data manipulation
import numpy as np
import pandas as pd
import random 
#0
env = gym.make('Taxi-v3')
env.reset()
actions=env.action_space.n
states=env.observation_space.n
#1
def build_model(states,actions):
  model=Sequential()
  model.add(Embedding(states,10, input_length=1))
  model.add(Reshape((10,)))
  model.add(Dense(32,activation='relu'))
  model.add(Dense(32,activation='relu'))
  model.add(Dense(actions,activation='linear'))
  return model
#2
import rl
from rl.agents import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory

policy=EpsGreedyQPolicy()
memory=SequentialMemory(limit=100000,window_length=1)
#3
agent=DQNAgent(model=model1,memory=memory,policy=policy,nb_actions=actions,nb_steps_warmup=500, target_model_update=1e-2)
agent.compile(Adam(lr=0.001),metrics=['mae'])

agent.fit(env, nb_steps=1000000, visualize=False, verbose=1, nb_max_episode_steps=99,

1 Answers1

0

This ValueError comes from the way Keras RL handles the info returned by the environment. As you can see on the line https://github.com/keras-rl/keras-rl/blob/v0.4.2/rl/core.py#L181, it loops on each item of the info map and runs np.isreal(value).

And quoting the Taxi documentation for gym:

In v0.25.0, info["action_mask"] contains a np.ndarray for each of the action specifying if the action will change the state.

You can run gym.__version__ to confirm that you have a version greater or equal to 0.25.0.

To leverage the current Keras RL library (up to 0.4.2), you should install a gym version less than 0.25.0. Additionally, you can submit a PR to keras-rl to handle np.ndarray values without error.

tacon
  • 321
  • 1
  • 6