Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove rlenv dep for tests #989

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ cuDNN = "1"
julia = "1.9"

[extras]
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReinforcementLearningEnvironments = "25e41dd2-4622-11e9-1641-f1adca772921"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DomainSets", "Test", "Random", "ReinforcementLearningEnvironments"]
test = ["CommonRLInterface","DomainSets", "Test", "Random"]
66 changes: 66 additions & 0 deletions src/ReinforcementLearningCore/test/environments/randomwalk1D.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import CommonRLInterface

"""
RandomWalk1D(;rewards=-1. => 1.0, N=7, start_pos=(N+1) ÷ 2, actions=[-1,1])

An agent is placed at the `start_pos` and can move left or right (stride is
defined in actions). The game terminates when the agent reaches either end and
receives a reward correspondingly.

Compared to the [`MultiArmBanditsEnv`](@ref):

1. The state space is more complicated (well, not that complicated though).
1. It's a sequential game of multiple action steps.
1. It's a deterministic game instead of stochastic game.
"""
Base.@kwdef mutable struct RandomWalk1D <: AbstractEnv
rewards::Pair{Float64,Float64} = -1.0 => 1.0
N::Int = 7
actions::Vector{Int} = [-1, 1]
start_pos::Int = (N + 1) ÷ 2
pos::Int = start_pos

action_space::Base.OneTo = Base.OneTo(length(actions))
state_space::Base.OneTo = Base.OneTo(N)
end

RLBase.action_space(env::RandomWalk1D) = env.action_space

function RLBase.act!(env::RandomWalk1D, action::Int)
env.pos += env.actions[action]
if env.pos > env.N
env.pos = env.N

Check warning on line 32 in src/ReinforcementLearningCore/test/environments/randomwalk1D.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/randomwalk1D.jl#L32

Added line #L32 was not covered by tests
elseif env.pos < 1
env.pos = 1

Check warning on line 34 in src/ReinforcementLearningCore/test/environments/randomwalk1D.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/randomwalk1D.jl#L34

Added line #L34 was not covered by tests
end
return
end

RLBase.state(env::RandomWalk1D) = env.pos
RLBase.state_space(env::RandomWalk1D) = env.state_space

Check warning on line 40 in src/ReinforcementLearningCore/test/environments/randomwalk1D.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/randomwalk1D.jl#L40

Added line #L40 was not covered by tests
RLBase.is_terminated(env::RandomWalk1D) = env.pos == 1 || env.pos == env.N
RLBase.reset!(env::RandomWalk1D) = env.pos = env.start_pos

RLBase.reward(env::RandomWalk1D) = random_walk_reward(env.pos, env.rewards, env.N)

function random_walk_reward(pos::Int, rewards::Pair{Float64,Float64}, N::Int)
if pos == 1
return random_walk_reward_first(rewards)
elseif pos == N
return random_walk_reward_last(rewards)
else
return 0.0
end
end

random_walk_reward_first(rewards::Pair{Float64,Float64}) = first(rewards)
random_walk_reward_last(rewards::Pair{Float64,Float64}) = last(rewards)

RLBase.NumAgentStyle(::RandomWalk1D) = SINGLE_AGENT
RLBase.DynamicStyle(::RandomWalk1D) = SEQUENTIAL

Check warning on line 60 in src/ReinforcementLearningCore/test/environments/randomwalk1D.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/randomwalk1D.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
RLBase.ActionStyle(::RandomWalk1D) = MINIMAL_ACTION_SET
RLBase.InformationStyle(::RandomWalk1D) = PERFECT_INFORMATION
RLBase.StateStyle(::RandomWalk1D) = Observation{Int}()
RLBase.RewardStyle(::RandomWalk1D) = TERMINAL_REWARD
RLBase.UtilityStyle(::RandomWalk1D) = GENERAL_SUM
RLBase.ChanceStyle(::RandomWalk1D) = DETERMINISTIC

Check warning on line 66 in src/ReinforcementLearningCore/test/environments/randomwalk1D.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/randomwalk1D.jl#L62-L66

Added lines #L62 - L66 were not covered by tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
RockPaperScissorsEnv()

