Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: gym v26 migration #224

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,7 @@ venv.bak/
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
dmypy.json

# gym results
results/
2 changes: 1 addition & 1 deletion games/abstract_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class AbstractGame(ABC):
"""

@abstractmethod
def __init__(self, seed=None):
def __init__(self, seed=None, render_mode=None):
pass

@abstractmethod
Expand Down
12 changes: 7 additions & 5 deletions games/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
self.env = gym.make("Breakout-v4")
def __init__(self, seed=None, render_mode=None):
self.env = gym.make("Breakout-v4", render_mode=render_mode)
if seed is not None:
self.env.seed(seed)
self.env.reset(seed=seed)

def step(self, action):
"""
Expand All @@ -153,7 +153,9 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated

observation = cv2.resize(observation, (96, 96), interpolation=cv2.INTER_AREA)
observation = numpy.asarray(observation, dtype="float32") / 255.0
observation = numpy.moveaxis(observation, -1, 0)
Expand All @@ -179,7 +181,7 @@ def reset(self):
Returns:
Initial observation of the game.
"""
observation = self.env.reset()
observation, _ = self.env.reset()
observation = cv2.resize(observation, (96, 96), interpolation=cv2.INTER_AREA)
observation = numpy.asarray(observation, dtype="float32") / 255.0
observation = numpy.moveaxis(observation, -1, 0)
Expand Down
12 changes: 7 additions & 5 deletions games/breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
self.env = gym.make("Breakout-v4")
def __init__(self, seed=None, render_mode=None):
self.env = gym.make("Breakout-v4", render_mode=render_mode)
if seed is not None:
self.env.seed(seed)
self.env.reset(seed=seed)

def step(self, action):
"""
Expand All @@ -153,7 +153,9 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated

observation = cv2.resize(observation, (96, 96), interpolation=cv2.INTER_AREA)
observation = numpy.asarray(observation, dtype="float32") / 255.0
observation = numpy.moveaxis(observation, -1, 0)
Expand All @@ -179,7 +181,7 @@ def reset(self):
Returns:
Initial observation of the game.
"""
observation = self.env.reset()
observation, _ = self.env.reset()
observation = cv2.resize(observation, (96, 96), interpolation=cv2.INTER_AREA)
observation = numpy.asarray(observation, dtype="float32") / 255.0
observation = numpy.moveaxis(observation, -1, 0)
Expand Down
12 changes: 7 additions & 5 deletions games/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
self.env = gym.make("CartPole-v1")
def __init__(self, seed=None, render_mode=None):
self.env = gym.make("CartPole-v1", render_mode=render_mode)
if seed is not None:
self.env.seed(seed)
self.env.reset(seed=seed)

def step(self, action):
"""
Expand All @@ -148,7 +148,8 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
return numpy.array([[observation]]), reward, done

def legal_actions(self):
Expand All @@ -171,7 +172,8 @@ def reset(self):
Returns:
Initial observation of the game.
"""
return numpy.array([[self.env.reset()]])
observation, _ = self.env.reset()
return numpy.array([[observation]])

def close(self):
"""
Expand Down
8 changes: 5 additions & 3 deletions games/connect4.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
def __init__(self, seed=None, render_mode=None):
self.env = Connect4()

def step(self, action):
Expand All @@ -140,7 +140,8 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
return observation, reward * 10, done

def to_play(self):
Expand Down Expand Up @@ -172,7 +173,8 @@ def reset(self):
Returns:
Initial observation of the game.
"""
return self.env.reset()
observation, _ = self.env.reset()
return observation

def render(self):
"""
Expand Down
8 changes: 5 additions & 3 deletions games/gomoku.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
def __init__(self, seed=None, render_mode=None):
self.env = Gomoku()

def step(self, action):
Expand All @@ -146,7 +146,8 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
return observation, reward, done

