diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl index ff4c89b4d..9f2cf02cb 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl @@ -70,20 +70,19 @@ end RLBase.players(::TicTacToeEnv) = (Player(:Cross), Player(:Nought)) -RLBase.state(env::TicTacToeEnv, ::Observation, ::DefaultPlayer) = state(env, Observation{Int}(), Player(:Any)) -RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, player) = env.board -RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), Player(1)) -RLBase.state(env::TicTacToeEnv, ::Observation{Int}, player::Player) = +RLBase.state(env::TicTacToeEnv, o::Observation, ::RLBase.AbstractPlayer) = state(env, o) +RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}()) +RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}) = env.board +RLBase.state(env::TicTacToeEnv, ::Observation{Int}) = get_tic_tac_toe_state_info()[env].index -RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, player::Player) = ArrayProductDomain(fill(false:true, 3, 3, 3)) -RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, player::Player) = +RLBase.state_space(env::TicTacToeEnv, o::Observation, ::RLBase.AbstractPlayer) = state_space(env, o) +RLBase.state_space(::TicTacToeEnv, ::Observation{BitArray{3}}) = ArrayProductDomain(fill(false:true, 3, 3, 3)) +RLBase.state_space(::TicTacToeEnv, ::Observation{Int}) = Base.OneTo(length(get_tic_tac_toe_state_info())) -RLBase.state_space(env::TicTacToeEnv, ::Observation{String}, player::Player) = fullspace(String) +RLBase.state_space(::TicTacToeEnv, ::Observation{String}) = fullspace(String) -RLBase.state(env::TicTacToeEnv, ::Observation{String}) = state(env::TicTacToeEnv, Observation{String}(), Player(1)) - -function RLBase.state(env::TicTacToeEnv, ::Observation{String}, player::Player) +function RLBase.state(env::TicTacToeEnv, ::Observation{String}) buff = IOBuffer() for i in 1:3 for j in 1:3 diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl index 0eca516ff..f5b15f289 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/tic_tac_toe.jl @@ -3,15 +3,15 @@ using ReinforcementLearningEnvironments, ReinforcementLearningBase, ReinforcementLearningCore trajectory_1 = Trajectory( - CircularArraySARTSTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity=1), BatchSampler(1), - InsertSampleRatioController(n_inserted = -1), + InsertSampleRatioController(n_inserted=-1), ) trajectory_2 = Trajectory( - CircularArraySARTSTraces(; capacity = 1), + CircularArraySARTSTraces(; capacity=1), BatchSampler(1), - InsertSampleRatioController(n_inserted = -1), + InsertSampleRatioController(n_inserted=-1), ) multiagent_policy = MultiAgentPolicy(PlayerTuple( @@ -30,6 +30,7 @@ @test length(state_space(env, Observation{Int}())) == 5478 @test RLBase.state(env, Observation{BitArray{3}}(), Player(:Cross)) == env.board + @test RLBase.state(env, Observation{BitArray{3}}()) == env.board @test RLBase.state_space(env, Observation{BitArray{3}}(), Player(:Cross)) isa ArrayProductDomain @test RLBase.state_space(env, Observation{String}(), Player(:Cross)) isa DomainSets.FullSpace{String} @test RLBase.state(env, Observation{String}(), Player(:Cross)) isa String