Skip to content
This repository has been archived by the owner on Aug 11, 2023. It is now read-only.

Commit

Permalink
add tests for POMDP models (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
findmyway authored Feb 1, 2021
1 parent b0aba19 commit 4b31d85
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
AbstractTrees = "0.3"
CommonRLInterface = "0.2"
CommonRLInterface = "0.3"
julia = "1.3"
80 changes: 45 additions & 35 deletions src/CommonRLInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,52 +10,54 @@ struct CommonRLEnv{T<:AbstractEnv} <: CRL.AbstractEnv
env::T
end

struct CommonRLMarkovEnv{T<:AbstractEnv} <: CRL.AbstractMarkovEnv
env::T
function Base.convert(::Type{CRL.AbstractEnv}, env::AbstractEnv)
CommonRLEnv(env)
end

struct CommonRLZeroSumEnv{T<:AbstractEnv} <: CRL.AbstractZeroSumEnv
env::T
CRL.@provide CRL.reset!(env::CommonRLEnv) = reset!(env.env)
CRL.@provide CRL.actions(env::CommonRLEnv) = action_space(env.env)
CRL.@provide CRL.terminated(env::CommonRLEnv) = is_terminated(env.env)

CRL.@provide function CRL.act!(env::CommonRLEnv, a)
env.env(a)
reward(env.env)
end

function find_state_style(env::AbstractEnv, s)
find_state_style(StateStyle(env), s)
end

const CommonRLEnvs = Union{CommonRLEnv,CommonRLMarkovEnv,CommonRLZeroSumEnv}
find_state_style(::Tuple{}, s) = nothing

function Base.convert(::Type{CRL.AbstractEnv}, env::AbstractEnv)
if NumAgentStyle(env) === SINGLE_AGENT
convert(CRL.AbstractMarkovEnv, env)
elseif NumAgentStyle(env) isa MultiAgent{2} && UtilityStyle(env) === ZERO_SUM
convert(CRL.AbstractZeroSumEnv, env)
function find_state_style(ss::Tuple, s)
x = first(ss)
if x isa s
x
else
CommonRLEnv(env)
find_state_style(Base.tail(ss), s)
end
end

Base.convert(::Type{CRL.AbstractMarkovEnv}, env::AbstractEnv) = CommonRLMarkovEnv(env)
Base.convert(::Type{CRL.AbstractZeroSumEnv}, env::AbstractEnv) = CommonRLZeroSumEnv(env)
# !!! may need to be extended by user
CRL.@provide CRL.observe(env::CommonRLEnv) = state(env.env)

CRL.@provide CRL.reset!(env::CommonRLEnvs) = reset!(env.env)
CRL.@provide CRL.actions(env::CommonRLEnvs) = action_space(env.env)
CRL.@provide CRL.observe(env::CommonRLEnvs) = state(env.env)
CRL.state(env::CommonRLEnvs) = state(env.env)
CRL.provided(::typeof(CRL.state), env::CommonRLEnvs) =
InformationStyle(env.env) === PERFECT_INFORMATION
CRL.@provide CRL.terminated(env::CommonRLEnvs) = is_terminated(env.env)
CRL.@provide CRL.player(env::CommonRLEnvs) = current_player(env.env)
CRL.@provide CRL.clone(env::CommonRLEnvs) = CommonRLEnv(copy(env.env))
CRL.provided(::typeof(CRL.state), env::CommonRLEnv) = !isnothing(find_state_style(env.env, InternalState))
CRL.state(env::CommonRLEnv) = state(env.env, find_state_style(env.env, InternalState))

CRL.@provide function CRL.act!(env::CommonRLEnvs, a)
env.env(a)
reward(env.env)
end
CRL.@provide CRL.clone(env::CommonRLEnv) = CommonRLEnv(copy(env.env))
CRL.@provide CRL.render(env::CommonRLEnv) = @error "unsupported yet..."
CRL.@provide CRL.player(env::CommonRLEnv) = current_player(env.env)

CRL.valid_actions(x::CommonRLEnvs) = legal_action_space(x.env)
CRL.provided(::typeof(CRL.valid_actions), env::CommonRLEnvs) =
CRL.valid_actions(x::CommonRLEnv) = legal_action_space(x.env)
CRL.provided(::typeof(CRL.valid_actions), env::CommonRLEnv) =
ActionStyle(env.env) === FullActionSet()

CRL.valid_action_mask(x::CommonRLEnvs) = legal_action_space_mask(x.env)
CRL.provided(::typeof(CRL.valid_action_mask), env::CommonRLEnvs) =
CRL.valid_action_mask(x::CommonRLEnv) = legal_action_space_mask(x.env)
CRL.provided(::typeof(CRL.valid_action_mask), env::CommonRLEnv) =
ActionStyle(env.env) === FullActionSet()

CRL.@provide CRL.observations(env::CommonRLEnv) = state_space(env.env)

#####
# RLBaseEnv
#####
Expand All @@ -68,8 +70,16 @@ end
Base.convert(::Type{AbstractEnv}, env::CRL.AbstractEnv) = convert(RLBaseEnv, env)
Base.convert(::Type{RLBaseEnv}, env::CRL.AbstractEnv) = RLBaseEnv(env, 0.0f0) # can not determine reward ahead. Assume `Float32`.

state(env::RLBaseEnv) = CRL.observe(env.env)
state_space(env::RLBaseEnv) = CRL.observations(env.env)
RLBase.StateStyle(env::RLBaseEnv) = (
(CRL.provided(CRL.observe, env.env) ? (Observation{Any}(),) : ())...,
(CRL.provided(CRL.state, env.env) ? (InternalState{Any}(),) : ())...,
)

state(env::RLBaseEnv, ::Observation) = CRL.observe(env.env)
state(env::RLBaseEnv, ::InternalState) = CRL.state(env.env)

state_space(env::RLBaseEnv, ::Observation) = CRL.observations(env.env)

action_space(env::RLBaseEnv) = CRL.actions(env.env)
reward(env::RLBaseEnv) = env.r
is_terminated(env::RLBaseEnv) = CRL.terminated(env.env)
Expand All @@ -82,6 +92,6 @@ Base.copy(env::CommonRLEnv) = RLBaseEnv(CRL.clone(env.env), env.r)

ActionStyle(env::RLBaseEnv) =
CRL.provided(CRL.valid_actions, env.env) ? FullActionSet() : MinimalActionSet()
UtilityStyle(env::RLBaseEnv) = GENERAL_SUM
UtilityStyle(env::RLBaseEnv{<:CRL.AbstractZeroSumEnv}) = ZERO_SUM
InformationStyle(env::RLBaseEnv) = IMPERFECT_INFORMATION

current_player(env::RLBaseEnv) = CRL.player(env.env)
players(env::RLBaseEnv) = CRL.players(env.env)
34 changes: 34 additions & 0 deletions test/CommonRLInterface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
@testset "CommonRLInterface" begin
@testset "MDPEnv" begin
struct RLTestMDP <: MDP{Int, Int} end

POMDPs.actions(m::RLTestMDP) = [-1, 1]
POMDPs.transition(m::RLTestMDP, s, a) = Deterministic(clamp(s + a, 1, 3))
POMDPs.initialstate(m::RLTestMDP) = Deterministic(1)
POMDPs.isterminal(m::RLTestMDP, s) = s == 3
POMDPs.reward(m::RLTestMDP, s, a, sp) = sp
POMDPs.states(m::RLTestMDP) = 1:3

env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestMDP()))
RLBase.test_runnable!(env)
end