def to_play(self):
Expand Down Expand Up @@ -178,7 +179,8 @@ def reset(self):
Returns:
Initial observation of the game.
"""
return self.env.reset()
observation, _ = self.env.reset()
return observation

def close(self):
"""
Expand Down
12 changes: 7 additions & 5 deletions games/gridworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
self.env = gym.make("MiniGrid-Empty-Random-6x6-v0")
def __init__(self, seed=None, render_mode=None):
self.env = gym.make("MiniGrid-Empty-Random-6x6-v0", render_mode=render_mode)
self.env = gym_minigrid.wrappers.ImgObsWrapper(self.env)
if seed is not None:
self.env.seed(seed)
self.env.reset(seed=seed)

def step(self, action):
"""
Expand All @@ -154,7 +154,8 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
return numpy.array(observation), reward, done

def legal_actions(self):
Expand All @@ -177,7 +178,8 @@ def reset(self):
Returns:
Initial observation of the game.
"""
return numpy.array(self.env.reset())
observation, _ = self.env.reset()
return numpy.array(observation)

def close(self):
"""
Expand Down
10 changes: 6 additions & 4 deletions games/lunarlander.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
def __init__(self, seed=None, render_mode=None):
self.env = DeterministicLunarLander()
# self.env = gym.make("LunarLander-v2")
if seed is not None:
self.env.seed(seed)
self.env.reset(seed=seed)

def step(self, action):
"""
Expand All @@ -145,7 +145,8 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done, _ = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
return numpy.array([[observation]]), reward / 3, done

def legal_actions(self):
Expand All @@ -168,7 +169,8 @@ def reset(self):
Returns:
Initial observation of the game.
"""
return numpy.array([[self.env.reset()]])
observation, _ = self.env.reset()
return numpy.array([[observation]])

def close(self):
"""
Expand Down
8 changes: 5 additions & 3 deletions games/simple_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
def __init__(self, seed=None, render_mode=None):
self.env = GridEnv()

def step(self, action):
Expand All @@ -140,7 +140,8 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
return [[observation]], reward * 10, done

def legal_actions(self):
Expand All @@ -163,7 +164,8 @@ def reset(self):
Returns:
Initial observation of the game.
"""
return [[self.env.reset()]]
observation, _ = self.env.reset()
return [[observation]]

def render(self):
"""
Expand Down
8 changes: 5 additions & 3 deletions games/spiel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
def __init__(self, seed=None, render_mode=None):
self.env = Spiel()

def step(self, action):
Expand All @@ -158,7 +158,8 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
return observation, reward * 20, done

def to_play(self):
Expand Down Expand Up @@ -190,7 +191,8 @@ def reset(self):
Returns:
Initial observation of the game.
"""
return self.env.reset()
observation, _ = self.env.reset()
return observation

def render(self):
"""
Expand Down
8 changes: 5 additions & 3 deletions games/tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
def __init__(self, seed=None, render_mode=None):
self.env = TicTacToe()

def step(self, action):
Expand All @@ -140,7 +140,8 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
return observation, reward * 20, done

def to_play(self):
Expand Down Expand Up @@ -172,7 +173,8 @@ def reset(self):
Returns:
Initial observation of the game.
"""
return self.env.reset()
observation, _ = self.env.reset()
return observation

def render(self):
"""
Expand Down
8 changes: 5 additions & 3 deletions games/twentyone.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Game(AbstractGame):
Game wrapper.
"""

def __init__(self, seed=None):
def __init__(self, seed=None, render_mode=None):
self.env = TwentyOne(seed)

def step(self, action):
Expand All @@ -152,7 +152,8 @@ def step(self, action):
Returns:
The new observation, the reward and a boolean if the game has ended.
"""
observation, reward, done = self.env.step(action)
observation, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
return observation, reward * 10, done

def to_play(self):
Expand Down Expand Up @@ -184,7 +185,8 @@ def reset(self):
Returns:
Initial observation of the game.
"""
return self.env.reset()
observation, _ = self.env.reset()
return observation

def render(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def load_model(self, checkpoint_path=None, replay_buffer_path=None):
"""
# Load checkpoint
if checkpoint_path:
checkpoint_path = pathlib.Path(checkpoint_path)
checkpoint_path = pathlib.Path(checkpoint_path).absolute()
self.checkpoint = torch.load(checkpoint_path)
print(f"\nUsing checkpoint from {checkpoint_path}")

Expand Down