0

Here is my custom gym env..

class PricePredictor(gym.Env):

    def __init__(self):

        ...

        self.action_space = gym.spaces.Discrete(3,start=-1)

        self.observation_space = gym.spaces.Dict({
            'on_trade_price':gym.spaces.Box(low=0,high=1,shape=(1,)),
            'state':gym.spaces.Discrete(3,start=0),
            'prices':gym.spaces.Box(low=0,high=1,shape=(30,))
        })

      ...    
    def reset(self):
        
        ....

        return obs


    def act(self,action):
        ....


    def step(self,action):
       ....

        return self.get_obs(),self.rewards,False,{}


    def get_obs(self):
        obs = {}
        obs['prices'] = self.data
        obs['on_trade_price'] = self.on_trade_price
        obs['state'] = self.state
        return obs

And here is the neural network and agent:

tf.compat.v1.experimental.output_all_intermediates(True)

def build_model():

    dense = keras.layers.Dense
    conv = keras.layers.Conv1D
    maxpool = keras.layers.MaxPool1D
    dropout = keras.layers.Dropout
    flatten = keras.layers.Flatten
    lstm = keras.layers.LSTM

    prices_input = keras.layers.Input(shape=(1,33),name='prices')
    state_input = keras.layers.Input(shape=(1,),name='state')
    p_trade_input = keras.layers.Input(shape=(1,),name='on_trade_price')

    state_trade_input = keras.layers.concatenate([
    keras.layers.Reshape((1, 1))(state_input),
    keras.layers.Reshape((1, 1))(p_trade_input)], axis=-1)

    prices = conv(64,strides=1,kernel_size=2,padding='same',activation='relu')(prices_input)
    prices = conv(256,strides=1,kernel_size=2,padding='same',activation='relu')(prices)
    prices = conv(128,strides=1,kernel_size=2,padding='same',activation='relu')(prices)
    prices = maxpool(padding='same')(prices)
    prices = conv(128,strides=1,kernel_size=2,padding='same',activation='relu')(prices)
    prices = conv(128,strides=1,kernel_size=2,padding='same',activation='relu')(prices)
    prices = maxpool(padding='same')(prices)
    prices = conv(64,strides=1,kernel_size=2,padding='same',activation='relu')(prices)
    prices = conv(64,strides=1,kernel_size=2,padding='same',activation='relu')(prices)
    prices = conv(64,strides=1,kernel_size=2,padding='same',activation='relu')(prices)
    prices = maxpool(padding='same')(prices)
    prices = flatten()(prices)

    state_trade = lstm(8,return_sequences=True)(state_trade_input)
    state_trade = lstm(16,return_sequences=True)(state_trade)
    state_trade = lstm(16,return_sequences=True)(state_trade)
    state_trade = lstm(4,return_sequences=False)(state_trade)
    state_trade = flatten()(state_trade)

    main = keras.layers.concatenate([state_trade,prices])
    main = dense(64,activation='relu')(main)
    main = dense(64,activation='relu')(main)
    main = dense(32,activation='relu')(main)
    main = dense(16,activation='relu')(main)

    main = dense(9,activation='sigmoid')(main)
    output = dense(3,activation='sigmoid')(main)

    model = keras.models.Model(inputs=[prices_input,state_input,p_trade_input],outputs=output)

    return model




model = build_model()
memory = SequentialMemory(limit=100000,window_length=1)

agent = DQNAgent(model=model,memory=memory,policy=BoltzmannQPolicy(),nb_actions=3,nb_steps_warmup=10)
agent.compile(optimizer='sgd',metrics=['accuracy'])
agent.fit(env=env,nb_steps=10000,visualize=False)

The get obs and reset functions will return 3 inputs, 1 discrete and 2 box spaces as you can see.

But whenever I execute the code, I get this error:

ValueError: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 3 array(s), for inputs ['prices', 'state', 'on_trade_price'] but instead got the following list of 1 arrays: [array([[{'prices': [1.0, 0.41346304723049343, 0.3527381139181216, 0.2611015220461468, 0.21434175427589253, 0.21449866624823244, 0.21449866624823244, 0.21449866624823244, 0.4032637690256138, 0.3387729...

I understood why I am getting this error. I want to know how to fix this!

James Z
  • 12,209
  • 10
  • 24
  • 44

0 Answers0