@testset "POMDPEnv" begin

struct RLTestPOMDP <: POMDP{Int, Int, Int} end

POMDPs.actions(m::RLTestPOMDP) = [-1, 1]
POMDPs.states(m::RLTestPOMDP) = 1:3
POMDPs.transition(m::RLTestPOMDP, s, a) = Deterministic(clamp(s + a, 1, 3))
POMDPs.observation(m::RLTestPOMDP, s, a, sp) = Deterministic(sp + 1)
POMDPs.initialstate(m::RLTestPOMDP) = Deterministic(1)
POMDPs.initialobs(m::RLTestPOMDP, s) = Deterministic(s + 1)
POMDPs.isterminal(m::RLTestPOMDP, s) = s == 3
POMDPs.reward(m::RLTestPOMDP, s, a, sp) = sp
POMDPs.observations(m::RLTestPOMDP) = 2:4

env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestPOMDP()))

RLBase.test_runnable!(env)
end
end
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[deps]
CommonRLInterface = "d842c3ba-07a1-494f-bbec-f5741b0a3e98"
POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10 changes: 9 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
using ReinforcementLearningBase
using Test

@testset "ReinforcementLearningBase" begin end
using CommonRLInterface
const CRL = CommonRLInterface

using POMDPs
using POMDPModelTools: Deterministic

@testset "ReinforcementLearningBase" begin
include("CommonRLInterface.jl")
end

0 comments on commit 4b31d85

Please sign in to comment.