From 9e0218277ea80ee2b4a6da2455481c329dfb5bbd Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 19 Feb 2021 18:18:58 +0800 Subject: [PATCH] Simplify env wrapper (#127) * simplify the definition of environment wrappers * simplify further --- src/environments/examples/KuhnPokerEnv.jl | 32 ++++++++----------- .../wrappers/ActionTransformedEnv.jl | 17 ++-------- .../wrappers/DefaultStateStyle.jl | 19 ----------- src/environments/wrappers/MaxTimeoutEnv.jl | 15 ++------- .../wrappers/RewardOverriddenEnv.jl | 15 +-------- src/environments/wrappers/StateCachedEnv.jl | 13 +------- .../wrappers/StateOverriddenEnv.jl | 17 +--------- src/environments/wrappers/StochasticEnv.jl | 15 ++------- src/environments/wrappers/wrappers.jl | 16 ++++++++++ 9 files changed, 40 insertions(+), 119 deletions(-) diff --git a/src/environments/examples/KuhnPokerEnv.jl b/src/environments/examples/KuhnPokerEnv.jl index 61c0c2c..fcede73 100644 --- a/src/environments/examples/KuhnPokerEnv.jl +++ b/src/environments/examples/KuhnPokerEnv.jl @@ -4,14 +4,11 @@ const KUHN_POKER_CARDS = (:J, :Q, :K) const KUHN_POKER_CARD_COMBINATIONS = ((:J, :Q), (:J, :K), (:Q, :J), (:Q, :K), (:K, :J), (:K, :Q)) const KUHN_POKER_ACTIONS = (:pass, :bet) -const KUHN_POKER_STATES = ( - (), +const KUHN_POKER_STATES = ((), map(tuple, KUHN_POKER_CARDS)..., KUHN_POKER_CARD_COMBINATIONS..., ( - (cards..., actions...) for cards in ((), map(tuple, KUHN_POKER_CARDS)...) for - actions in ( - (), + (cards..., actions...) for cards in ((), map(tuple, KUHN_POKER_CARDS)...) for actions in ((), (:bet,), (:bet, :bet), (:bet, :pass), @@ -19,10 +16,8 @@ const KUHN_POKER_STATES = ( (:pass, :pass), (:pass, :bet), (:pass, :bet, :pass), - (:pass, :bet, :bet), - ) - )..., -) + (:pass, :bet, :bet),) + )...,) """ ![](https://upload.wikimedia.org/wikipedia/commons/a/a9/Kuhn_poker_tree.svg) @@ -146,15 +141,16 @@ end RLBase.current_player(env::KuhnPokerEnv) = if length(env.cards) < 2 - CHANCE_PLAYER - elseif length(env.actions) == 0 - 1 - elseif length(env.actions) == 1 - 2 - elseif length(env.actions) == 2 - 1 - else - end + CHANCE_PLAYER +elseif length(env.actions) == 0 + 1 +elseif length(env.actions) == 1 + 2 +elseif length(env.actions) == 2 + 1 +else + 2 # actually the game is over now +end RLBase.NumAgentStyle(::KuhnPokerEnv) = MultiAgent(2) RLBase.DynamicStyle(::KuhnPokerEnv) = SEQUENTIAL diff --git a/src/environments/wrappers/ActionTransformedEnv.jl b/src/environments/wrappers/ActionTransformedEnv.jl index 63f84d5..0bcf2e0 100644 --- a/src/environments/wrappers/ActionTransformedEnv.jl +++ b/src/environments/wrappers/ActionTransformedEnv.jl @@ -1,6 +1,6 @@ export ActionTransformedEnv -struct ActionTransformedEnv{P,M,E<:AbstractEnv} <: AbstractEnvWrapper +struct ActionTransformedEnv{P,M,E <: AbstractEnv} <: AbstractEnvWrapper action_space_mapping::P action_mapping::M env::E @@ -15,23 +15,12 @@ feeding it into `env`. """ function ActionTransformedEnv( env; - action_space_mapping = identity, - action_mapping = identity, + action_space_mapping=identity, + action_mapping=identity, ) ActionTransformedEnv(action_space_mapping, action_mapping, env) end -for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API) - if f ∉ (:action_space, :legal_action_space) - @eval RLBase.$f(x::ActionTransformedEnv, args...; kwargs...) = - $f(x.env, args...; kwargs...) - end -end - -RLBase.state(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss) -RLBase.state_space(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) = - state_space(env.env, ss) - RLBase.action_space(env::ActionTransformedEnv, args...) = env.action_space_mapping(action_space(env.env), args...) diff --git a/src/environments/wrappers/DefaultStateStyle.jl b/src/environments/wrappers/DefaultStateStyle.jl index 6db5439..7506ca6 100644 --- a/src/environments/wrappers/DefaultStateStyle.jl +++ b/src/environments/wrappers/DefaultStateStyle.jl @@ -12,22 +12,3 @@ Reset the result of `DefaultStateStyle` without changing the original behavior. DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env) RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S - -for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API) - if f ∉ (:DefaultStateStyle, :state, :state_space) - @eval RLBase.$f(x::DefaultStateStyleEnv, args...; kwargs...) = - $f(x.env, args...; kwargs...) - end -end - -(env::DefaultStateStyleEnv)(args...; kwargs...) = env.env(args...; kwargs...) - -RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss) -RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) = - state(env.env, ss, p) - -RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = - state_space(env.env, ss) - -RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) = - state_space(env.env, ss, p) diff --git a/src/environments/wrappers/MaxTimeoutEnv.jl b/src/environments/wrappers/MaxTimeoutEnv.jl index aed04b2..38dc19d 100644 --- a/src/environments/wrappers/MaxTimeoutEnv.jl +++ b/src/environments/wrappers/MaxTimeoutEnv.jl @@ -1,6 +1,6 @@ export MaxTimeoutEnv -mutable struct MaxTimeoutEnv{E<:AbstractEnv} <: AbstractEnvWrapper +mutable struct MaxTimeoutEnv{E <: AbstractEnv} <: AbstractEnvWrapper env::E max_t::Int current_t::Int @@ -11,7 +11,7 @@ end Force `is_terminated(env)` return `true` after `max_t` interactions. """ -MaxTimeoutEnv(env::E, max_t::Int; current_t::Int = 1) where {E<:AbstractEnv} = +MaxTimeoutEnv(env::E, max_t::Int; current_t::Int=1) where {E <: AbstractEnv} = MaxTimeoutEnv(env, max_t, current_t) function (env::MaxTimeoutEnv)(args...; kwargs...) @@ -19,13 +19,6 @@ function (env::MaxTimeoutEnv)(args...; kwargs...) env.current_t = env.current_t + 1 end -for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API) - if f ∉ (:is_terminated, :reset!) - @eval RLBase.$f(x::MaxTimeoutEnv, args...; kwargs...) = - $f(x.env, args...; kwargs...) - end -end - RLBase.is_terminated(env::MaxTimeoutEnv) = (env.current_t > env.max_t) || is_terminated(env.env) @@ -33,7 +26,3 @@ function RLBase.reset!(env::MaxTimeoutEnv) env.current_t = 1 RLBase.reset!(env.env) end - -RLBase.state(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss) -RLBase.state_space(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) = - state_space(env.env, ss) diff --git a/src/environments/wrappers/RewardOverriddenEnv.jl b/src/environments/wrappers/RewardOverriddenEnv.jl index 9852deb..69a305e 100644 --- a/src/environments/wrappers/RewardOverriddenEnv.jl +++ b/src/environments/wrappers/RewardOverriddenEnv.jl @@ -1,24 +1,11 @@ export RewardOverriddenEnv -struct RewardOverriddenEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper +struct RewardOverriddenEnv{F,E <: AbstractEnv} <: AbstractEnvWrapper env::E f::F end -(env::RewardOverriddenEnv)(args...; kwargs...) = env.env(args...; kwargs...) - RewardOverriddenEnv(f) = env -> RewardOverriddenEnv(f, env) -for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API) - if f != :reward - @eval RLBase.$f(x::RewardOverriddenEnv, args...; kwargs...) = - $f(x.env, args...; kwargs...) - end -end - RLBase.reward(env::RewardOverriddenEnv, args...; kwargs...) = env.f(reward(env.env, args...; kwargs...)) - -RLBase.state(env::RewardOverriddenEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss) -RLBase.state_space(env::RewardOverriddenEnv, ss::RLBase.AbstractStateStyle) = - state_space(env.env, ss) diff --git a/src/environments/wrappers/StateCachedEnv.jl b/src/environments/wrappers/StateCachedEnv.jl index 4cb6704..5c727ce 100644 --- a/src/environments/wrappers/StateCachedEnv.jl +++ b/src/environments/wrappers/StateCachedEnv.jl @@ -6,7 +6,7 @@ the next interaction with `env`. This function is useful because some environments are stateful during each `state(env)`. For example: `StateOverriddenEnv(StackFrames(...))`. """ -mutable struct StateCachedEnv{S,E<:AbstractEnv} <: AbstractEnvWrapper +mutable struct StateCachedEnv{S,E <: AbstractEnv} <: AbstractEnvWrapper s::S env::E is_state_cached::Bool @@ -28,14 +28,3 @@ function RLBase.state(env::StateCachedEnv, args...; kwargs...) env.s end end - -for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API) - if f != :state - @eval RLBase.$f(x::StateCachedEnv, args...; kwargs...) = - $f(x.env, args...; kwargs...) - end -end - -RLBase.state(env::StateCachedEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss) -RLBase.state_space(env::StateCachedEnv, ss::RLBase.AbstractStateStyle) = - state_space(env.env, ss) diff --git a/src/environments/wrappers/StateOverriddenEnv.jl b/src/environments/wrappers/StateOverriddenEnv.jl index 6192f72..86c2da4 100644 --- a/src/environments/wrappers/StateOverriddenEnv.jl +++ b/src/environments/wrappers/StateOverriddenEnv.jl @@ -9,27 +9,12 @@ Apply `f` to override `state(env)`. If the meaning of state space is changed after apply `f`, one should manually redefine the `RLBase.state_space(env::YourSpecificEnv)`. """ -struct StateOverriddenEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper +struct StateOverriddenEnv{F,E <: AbstractEnv} <: AbstractEnvWrapper env::E f::F end StateOverriddenEnv(f) = env -> StateOverriddenEnv(f, env) -(env::StateOverriddenEnv)(args...; kwargs...) = env.env(args...; kwargs...) - -for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API) - if f ∉ (:state,) - @eval RLBase.$f(x::StateOverriddenEnv, args...; kwargs...) = - $f(x.env, args...; kwargs...) - end -end - RLBase.state(env::StateOverriddenEnv, args...; kwargs...) = env.f(state(env.env, args...; kwargs...)) - -RLBase.state(env::StateOverriddenEnv, ss::RLBase.AbstractStateStyle) = - env.f(state(env.env, ss)) - -RLBase.state_space(env::StateOverriddenEnv, ss::RLBase.AbstractStateStyle) = - state_space(env.env, ss) diff --git a/src/environments/wrappers/StochasticEnv.jl b/src/environments/wrappers/StochasticEnv.jl index a5ebced..2182b6b 100644 --- a/src/environments/wrappers/StochasticEnv.jl +++ b/src/environments/wrappers/StochasticEnv.jl @@ -2,12 +2,12 @@ export StochasticEnv using StatsBase: sample, Weights -struct StochasticEnv{E<:AbstractEnv,R} <: AbstractEnv +struct StochasticEnv{E <: AbstractEnv,R} <: AbstractEnvWrapper env::E rng::R end -function StochasticEnv(env; rng = Random.GLOBAL_RNG) +function StochasticEnv(env; rng=Random.GLOBAL_RNG) ChanceStyle(env) === EXPLICIT_STOCHASTIC || throw(ArgumentError("only environments of EXPLICIT_STOCHASTIC style is supported")) env = StochasticEnv(env, rng) @@ -39,14 +39,3 @@ RLBase.ChanceStyle(::StochasticEnv) = STOCHASTIC RLBase.players(env::StochasticEnv) = [p for p in players(env.env) if p != chance_player(env.env)] Random.seed!(env::StochasticEnv, s) = Random.seed!(env.rng, s) - -for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API) - if f ∉ (:players, :ChanceStyle, :reset!) - @eval RLBase.$f(x::StochasticEnv, args...; kwargs...) = - $f(x.env, args...; kwargs...) - end -end - -RLBase.state(env::StochasticEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss) -RLBase.state_space(env::StochasticEnv, ss::RLBase.AbstractStateStyle) = - state_space(env.env, ss) diff --git a/src/environments/wrappers/wrappers.jl b/src/environments/wrappers/wrappers.jl index 5324ce5..608b73b 100644 --- a/src/environments/wrappers/wrappers.jl +++ b/src/environments/wrappers/wrappers.jl @@ -1,7 +1,23 @@ +export AbstractEnvWrapper + abstract type AbstractEnvWrapper <: AbstractEnv end Base.nameof(env::AbstractEnvWrapper) = "$(nameof(env.env)) |> $(nameof(typeof(env)))" +Base.getindex(env::AbstractEnvWrapper) = env.env + +(env::AbstractEnvWrapper)(args...; kwargs...) = env.env(args...; kwargs...) + +for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API) + @eval RLBase.$f(x::AbstractEnvWrapper, args...; kwargs...) = $f(x[], args...; kwargs...) +end + +# avoid ambiguous +RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, p) = state(env[], ss, p) +RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle) = state(env[], ss) +RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle) = state_space(env[], ss) +RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, p) = state_space(env[], ss, p) + include("ActionTransformedEnv.jl") include("DefaultStateStyle.jl") include("MaxTimeoutEnv.jl")