diff --git a/games/spiel.py b/games/spiel.py index 7bb3a194..9607f3e6 100644 --- a/games/spiel.py +++ b/games/spiel.py @@ -1,8 +1,10 @@ import datetime import pathlib +from typing import Tuple import numpy import torch +from pyspiel import Game, SpielError from .abstract_game import AbstractGame @@ -19,8 +21,23 @@ "You need to install open_spiel by running pip install open_spiel. For a full documentation, see: https://github.com/deepmind/open_spiel/blob/master/docs/install.md" ) + +def get_observation_tensor_shape(game: Game) -> Tuple: + # Dimensions of the game observation, must be 3D (channel, height, width). + # We reshape it for a 1D array to (1, 1, length of array) or return a default shape if it's not implemented in spiel game. + try: + shape = game.observation_tensor_shape() + except SpielError: + print('ObservationTensorShape unimplemented. Returning default tensor shape (1,1,1)') + return 1, 1, 1 + + for _ in range(3 - len(shape)): + shape.insert(0, 1) + return tuple(shape) + + # The game you want to run. See https://github.com/deepmind/open_spiel/blob/master/docs/games.md for a list of games -game = pyspiel.load_game("tic_tac_toe") +game = pyspiel.load_game("backgammon") class MuZeroConfig: @@ -34,9 +51,8 @@ def __init__(self): self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available - ### Game - self.observation_shape = tuple(self.game.observation_tensor_shape()) # Dimensions of the game observation, must be 3D (channel, height, width). For a 1D array, please reshape it to (1, 1, length of array) + self.observation_shape = get_observation_tensor_shape(self.game) self.action_space = list(range(self.game.policy_tensor_shape()[0])) # Fixed list of all possible actions. You should only edit the length self.players = list(range(self.game.num_players())) # List of players. You should only edit the length self.stacked_observations = 0 # Number of previous observations and previous actions to add to the current observation @@ -270,14 +286,17 @@ def get_observation(self): else: current_player = 0 return numpy.array(self.board.observation_tensor(current_player)).reshape( - self.game.observation_tensor_shape() + get_observation_tensor_shape(self.game) ) def legal_actions(self): return self.board.legal_actions() def have_winner(self): - rewards = self.board.rewards() + try: + rewards = self.board.rewards() + except SpielError: + return False if self.player == 1: