0

I am learning Reinforcement Learning in Python with Stable Baselines 3 referencing a tutorial by sentdex. The problem when I run check the code using check_env()I get an error AssertionError: The observation returned by the reset() method does not match the given observation space. Clearly I don't know what wrong with return in the reset method.

Here is the code:

import gym
from gym import spaces
import numpy as np
import cv2
import random
import time
from collections import deque

SNAKE_LEN_GOAL = 30


def collision_with_apple(apple_position, score):
    apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
    score += 1
    return apple_position, score


def collision_with_boundaries(snake_head):
    if snake_head[0] >= 500 or snake_head[0] < 0 or snake_head[1] >= 500 or snake_head[1] < 0:
        return 1
    else:
        return 0


def collision_with_self(snake_position):
    snake_head = snake_position[0]
    if snake_head in snake_position[1:]:
        return 1
    else:
        return 0


class SnekEnv(gym.Env):

    def __init__(self):
        super(SnekEnv, self).__init__()
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions:
        self.prev_actions = deque(maxlen=SNAKE_LEN_GOAL)  # however long we aspire the snake to be
        self.action_space = spaces.Discrete(4)
        # Example for using image as input (channel-first; channel-last also works):
        self.observation_space = spaces.Box(low=-500, high=500,
                                            shape=(5 + SNAKE_LEN_GOAL,), dtype=np.float32)

    def step(self, action):
        self.prev_actions.append(action)
        cv2.imshow('a', self.img)
        cv2.waitKey(1)
        self.img = np.zeros((500, 500, 3), dtype='uint8')
        # Display Apple
        cv2.rectangle(self.img, (self.apple_position[0], self.apple_position[1]),
                      (self.apple_position[0] + 10, self.apple_position[1] + 10), (0, 0, 255), 3)
        # Display Snake
        for position in self.snake_position:
            cv2.rectangle(self.img, (position[0], position[1]), (position[0] + 10, position[1] + 10), (0, 255, 0), 3)

        # Takes step after fixed time
        t_end = time.time() + 0.05
        k = -1
        while time.time() < t_end:
            if k == -1:
                k = cv2.waitKey(1)
            else:
                continue

        button_direction = action
        # Change the head position based on the button direction
        if button_direction == 1:
            self.snake_head[0] += 10
        elif button_direction == 0:
            self.snake_head[0] -= 10
        elif button_direction == 2:
            self.snake_head[1] += 10
        elif button_direction == 3:
            self.snake_head[1] -= 10

        # Increase Snake length on eating apple
        if self.snake_head == self.apple_position:
            self.apple_position, self.score = collision_with_apple(self.apple_position, self.score)
            self.snake_position.insert(0, list(self.snake_head))

        else:
            self.snake_position.insert(0, list(self.snake_head))
            self.snake_position.pop()

        # On collision kill the snake and print the score
        if collision_with_boundaries(self.snake_head) == 1 or collision_with_self(self.snake_position) == 1:
            font = cv2.FONT_HERSHEY_SIMPLEX
            self.img = np.zeros((500, 500, 3), dtype='uint8')
            cv2.putText(self.img, 'Your Score is {}'.format(self.score), (140, 250), font, 1, (255, 255, 255), 2,
                        cv2.LINE_AA)
            cv2.imshow('a', self.img)
            self.done = True

        #self.total_reward = len(self.snake_position) - 3  # default length is 3
        #self.reward = self.total_reward - self.prev_reward
        #self.prev_reward = self.total_reward

        if self.done:
            self.reward = -10
        else:
            self.reward = self.score

        head_x = self.snake_head[0]
        head_y = self.snake_head[1]

        apple_delta_x = self.apple_position[0] - head_x
        apple_delta_y = self.apple_position[1] - head_y

        snake_length = len(self.snake_position)

        self.prev_actions = deque(maxlen=SNAKE_LEN_GOAL)
        for _ in range(SNAKE_LEN_GOAL):
            self.prev_actions(-1)

        # create observation:
        observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
        observation = np.array(observation)

        info = {}
        return observation, self.reward, self.done, info

    def reset(self):
        self.img = np.zeros((500, 500, 3), dtype='uint8')
        # Initial Snake and Apple position
        self.snake_position = [[250, 250], [240, 250], [230, 250]]
        self.apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
        self.score = 0
        self.prev_button_direction = 1
        self.button_direction = 1
        self.snake_head = [250, 250]

        self.prev_reward = 0

        self.done = False

        head_x = self.snake_head[0]
        head_y = self.snake_head[1]

        apple_delta_x = self.apple_position[0] - head_x
        apple_delta_y = self.apple_position[1] - head_y
        snake_length = len(self.snake_position)

        for i in range(SNAKE_LEN_GOAL):
            self.prev_actions.append(-1)  # to create history

        # create observation:
        observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
        observation = np.array(observation)

        return observation


Checking the environment.

from stable_baselines3.common.env_checker import check_env
from snake_python_game_Env import SnekEnv


env = SnekEnv()
# It will check your custom environment and output additional warnings if needed
check_env(env)

Error.

Traceback (most recent call last):
  File "C:\Users\This PC\PycharmProjects\pythonProject\snake_python_game_agent.py", line 7, in <module>
    check_env(env)
  File "C:\Users\This PC\AppData\Local\Programs\Python\Python38\lib\site-packages\stable_baselines3\common\env_checker.py", line 302, in check_env
    _check_returned_values(env, observation_space, action_space)
  File "C:\Users\This PC\AppData\Local\Programs\Python\Python38\lib\site-packages\stable_baselines3\common\env_checker.py", line 159, in _check_returned_values
    _check_obs(obs, observation_space, "reset")
  File "C:\Users\This PC\AppData\Local\Programs\Python\Python38\lib\site-packages\stable_baselines3\common\env_checker.py", line 112, in _check_obs
    assert observation_space.contains(
AssertionError: The observation returned by the `reset()` method does not match the given observation space

The code according to the tutorial is supposed to run but on my side it's not.

john mugi
  • 13
  • 4

1 Answers1

0

I think you should change the line where you define your observation space:

self.observation_space = spaces.Box(low=-500, high=500,
                                        shape=(5 + SNAKE_LEN_GOAL,), dtype=int)

Here I change the data type into which the observation space will have, which seems to be integer valued arrays in your case. When I tried this locally it gave a different error though: TypeError: 'collections.deque' object is not callable which is thrown from step function of the environment. Hope this helps.