Skip to content

Commit

Permalink
remove rlenv dep for tests (#989)
Browse files Browse the repository at this point in the history
  • Loading branch information
HenriDeh authored Oct 12, 2023
1 parent dd19ee0 commit 3b21982
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ cuDNN = "1"
julia = "1.9"

[extras]
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DomainSets", "Test", "Random", "ReinforcementLearningEnvironments"]
test = ["CommonRLInterface","DomainSets", "Test", "Random"]
66 changes: 66 additions & 0 deletions src/ReinforcementLearningCore/test/environments/randomwalk1D.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import CommonRLInterface

"""
RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) ÷ 2, actions=[-1,1])
An agent is placed at the `start_pos` and can move left or right (stride is
defined in actions). The game terminates when the agent reaches either end and
receives a reward correspondingly.
Compared to the [`MultiArmBanditsEnv`](@ref):
1. The state space is more complicated (well, not that complicated though).
1. It's a sequential game of multiple action steps.
1. It's a deterministic game instead of stochastic game.
"""
Base.@kwdef mutable struct RandomWalk1D <: AbstractEnv
rewards::Pair{Float64,Float64} = -1.0 => 1.0
N::Int = 7
actions::Vector{Int} = [-1, 1]
start_pos::Int = (N + 1) ÷ 2
pos::Int = start_pos

action_space::Base.OneTo = Base.OneTo(length(actions))
state_space::Base.OneTo = Base.OneTo(N)
end

RLBase.action_space(env::RandomWalk1D) = env.action_space

function RLBase.act!(env::RandomWalk1D, action::Int)
env.pos += env.actions[action]
if env.pos > env.N
env.pos = env.N
elseif env.pos < 1
env.pos = 1
end
return
end

RLBase.state(env::RandomWalk1D) = env.pos
RLBase.state_space(env::RandomWalk1D) = env.state_space
RLBase.is_terminated(env::RandomWalk1D) = env.pos == 1 || env.pos == env.N
RLBase.reset!(env::RandomWalk1D) = env.pos = env.start_pos

RLBase.reward(env::RandomWalk1D) = random_walk_reward(env.pos, env.rewards, env.N)

function random_walk_reward(pos::Int, rewards::Pair{Float64,Float64}, N::Int)
if pos == 1
return random_walk_reward_first(rewards)
elseif pos == N
return random_walk_reward_last(rewards)
else
return 0.0
end
end

random_walk_reward_first(rewards::Pair{Float64,Float64}) = first(rewards)
random_walk_reward_last(rewards::Pair{Float64,Float64}) = last(rewards)

RLBase.NumAgentStyle(::RandomWalk1D) = SINGLE_AGENT
RLBase.DynamicStyle(::RandomWalk1D) = SEQUENTIAL
RLBase.ActionStyle(::RandomWalk1D) = MINIMAL_ACTION_SET
RLBase.InformationStyle(::RandomWalk1D) = PERFECT_INFORMATION
RLBase.StateStyle(::RandomWalk1D) = Observation{Int}()
RLBase.RewardStyle(::RandomWalk1D) = TERMINAL_REWARD
RLBase.UtilityStyle(::RandomWalk1D) = GENERAL_SUM
RLBase.ChanceStyle(::RandomWalk1D) = DETERMINISTIC
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
RockPaperScissorsEnv()
[Rock Paper Scissors](https://en.wikipedia.org/wiki/Rock_paper_scissors) is a
simultaneous, zero sum game.
"""
Base.@kwdef mutable struct RockPaperScissorsEnv <: AbstractEnv
reward::NamedTuple{(Symbol(1), Symbol(2)), Tuple{Int64, Int64}} = (; Symbol(1) => 0, Symbol(2) => 0)
is_done::Bool = false
end

RLBase.players(::RockPaperScissorsEnv) = (Symbol(1), Symbol(2))

"""
Note that although this is a two player game. The current player is always a
dummy simultaneous player.
"""
RLBase.current_player(::RockPaperScissorsEnv) = SIMULTANEOUS_PLAYER

RLBase.action_space(::RockPaperScissorsEnv, ::Symbol) = ('💎', '📃', '')

RLBase.action_space(::RockPaperScissorsEnv, ::SimultaneousPlayer) =
Tuple((i, j) for i in ('💎', '📃', '') for j in ('💎', '📃', ''))

RLBase.action_space(env::RockPaperScissorsEnv) = action_space(env, SIMULTANEOUS_PLAYER)

RLBase.legal_action_space(env::RockPaperScissorsEnv, p) =
is_terminated(env) ? () : action_space(env, p)

"Since it's a one-shot game, the state space doesn't have much meaning."
RLBase.state_space(::RockPaperScissorsEnv, ::Observation, p) = Base.OneTo(1)

"""
For multi-agent environments, we usually implement the most detailed one.
"""
RLBase.state(::RockPaperScissorsEnv, ::Observation, p) = 1

RLBase.reward(env::RockPaperScissorsEnv) = env.is_done ? env.reward : (; Symbol(1) => 0, Symbol(2) => 0)
RLBase.reward(env::RockPaperScissorsEnv, p::Symbol) = reward(env)[p]

RLBase.is_terminated(env::RockPaperScissorsEnv) = env.is_done
RLBase.reset!(env::RockPaperScissorsEnv) = env.is_done = false

# TODO: Consider using CRL.all_act! and adjusting run function accordingly
function RLBase.act!(env::RockPaperScissorsEnv, (x, y))
if x == y
env.reward = (; Symbol(1) => 0, Symbol(2) => 0)
elseif x == '💎' && y == '' || x == '' && y == '📃' || x == '📃' && y == '💎'
env.reward = (; Symbol(1) => 1, Symbol(2) => -1)
else
env.reward = (; Symbol(1) => -1, Symbol(2) => 1)
end
env.is_done = true
end

RLBase.NumAgentStyle(::RockPaperScissorsEnv) = MultiAgent(2)
RLBase.DynamicStyle(::RockPaperScissorsEnv) = SIMULTANEOUS
RLBase.ActionStyle(::RockPaperScissorsEnv) = MINIMAL_ACTION_SET
RLBase.InformationStyle(::RockPaperScissorsEnv) = IMPERFECT_INFORMATION
RLBase.StateStyle(::RockPaperScissorsEnv) = Observation{Int}()
RLBase.RewardStyle(::RockPaperScissorsEnv) = TERMINAL_REWARD
RLBase.UtilityStyle(::RockPaperScissorsEnv) = ZERO_SUM
RLBase.ChanceStyle(::RockPaperScissorsEnv) = DETERMINISTIC
171 changes: 171 additions & 0 deletions src/ReinforcementLearningCore/test/environments/tictactoe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import ReinforcementLearningBase: RLBase
import ReinforcementLearningCore: RLCore
import CommonRLInterface

mutable struct TicTacToeEnv <: AbstractEnv
board::BitArray{3}
player::Symbol
end

function TicTacToeEnv()
board = BitArray{3}(undef, 3, 3, 3)
fill!(board, false)
board[:, :, 1] .= true
TicTacToeEnv(board, :Cross)
end

function RLBase.reset!(env::TicTacToeEnv)
fill!(env.board, false)
env.board[:, :, 1] .= true
env.player = :Cross
end

struct TicTacToeInfo
is_terminated::Bool
winner::Union{Nothing,Symbol}
end

const TIC_TAC_TOE_STATE_INFO = Dict{
TicTacToeEnv,
NamedTuple{
(:index, :is_terminated, :winner),
Tuple{Int,Bool,Union{Nothing,Symbol}},
},
}()

Base.hash(env::TicTacToeEnv, h::UInt) = hash(env.board, h)
Base.isequal(a::TicTacToeEnv, b::TicTacToeEnv) = isequal(a.board, b.board)

Base.to_index(::TicTacToeEnv, player) = player == :Cross ? 2 : 3

RLBase.action_space(::TicTacToeEnv, player) = Base.OneTo(9)

RLBase.legal_action_space(env::TicTacToeEnv, p) = findall(legal_action_space_mask(env))

function RLBase.legal_action_space_mask(env::TicTacToeEnv, p)
if is_win(env, :Cross) || is_win(env, :Nought)
falses(9)
else
vec(env.board[:, :, 1])
end
end

RLBase.act!(env::TicTacToeEnv, action::Int) = RLBase.act!(env, CartesianIndices((3, 3))[action])

function RLBase.act!(env::TicTacToeEnv, action::CartesianIndex{2})
env.board[action, 1] = false
env.board[action, Base.to_index(env, current_player(env))] = true
end

function RLBase.next_player!(env::TicTacToeEnv)
env.player = env.player == :Cross ? :Nought : :Cross
end

RLBase.players(::TicTacToeEnv) = (:Cross, :Nought)

RLBase.state(env::TicTacToeEnv) = state(env, Observation{Int}(), 1)
RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = env.board
RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), 1)
RLBase.state(env::TicTacToeEnv, ::Observation{Int}, p) =
get_tic_tac_toe_state_info()[env].index

RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = ArrayProductDomain(fill(false:true, 3, 3, 3))
RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, p) =
Base.OneTo(length(get_tic_tac_toe_state_info()))
RLBase.state_space(env::TicTacToeEnv, ::Observation{String}, p) = fullspace(String)

RLBase.state(env::TicTacToeEnv, ::Observation{String}) = state(env::TicTacToeEnv, Observation{String}(), 1)

function RLBase.state(env::TicTacToeEnv, ::Observation{String}, p)
buff = IOBuffer()
for i in 1:3
for j in 1:3
if env.board[i, j, 1]
x = '.'
elseif env.board[i, j, 2]
x = 'x'
else
x = 'o'
end
print(buff, x)
end
print(buff, '\n')
end
String(take!(buff))
end

RLBase.is_terminated(env::TicTacToeEnv) = get_tic_tac_toe_state_info()[env].is_terminated

function RLBase.reward(env::TicTacToeEnv, player::Symbol)
if is_terminated(env)
winner = get_tic_tac_toe_state_info()[env].winner
if isnothing(winner)
0
elseif winner === player
1
else
-1
end
else
0
end
end

function is_win(env::TicTacToeEnv, player::Symbol)
b = env.board
p = Base.to_index(env, player)
@inbounds begin
b[1, 1, p] & b[1, 2, p] & b[1, 3, p] ||
b[2, 1, p] & b[2, 2, p] & b[2, 3, p] ||
b[3, 1, p] & b[3, 2, p] & b[3, 3, p] ||
b[1, 1, p] & b[2, 1, p] & b[3, 1, p] ||
b[1, 2, p] & b[2, 2, p] & b[3, 2, p] ||
b[1, 3, p] & b[2, 3, p] & b[3, 3, p] ||
b[1, 1, p] & b[2, 2, p] & b[3, 3, p] ||
b[1, 3, p] & b[2, 2, p] & b[3, 1, p]
end
end

function get_tic_tac_toe_state_info()
if isempty(TIC_TAC_TOE_STATE_INFO)
@info "initializing tictactoe state info cache..."
t = @elapsed begin
n = 1
root = TicTacToeEnv()
TIC_TAC_TOE_STATE_INFO[root] =
(index=n, is_terminated=false, winner=nothing)
walk(root) do env
if !haskey(TIC_TAC_TOE_STATE_INFO, env)
n += 1
has_empty_pos = any(view(env.board, :, :, 1))
w = if is_win(env, :Cross)
:Cross
elseif is_win(env, :Nought)
:Nought
else
nothing
end
TIC_TAC_TOE_STATE_INFO[env] = (
index=n,
is_terminated=!(has_empty_pos && isnothing(w)),
winner=w,
)
end
end
end
@info "finished initializing tictactoe state info cache in $t seconds"
end
TIC_TAC_TOE_STATE_INFO
end

RLBase.current_player(env::TicTacToeEnv) = env.player

RLBase.NumAgentStyle(::TicTacToeEnv) = MultiAgent(2)
RLBase.DynamicStyle(::TicTacToeEnv) = SEQUENTIAL
RLBase.ActionStyle(::TicTacToeEnv) = FULL_ACTION_SET
RLBase.InformationStyle(::TicTacToeEnv) = PERFECT_INFORMATION
RLBase.StateStyle(::TicTacToeEnv) =
(Observation{Int}(), Observation{String}(), Observation{BitArray{3}}())
RLBase.RewardStyle(::TicTacToeEnv) = TERMINAL_REWARD
RLBase.UtilityStyle(::TicTacToeEnv) = ZERO_SUM
RLBase.ChanceStyle(::TicTacToeEnv) = DETERMINISTIC
5 changes: 2 additions & 3 deletions src/ReinforcementLearningCore/test/policies/agent.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ReinforcementLearningBase, ReinforcementLearningEnvironments
using ReinforcementLearningCore: SRT
using ReinforcementLearningCore
using ReinforcementLearningBase
import ReinforcementLearningCore.SRT

@testset "agent.jl" begin
@testset "Agent Tests" begin
Expand Down
5 changes: 2 additions & 3 deletions src/ReinforcementLearningCore/test/policies/multi_agent.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Test
using ReinforcementLearningEnvironments
using ReinforcementLearningTrajectories
using ReinforcementLearningCore
using ReinforcementLearningBase
Expand Down Expand Up @@ -85,8 +84,8 @@ end
@test multiagent_hook.hooks[:Cross].steps[1] > 0

@test RLBase.is_terminated(env)
@test RLEnvs.is_win(env, :Cross) isa Bool
@test RLEnvs.is_win(env, :Nought) isa Bool
@test is_win(env, :Cross) isa Bool
@test is_win(env, :Nought) isa Bool
@test RLBase.reward(env, :Cross) == (RLBase.reward(env, :Nought) * -1)
@test RLBase.legal_action_space_mask(env, :Cross) == falses(9)
@test RLBase.legal_action_space(env) == []
Expand Down
4 changes: 3 additions & 1 deletion src/ReinforcementLearningCore/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
using ReinforcementLearningBase
using ReinforcementLearningCore
using ReinforcementLearningEnvironments
using ReinforcementLearningTrajectories

using Test
using CUDA
using CircularArrayBuffers
using Flux

include("environments/randomwalk1D.jl")
include("environments/tictactoe.jl")
include("environments/rockpaperscissors.jl")
@testset "ReinforcementLearningCore.jl" begin
include("core/core.jl")
include("core/stop_conditions.jl")
Expand Down

0 comments on commit 3b21982

Please sign in to comment.