diff --git a/src/ReinforcementLearningCore/Project.toml b/src/ReinforcementLearningCore/Project.toml index 33515edb0..3c04fbbe2 100644 --- a/src/ReinforcementLearningCore/Project.toml +++ b/src/ReinforcementLearningCore/Project.toml @@ -51,10 +51,10 @@ cuDNN = "1" julia = "1.9" [extras] +CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["DomainSets", "Test", "Random", "ReinforcementLearningEnvironments"] +test = ["CommonRLInterface","DomainSets", "Test", "Random"] diff --git a/src/ReinforcementLearningCore/test/environments/randomwalk1D.jl b/src/ReinforcementLearningCore/test/environments/randomwalk1D.jl new file mode 100644 index 000000000..7ee51fefc --- /dev/null +++ b/src/ReinforcementLearningCore/test/environments/randomwalk1D.jl @@ -0,0 +1,66 @@ +import CommonRLInterface + +""" + RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) รท 2, actions=[-1,1]) + +An agent is placed at the `start_pos` and can move left or right (stride is +defined in actions). The game terminates when the agent reaches either end and +receives a reward correspondingly. + +Compared to the [`MultiArmBanditsEnv`](@ref): + +1. The state space is more complicated (well, not that complicated though). +1. It's a sequential game of multiple action steps. +1. It's a deterministic game instead of stochastic game. +""" +Base.@kwdef mutable struct RandomWalk1D <: AbstractEnv + rewards::Pair{Float64,Float64} = -1.0 => 1.0 + N::Int = 7 + actions::Vector{Int} = [-1, 1] + start_pos::Int = (N + 1) รท 2 + pos::Int = start_pos + + action_space::Base.OneTo = Base.OneTo(length(actions)) + state_space::Base.OneTo = Base.OneTo(N) +end + +RLBase.action_space(env::RandomWalk1D) = env.action_space + +function RLBase.act!(env::RandomWalk1D, action::Int) + env.pos += env.actions[action] + if env.pos > env.N + env.pos = env.N + elseif env.pos < 1 + env.pos = 1 + end + return +end + +RLBase.state(env::RandomWalk1D) = env.pos +RLBase.state_space(env::RandomWalk1D) = env.state_space +RLBase.is_terminated(env::RandomWalk1D) = env.pos == 1 || env.pos == env.N +RLBase.reset!(env::RandomWalk1D) = env.pos = env.start_pos + +RLBase.reward(env::RandomWalk1D) = random_walk_reward(env.pos, env.rewards, env.N) + +function random_walk_reward(pos::Int, rewards::Pair{Float64,Float64}, N::Int) + if pos == 1 + return random_walk_reward_first(rewards) + elseif pos == N + return random_walk_reward_last(rewards) + else + return 0.0 + end +end + +random_walk_reward_first(rewards::Pair{Float64,Float64}) = first(rewards) +random_walk_reward_last(rewards::Pair{Float64,Float64}) = last(rewards) + +RLBase.NumAgentStyle(::RandomWalk1D) = SINGLE_AGENT +RLBase.DynamicStyle(::RandomWalk1D) = SEQUENTIAL +RLBase.ActionStyle(::RandomWalk1D) = MINIMAL_ACTION_SET +RLBase.InformationStyle(::RandomWalk1D) = PERFECT_INFORMATION +RLBase.StateStyle(::RandomWalk1D) = Observation{Int}() +RLBase.RewardStyle(::RandomWalk1D) = TERMINAL_REWARD +RLBase.UtilityStyle(::RandomWalk1D) = GENERAL_SUM +RLBase.ChanceStyle(::RandomWalk1D) = DETERMINISTIC \ No newline at end of file diff --git a/src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl b/src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl new file mode 100644 index 000000000..fb8853ab2 --- /dev/null +++ b/src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl @@ -0,0 +1,63 @@ +""" + RockPaperScissorsEnv() + +[Rock Paper Scissors](https://en.wikipedia.org/wiki/Rock_paper_scissors) is a +simultaneous, zero sum game. +""" +Base.@kwdef mutable struct RockPaperScissorsEnv <: AbstractEnv + reward::NamedTuple{(Symbol(1), Symbol(2)), Tuple{Int64, Int64}} = (; Symbol(1) => 0, Symbol(2) => 0) + is_done::Bool = false +end + +RLBase.players(::RockPaperScissorsEnv) = (Symbol(1), Symbol(2)) + +""" +Note that although this is a two player game. The current player is always a +dummy simultaneous player. +""" +RLBase.current_player(::RockPaperScissorsEnv) = SIMULTANEOUS_PLAYER + +RLBase.action_space(::RockPaperScissorsEnv, ::Symbol) = ('๐Ÿ’Ž', '๐Ÿ“ƒ', 'โœ‚') + +RLBase.action_space(::RockPaperScissorsEnv, ::SimultaneousPlayer) = + Tuple((i, j) for i in ('๐Ÿ’Ž', '๐Ÿ“ƒ', 'โœ‚') for j in ('๐Ÿ’Ž', '๐Ÿ“ƒ', 'โœ‚')) + +RLBase.action_space(env::RockPaperScissorsEnv) = action_space(env, SIMULTANEOUS_PLAYER) + +RLBase.legal_action_space(env::RockPaperScissorsEnv, p) = + is_terminated(env) ? () : action_space(env, p) + +"Since it's a one-shot game, the state space doesn't have much meaning." +RLBase.state_space(::RockPaperScissorsEnv, ::Observation, p) = Base.OneTo(1) + +""" +For multi-agent environments, we usually implement the most detailed one. +""" +RLBase.state(::RockPaperScissorsEnv, ::Observation, p) = 1 + +RLBase.reward(env::RockPaperScissorsEnv) = env.is_done ? env.reward : (; Symbol(1) => 0, Symbol(2) => 0) +RLBase.reward(env::RockPaperScissorsEnv, p::Symbol) = reward(env)[p] + +RLBase.is_terminated(env::RockPaperScissorsEnv) = env.is_done +RLBase.reset!(env::RockPaperScissorsEnv) = env.is_done = false + +# TODO: Consider using CRL.all_act! and adjusting run function accordingly +function RLBase.act!(env::RockPaperScissorsEnv, (x, y)) + if x == y + env.reward = (; Symbol(1) => 0, Symbol(2) => 0) + elseif x == '๐Ÿ’Ž' && y == 'โœ‚' || x == 'โœ‚' && y == '๐Ÿ“ƒ' || x == '๐Ÿ“ƒ' && y == '๐Ÿ’Ž' + env.reward = (; Symbol(1) => 1, Symbol(2) => -1) + else + env.reward = (; Symbol(1) => -1, Symbol(2) => 1) + end + env.is_done = true +end + +RLBase.NumAgentStyle(::RockPaperScissorsEnv) = MultiAgent(2) +RLBase.DynamicStyle(::RockPaperScissorsEnv) = SIMULTANEOUS +RLBase.ActionStyle(::RockPaperScissorsEnv) = MINIMAL_ACTION_SET +RLBase.InformationStyle(::RockPaperScissorsEnv) = IMPERFECT_INFORMATION +RLBase.StateStyle(::RockPaperScissorsEnv) = Observation{Int}() +RLBase.RewardStyle(::RockPaperScissorsEnv) = TERMINAL_REWARD +RLBase.UtilityStyle(::RockPaperScissorsEnv) = ZERO_SUM +RLBase.ChanceStyle(::RockPaperScissorsEnv) = DETERMINISTIC \ No newline at end of file diff --git a/src/ReinforcementLearningCore/test/environments/tictactoe.jl b/src/ReinforcementLearningCore/test/environments/tictactoe.jl new file mode 100644 index 000000000..0b21949b7 --- /dev/null +++ b/src/ReinforcementLearningCore/test/environments/tictactoe.jl @@ -0,0 +1,171 @@ +import ReinforcementLearningBase: RLBase +import ReinforcementLearningCore: RLCore +import CommonRLInterface + +mutable struct TicTacToeEnv <: AbstractEnv + board::BitArray{3} + player::Symbol +end + +function TicTacToeEnv() + board = BitArray{3}(undef, 3, 3, 3) + fill!(board, false) + board[:, :, 1] .= true + TicTacToeEnv(board, :Cross) +end + +function RLBase.reset!(env::TicTacToeEnv) + fill!(env.board, false) + env.board[:, :, 1] .= true + env.player = :Cross +end + +struct TicTacToeInfo + is_terminated::Bool + winner::Union{Nothing,Symbol} +end + +const TIC_TAC_TOE_STATE_INFO = Dict{ + TicTacToeEnv, + NamedTuple{ + (:index, :is_terminated, :winner), + Tuple{Int,Bool,Union{Nothing,Symbol}}, + }, +}() + +Base.hash(env::TicTacToeEnv, h::UInt) = hash(env.board, h) +Base.isequal(a::TicTacToeEnv, b::TicTacToeEnv) = isequal(a.board, b.board) + +Base.to_index(::TicTacToeEnv, player) = player == :Cross ? 2 : 3 + +RLBase.action_space(::TicTacToeEnv, player) = Base.OneTo(9) + +RLBase.legal_action_space(env::TicTacToeEnv, p) = findall(legal_action_space_mask(env)) + +function RLBase.legal_action_space_mask(env::TicTacToeEnv, p) + if is_win(env, :Cross) || is_win(env, :Nought) + falses(9) + else + vec(env.board[:, :, 1]) + end +end + +RLBase.act!(env::TicTacToeEnv, action::Int) = RLBase.act!(env, CartesianIndices((3, 3))[action]) + +function RLBase.act!(env::TicTacToeEnv, action::CartesianIndex{2}) + env.board[action, 1] = false + env.board[action, Base.to_index(env, current_player(env))] = true +end + +function RLBase.next_player!(env::TicTacToeEnv) + env.player = env.player == :Cross ? :Nought : :Cross +end + +RLBase.players(::TicTacToeEnv) = (:Cross, :Nought) + +RLBase.state(env::TicTacToeEnv) = state(env, Observation{Int}(), 1) +RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = env.board +RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), 1) +RLBase.state(env::TicTacToeEnv, ::Observation{Int}, p) = + get_tic_tac_toe_state_info()[env].index + +RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = ArrayProductDomain(fill(false:true, 3, 3, 3)) +RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, p) = + Base.OneTo(length(get_tic_tac_toe_state_info())) +RLBase.state_space(env::TicTacToeEnv, ::Observation{String}, p) = fullspace(String) + +RLBase.state(env::TicTacToeEnv, ::Observation{String}) = state(env::TicTacToeEnv, Observation{String}(), 1) + +function RLBase.state(env::TicTacToeEnv, ::Observation{String}, p) + buff = IOBuffer() + for i in 1:3 + for j in 1:3 + if env.board[i, j, 1] + x = '.' + elseif env.board[i, j, 2] + x = 'x' + else + x = 'o' + end + print(buff, x) + end + print(buff, '\n') + end + String(take!(buff)) +end + +RLBase.is_terminated(env::TicTacToeEnv) = get_tic_tac_toe_state_info()[env].is_terminated + +function RLBase.reward(env::TicTacToeEnv, player::Symbol) + if is_terminated(env) + winner = get_tic_tac_toe_state_info()[env].winner + if isnothing(winner) + 0 + elseif winner === player + 1 + else + -1 + end + else + 0 + end +end + +function is_win(env::TicTacToeEnv, player::Symbol) + b = env.board + p = Base.to_index(env, player) + @inbounds begin + b[1, 1, p] & b[1, 2, p] & b[1, 3, p] || + b[2, 1, p] & b[2, 2, p] & b[2, 3, p] || + b[3, 1, p] & b[3, 2, p] & b[3, 3, p] || + b[1, 1, p] & b[2, 1, p] & b[3, 1, p] || + b[1, 2, p] & b[2, 2, p] & b[3, 2, p] || + b[1, 3, p] & b[2, 3, p] & b[3, 3, p] || + b[1, 1, p] & b[2, 2, p] & b[3, 3, p] || + b[1, 3, p] & b[2, 2, p] & b[3, 1, p] + end +end + +function get_tic_tac_toe_state_info() + if isempty(TIC_TAC_TOE_STATE_INFO) + @info "initializing tictactoe state info cache..." + t = @elapsed begin + n = 1 + root = TicTacToeEnv() + TIC_TAC_TOE_STATE_INFO[root] = + (index=n, is_terminated=false, winner=nothing) + walk(root) do env + if !haskey(TIC_TAC_TOE_STATE_INFO, env) + n += 1 + has_empty_pos = any(view(env.board, :, :, 1)) + w = if is_win(env, :Cross) + :Cross + elseif is_win(env, :Nought) + :Nought + else + nothing + end + TIC_TAC_TOE_STATE_INFO[env] = ( + index=n, + is_terminated=!(has_empty_pos && isnothing(w)), + winner=w, + ) + end + end + end + @info "finished initializing tictactoe state info cache in $t seconds" + end + TIC_TAC_TOE_STATE_INFO +end + +RLBase.current_player(env::TicTacToeEnv) = env.player + +RLBase.NumAgentStyle(::TicTacToeEnv) = MultiAgent(2) +RLBase.DynamicStyle(::TicTacToeEnv) = SEQUENTIAL +RLBase.ActionStyle(::TicTacToeEnv) = FULL_ACTION_SET +RLBase.InformationStyle(::TicTacToeEnv) = PERFECT_INFORMATION +RLBase.StateStyle(::TicTacToeEnv) = + (Observation{Int}(), Observation{String}(), Observation{BitArray{3}}()) +RLBase.RewardStyle(::TicTacToeEnv) = TERMINAL_REWARD +RLBase.UtilityStyle(::TicTacToeEnv) = ZERO_SUM +RLBase.ChanceStyle(::TicTacToeEnv) = DETERMINISTIC \ No newline at end of file diff --git a/src/ReinforcementLearningCore/test/policies/agent.jl b/src/ReinforcementLearningCore/test/policies/agent.jl index 81297ce77..cacff4f3e 100644 --- a/src/ReinforcementLearningCore/test/policies/agent.jl +++ b/src/ReinforcementLearningCore/test/policies/agent.jl @@ -1,6 +1,5 @@ -using ReinforcementLearningBase, ReinforcementLearningEnvironments -using ReinforcementLearningCore: SRT -using ReinforcementLearningCore +using ReinforcementLearningBase +import ReinforcementLearningCore.SRT @testset "agent.jl" begin @testset "Agent Tests" begin diff --git a/src/ReinforcementLearningCore/test/policies/multi_agent.jl b/src/ReinforcementLearningCore/test/policies/multi_agent.jl index fc1229497..0eba11ffa 100644 --- a/src/ReinforcementLearningCore/test/policies/multi_agent.jl +++ b/src/ReinforcementLearningCore/test/policies/multi_agent.jl @@ -1,5 +1,4 @@ using Test -using ReinforcementLearningEnvironments using ReinforcementLearningTrajectories using ReinforcementLearningCore using ReinforcementLearningBase @@ -85,8 +84,8 @@ end @test multiagent_hook.hooks[:Cross].steps[1] > 0 @test RLBase.is_terminated(env) - @test RLEnvs.is_win(env, :Cross) isa Bool - @test RLEnvs.is_win(env, :Nought) isa Bool + @test is_win(env, :Cross) isa Bool + @test is_win(env, :Nought) isa Bool @test RLBase.reward(env, :Cross) == (RLBase.reward(env, :Nought) * -1) @test RLBase.legal_action_space_mask(env, :Cross) == falses(9) @test RLBase.legal_action_space(env) == [] diff --git a/src/ReinforcementLearningCore/test/runtests.jl b/src/ReinforcementLearningCore/test/runtests.jl index d11899e4d..bbe01ebd6 100644 --- a/src/ReinforcementLearningCore/test/runtests.jl +++ b/src/ReinforcementLearningCore/test/runtests.jl @@ -1,6 +1,5 @@ using ReinforcementLearningBase using ReinforcementLearningCore -using ReinforcementLearningEnvironments using ReinforcementLearningTrajectories using Test @@ -8,6 +7,9 @@ using CUDA using CircularArrayBuffers using Flux +include("environments/randomwalk1D.jl") +include("environments/tictactoe.jl") +include("environments/rockpaperscissors.jl") @testset "ReinforcementLearningCore.jl" begin include("core/core.jl") include("core/stop_conditions.jl")