I am learning Reinforcement Learning in Python with Stable Baselines 3 referencing a tutorial by sentdex. The problem when I run check the code using check_env()
I get an error AssertionError: The observation returned by the
reset() method does not match the given observation space
.
Clearly I don't know what wrong with return in the reset method.
Here is the code:
import gym
from gym import spaces
import numpy as np
import cv2
import random
import time
from collections import deque
SNAKE_LEN_GOAL = 30
def collision_with_apple(apple_position, score):
apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
score += 1
return apple_position, score
def collision_with_boundaries(snake_head):
if snake_head[0] >= 500 or snake_head[0] < 0 or snake_head[1] >= 500 or snake_head[1] < 0:
return 1
else:
return 0
def collision_with_self(snake_position):
snake_head = snake_position[0]
if snake_head in snake_position[1:]:
return 1
else:
return 0
class SnekEnv(gym.Env):
def __init__(self):
super(SnekEnv, self).__init__()
# Define action and observation space
# They must be gym.spaces objects
# Example when using discrete actions:
self.prev_actions = deque(maxlen=SNAKE_LEN_GOAL) # however long we aspire the snake to be
self.action_space = spaces.Discrete(4)
# Example for using image as input (channel-first; channel-last also works):
self.observation_space = spaces.Box(low=-500, high=500,
shape=(5 + SNAKE_LEN_GOAL,), dtype=np.float32)
def step(self, action):
self.prev_actions.append(action)
cv2.imshow('a', self.img)
cv2.waitKey(1)
self.img = np.zeros((500, 500, 3), dtype='uint8')
# Display Apple
cv2.rectangle(self.img, (self.apple_position[0], self.apple_position[1]),
(self.apple_position[0] + 10, self.apple_position[1] + 10), (0, 0, 255), 3)
# Display Snake
for position in self.snake_position:
cv2.rectangle(self.img, (position[0], position[1]), (position[0] + 10, position[1] + 10), (0, 255, 0), 3)
# Takes step after fixed time
t_end = time.time() + 0.05
k = -1
while time.time() < t_end:
if k == -1:
k = cv2.waitKey(1)
else:
continue
button_direction = action
# Change the head position based on the button direction
if button_direction == 1:
self.snake_head[0] += 10
elif button_direction == 0:
self.snake_head[0] -= 10
elif button_direction == 2:
self.snake_head[1] += 10
elif button_direction == 3:
self.snake_head[1] -= 10
# Increase Snake length on eating apple
if self.snake_head == self.apple_position:
self.apple_position, self.score = collision_with_apple(self.apple_position, self.score)
self.snake_position.insert(0, list(self.snake_head))
else:
self.snake_position.insert(0, list(self.snake_head))
self.snake_position.pop()
# On collision kill the snake and print the score
if collision_with_boundaries(self.snake_head) == 1 or collision_with_self(self.snake_position) == 1:
font = cv2.FONT_HERSHEY_SIMPLEX
self.img = np.zeros((500, 500, 3), dtype='uint8')
cv2.putText(self.img, 'Your Score is {}'.format(self.score), (140, 250), font, 1, (255, 255, 255), 2,
cv2.LINE_AA)
cv2.imshow('a', self.img)
self.done = True
#self.total_reward = len(self.snake_position) - 3 # default length is 3
#self.reward = self.total_reward - self.prev_reward
#self.prev_reward = self.total_reward
if self.done:
self.reward = -10
else:
self.reward = self.score
head_x = self.snake_head[0]
head_y = self.snake_head[1]
apple_delta_x = self.apple_position[0] - head_x
apple_delta_y = self.apple_position[1] - head_y
snake_length = len(self.snake_position)
self.prev_actions = deque(maxlen=SNAKE_LEN_GOAL)
for _ in range(SNAKE_LEN_GOAL):
self.prev_actions(-1)
# create observation:
observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
observation = np.array(observation)
info = {}
return observation, self.reward, self.done, info
def reset(self):
self.img = np.zeros((500, 500, 3), dtype='uint8')
# Initial Snake and Apple position
self.snake_position = [[250, 250], [240, 250], [230, 250]]
self.apple_position = [random.randrange(1, 50) * 10, random.randrange(1, 50) * 10]
self.score = 0
self.prev_button_direction = 1
self.button_direction = 1
self.snake_head = [250, 250]
self.prev_reward = 0
self.done = False
head_x = self.snake_head[0]
head_y = self.snake_head[1]
apple_delta_x = self.apple_position[0] - head_x
apple_delta_y = self.apple_position[1] - head_y
snake_length = len(self.snake_position)
for i in range(SNAKE_LEN_GOAL):
self.prev_actions.append(-1) # to create history
# create observation:
observation = [head_x, head_y, apple_delta_x, apple_delta_y, snake_length] + list(self.prev_actions)
observation = np.array(observation)
return observation
Checking the environment.
from stable_baselines3.common.env_checker import check_env
from snake_python_game_Env import SnekEnv
env = SnekEnv()
# It will check your custom environment and output additional warnings if needed
check_env(env)
Error.
Traceback (most recent call last):
File "C:\Users\This PC\PycharmProjects\pythonProject\snake_python_game_agent.py", line 7, in <module>
check_env(env)
File "C:\Users\This PC\AppData\Local\Programs\Python\Python38\lib\site-packages\stable_baselines3\common\env_checker.py", line 302, in check_env
_check_returned_values(env, observation_space, action_space)
File "C:\Users\This PC\AppData\Local\Programs\Python\Python38\lib\site-packages\stable_baselines3\common\env_checker.py", line 159, in _check_returned_values
_check_obs(obs, observation_space, "reset")
File "C:\Users\This PC\AppData\Local\Programs\Python\Python38\lib\site-packages\stable_baselines3\common\env_checker.py", line 112, in _check_obs
assert observation_space.contains(
AssertionError: The observation returned by the `reset()` method does not match the given observation space
The code according to the tutorial is supposed to run but on my side it's not.