[Rock Paper Scissors](https://en.wikipedia.org/wiki/Rock_paper_scissors) is a
simultaneous, zero sum game.
"""
Base.@kwdef mutable struct RockPaperScissorsEnv <: AbstractEnv
reward::NamedTuple{(Symbol(1), Symbol(2)), Tuple{Int64, Int64}} = (; Symbol(1) => 0, Symbol(2) => 0)
is_done::Bool = false
end

RLBase.players(::RockPaperScissorsEnv) = (Symbol(1), Symbol(2))

"""
Note that although this is a two player game. The current player is always a
dummy simultaneous player.
"""
RLBase.current_player(::RockPaperScissorsEnv) = SIMULTANEOUS_PLAYER

RLBase.action_space(::RockPaperScissorsEnv, ::Symbol) = ('💎', '📃', '✂')

RLBase.action_space(::RockPaperScissorsEnv, ::SimultaneousPlayer) =
Tuple((i, j) for i in ('💎', '📃', '✂') for j in ('💎', '📃', '✂'))

RLBase.action_space(env::RockPaperScissorsEnv) = action_space(env, SIMULTANEOUS_PLAYER)

Check warning on line 25 in src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl#L25

Added line #L25 was not covered by tests

RLBase.legal_action_space(env::RockPaperScissorsEnv, p) =
is_terminated(env) ? () : action_space(env, p)

"Since it's a one-shot game, the state space doesn't have much meaning."
RLBase.state_space(::RockPaperScissorsEnv, ::Observation, p) = Base.OneTo(1)

Check warning on line 31 in src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl#L31

Added line #L31 was not covered by tests

"""
For multi-agent environments, we usually implement the most detailed one.
"""
RLBase.state(::RockPaperScissorsEnv, ::Observation, p) = 1

RLBase.reward(env::RockPaperScissorsEnv) = env.is_done ? env.reward : (; Symbol(1) => 0, Symbol(2) => 0)
RLBase.reward(env::RockPaperScissorsEnv, p::Symbol) = reward(env)[p]

RLBase.is_terminated(env::RockPaperScissorsEnv) = env.is_done
RLBase.reset!(env::RockPaperScissorsEnv) = env.is_done = false

# TODO: Consider using CRL.all_act! and adjusting run function accordingly
function RLBase.act!(env::RockPaperScissorsEnv, (x, y))
if x == y
env.reward = (; Symbol(1) => 0, Symbol(2) => 0)
elseif x == '💎' && y == '✂' || x == '✂' && y == '📃' || x == '📃' && y == '💎'
env.reward = (; Symbol(1) => 1, Symbol(2) => -1)
else
env.reward = (; Symbol(1) => -1, Symbol(2) => 1)
end
env.is_done = true
end

RLBase.NumAgentStyle(::RockPaperScissorsEnv) = MultiAgent(2)

Check warning on line 56 in src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl#L56

Added line #L56 was not covered by tests
RLBase.DynamicStyle(::RockPaperScissorsEnv) = SIMULTANEOUS
RLBase.ActionStyle(::RockPaperScissorsEnv) = MINIMAL_ACTION_SET
RLBase.InformationStyle(::RockPaperScissorsEnv) = IMPERFECT_INFORMATION

Check warning on line 59 in src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl#L58-L59

Added lines #L58 - L59 were not covered by tests
RLBase.StateStyle(::RockPaperScissorsEnv) = Observation{Int}()
RLBase.RewardStyle(::RockPaperScissorsEnv) = TERMINAL_REWARD
RLBase.UtilityStyle(::RockPaperScissorsEnv) = ZERO_SUM
RLBase.ChanceStyle(::RockPaperScissorsEnv) = DETERMINISTIC

Check warning on line 63 in src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/rockpaperscissors.jl#L61-L63

Added lines #L61 - L63 were not covered by tests
171 changes: 171 additions & 0 deletions src/ReinforcementLearningCore/test/environments/tictactoe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import ReinforcementLearningBase: RLBase
import ReinforcementLearningCore: RLCore
import CommonRLInterface

mutable struct TicTacToeEnv <: AbstractEnv
board::BitArray{3}
player::Symbol
end

function TicTacToeEnv()
board = BitArray{3}(undef, 3, 3, 3)
fill!(board, false)
board[:, :, 1] .= true
TicTacToeEnv(board, :Cross)
end

function RLBase.reset!(env::TicTacToeEnv)
fill!(env.board, false)
env.board[:, :, 1] .= true
env.player = :Cross
end

struct TicTacToeInfo
is_terminated::Bool
winner::Union{Nothing,Symbol}
end

const TIC_TAC_TOE_STATE_INFO = Dict{
TicTacToeEnv,
NamedTuple{
(:index, :is_terminated, :winner),
Tuple{Int,Bool,Union{Nothing,Symbol}},
},
}()

Base.hash(env::TicTacToeEnv, h::UInt) = hash(env.board, h)
Base.isequal(a::TicTacToeEnv, b::TicTacToeEnv) = isequal(a.board, b.board)

Base.to_index(::TicTacToeEnv, player) = player == :Cross ? 2 : 3

RLBase.action_space(::TicTacToeEnv, player) = Base.OneTo(9)

Check warning on line 41 in src/ReinforcementLearningCore/test/environments/tictactoe.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/tictactoe.jl#L41

Added line #L41 was not covered by tests

RLBase.legal_action_space(env::TicTacToeEnv, p) = findall(legal_action_space_mask(env))

function RLBase.legal_action_space_mask(env::TicTacToeEnv, p)
if is_win(env, :Cross) || is_win(env, :Nought)
falses(9)
else
vec(env.board[:, :, 1])
end
end

RLBase.act!(env::TicTacToeEnv, action::Int) = RLBase.act!(env, CartesianIndices((3, 3))[action])

function RLBase.act!(env::TicTacToeEnv, action::CartesianIndex{2})
env.board[action, 1] = false
env.board[action, Base.to_index(env, current_player(env))] = true
end

function RLBase.next_player!(env::TicTacToeEnv)
env.player = env.player == :Cross ? :Nought : :Cross
end

RLBase.players(::TicTacToeEnv) = (:Cross, :Nought)

RLBase.state(env::TicTacToeEnv) = state(env, Observation{Int}(), 1)
RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = env.board
RLBase.state(env::TicTacToeEnv, ::RLBase.AbstractStateStyle) = state(env::TicTacToeEnv, Observation{Int}(), 1)

Check warning on line 68 in src/ReinforcementLearningCore/test/environments/tictactoe.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/tictactoe.jl#L68

Added line #L68 was not covered by tests
RLBase.state(env::TicTacToeEnv, ::Observation{Int}, p) =
get_tic_tac_toe_state_info()[env].index

RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = ArrayProductDomain(fill(false:true, 3, 3, 3))
RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, p) =

Check warning on line 73 in src/ReinforcementLearningCore/test/environments/tictactoe.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/tictactoe.jl#L73

Added line #L73 was not covered by tests
Base.OneTo(length(get_tic_tac_toe_state_info()))
RLBase.state_space(env::TicTacToeEnv, ::Observation{String}, p) = fullspace(String)

RLBase.state(env::TicTacToeEnv, ::Observation{String}) = state(env::TicTacToeEnv, Observation{String}(), 1)

function RLBase.state(env::TicTacToeEnv, ::Observation{String}, p)
buff = IOBuffer()
for i in 1:3
for j in 1:3
if env.board[i, j, 1]
x = '.'
elseif env.board[i, j, 2]
x = 'x'
else
x = 'o'
end
print(buff, x)
end
print(buff, '\n')
end
String(take!(buff))
end

RLBase.is_terminated(env::TicTacToeEnv) = get_tic_tac_toe_state_info()[env].is_terminated

function RLBase.reward(env::TicTacToeEnv, player::Symbol)
if is_terminated(env)
winner = get_tic_tac_toe_state_info()[env].winner
if isnothing(winner)
0

Check warning on line 103 in src/ReinforcementLearningCore/test/environments/tictactoe.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/tictactoe.jl#L103

Added line #L103 was not covered by tests
elseif winner === player
1
else
-1
end
else
0
end
end

function is_win(env::TicTacToeEnv, player::Symbol)
b = env.board
p = Base.to_index(env, player)
@inbounds begin
b[1, 1, p] & b[1, 2, p] & b[1, 3, p] ||
b[2, 1, p] & b[2, 2, p] & b[2, 3, p] ||
b[3, 1, p] & b[3, 2, p] & b[3, 3, p] ||
b[1, 1, p] & b[2, 1, p] & b[3, 1, p] ||
b[1, 2, p] & b[2, 2, p] & b[3, 2, p] ||
b[1, 3, p] & b[2, 3, p] & b[3, 3, p] ||
b[1, 1, p] & b[2, 2, p] & b[3, 3, p] ||
b[1, 3, p] & b[2, 2, p] & b[3, 1, p]
end
end

function get_tic_tac_toe_state_info()
if isempty(TIC_TAC_TOE_STATE_INFO)
@info "initializing tictactoe state info cache..."
t = @elapsed begin
n = 1
root = TicTacToeEnv()
TIC_TAC_TOE_STATE_INFO[root] =
(index=n, is_terminated=false, winner=nothing)
walk(root) do env
if !haskey(TIC_TAC_TOE_STATE_INFO, env)
n += 1
has_empty_pos = any(view(env.board, :, :, 1))
w = if is_win(env, :Cross)
:Cross
elseif is_win(env, :Nought)
:Nought
else
nothing
end
TIC_TAC_TOE_STATE_INFO[env] = (
index=n,
is_terminated=!(has_empty_pos && isnothing(w)),
winner=w,
)
end
end
end
@info "finished initializing tictactoe state info cache in $t seconds"
end
TIC_TAC_TOE_STATE_INFO
end

RLBase.current_player(env::TicTacToeEnv) = env.player

RLBase.NumAgentStyle(::TicTacToeEnv) = MultiAgent(2)

Check warning on line 163 in src/ReinforcementLearningCore/test/environments/tictactoe.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/tictactoe.jl#L163

Added line #L163 was not covered by tests
RLBase.DynamicStyle(::TicTacToeEnv) = SEQUENTIAL
RLBase.ActionStyle(::TicTacToeEnv) = FULL_ACTION_SET
RLBase.InformationStyle(::TicTacToeEnv) = PERFECT_INFORMATION

Check warning on line 166 in src/ReinforcementLearningCore/test/environments/tictactoe.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/tictactoe.jl#L165-L166

Added lines #L165 - L166 were not covered by tests
RLBase.StateStyle(::TicTacToeEnv) =
(Observation{Int}(), Observation{String}(), Observation{BitArray{3}}())
RLBase.RewardStyle(::TicTacToeEnv) = TERMINAL_REWARD
RLBase.UtilityStyle(::TicTacToeEnv) = ZERO_SUM
RLBase.ChanceStyle(::TicTacToeEnv) = DETERMINISTIC

Check warning on line 171 in src/ReinforcementLearningCore/test/environments/tictactoe.jl

View check run for this annotation

Codecov / codecov/patch

src/ReinforcementLearningCore/test/environments/tictactoe.jl#L169-L171

Added lines #L169 - L171 were not covered by tests
5 changes: 2 additions & 3 deletions src/ReinforcementLearningCore/test/policies/agent.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ReinforcementLearningBase, ReinforcementLearningEnvironments
using ReinforcementLearningCore: SRT
using ReinforcementLearningCore
using ReinforcementLearningBase
import ReinforcementLearningCore.SRT

@testset "agent.jl" begin
@testset "Agent Tests" begin
Expand Down
5 changes: 2 additions & 3 deletions src/ReinforcementLearningCore/test/policies/multi_agent.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Test
using ReinforcementLearningEnvironments
using ReinforcementLearningTrajectories
using ReinforcementLearningCore
using ReinforcementLearningBase
Expand Down Expand Up @@ -85,8 +84,8 @@ end
@test multiagent_hook.hooks[:Cross].steps[1] > 0

@test RLBase.is_terminated(env)
@test RLEnvs.is_win(env, :Cross) isa Bool
@test RLEnvs.is_win(env, :Nought) isa Bool
@test is_win(env, :Cross) isa Bool
@test is_win(env, :Nought) isa Bool
@test RLBase.reward(env, :Cross) == (RLBase.reward(env, :Nought) * -1)
@test RLBase.legal_action_space_mask(env, :Cross) == falses(9)
@test RLBase.legal_action_space(env) == []
Expand Down
4 changes: 3 additions & 1 deletion src/ReinforcementLearningCore/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
using ReinforcementLearningBase
using ReinforcementLearningCore
using ReinforcementLearningEnvironments
using ReinforcementLearningTrajectories

using Test
using CUDA
using CircularArrayBuffers
using Flux

include("environments/randomwalk1D.jl")
include("environments/tictactoe.jl")
include("environments/rockpaperscissors.jl")
@testset "ReinforcementLearningCore.jl" begin
include("core/core.jl")
include("core/stop_conditions.jl")
Expand Down
Loading