Skip to content
This repository has been archived by the owner on May 6, 2021. It is now read-only.

Commit

Permalink
Simplify env wrapper (#127)
Browse files Browse the repository at this point in the history
* simplify the definition of environment wrappers

* simplify further
  • Loading branch information
findmyway authored Feb 19, 2021
1 parent 9f50b34 commit 9e02182
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 119 deletions.
32 changes: 14 additions & 18 deletions src/environments/examples/KuhnPokerEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,20 @@ const KUHN_POKER_CARDS = (:J, :Q, :K)
const KUHN_POKER_CARD_COMBINATIONS =
((:J, :Q), (:J, :K), (:Q, :J), (:Q, :K), (:K, :J), (:K, :Q))
const KUHN_POKER_ACTIONS = (:pass, :bet)
const KUHN_POKER_STATES = (
(),
const KUHN_POKER_STATES = ((),
map(tuple, KUHN_POKER_CARDS)...,
KUHN_POKER_CARD_COMBINATIONS...,
(
(cards..., actions...) for cards in ((), map(tuple, KUHN_POKER_CARDS)...) for
actions in (
(),
(cards..., actions...) for cards in ((), map(tuple, KUHN_POKER_CARDS)...) for actions in ((),
(:bet,),
(:bet, :bet),
(:bet, :pass),
(:pass,),
(:pass, :pass),
(:pass, :bet),
(:pass, :bet, :pass),
(:pass, :bet, :bet),
)
)...,
)
(:pass, :bet, :bet),)
)...,)

"""
![](https://upload.wikimedia.org/wikipedia/commons/a/a9/Kuhn_poker_tree.svg)
Expand Down Expand Up @@ -146,15 +141,16 @@ end

RLBase.current_player(env::KuhnPokerEnv) =
if length(env.cards) < 2
CHANCE_PLAYER
elseif length(env.actions) == 0
1
elseif length(env.actions) == 1
2
elseif length(env.actions) == 2
1
else
end
CHANCE_PLAYER
elseif length(env.actions) == 0
1
elseif length(env.actions) == 1
2
elseif length(env.actions) == 2
1
else
2 # actually the game is over now
end

RLBase.NumAgentStyle(::KuhnPokerEnv) = MultiAgent(2)
RLBase.DynamicStyle(::KuhnPokerEnv) = SEQUENTIAL
Expand Down
17 changes: 3 additions & 14 deletions src/environments/wrappers/ActionTransformedEnv.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export ActionTransformedEnv

struct ActionTransformedEnv{P,M,E<:AbstractEnv} <: AbstractEnvWrapper
struct ActionTransformedEnv{P,M,E <: AbstractEnv} <: AbstractEnvWrapper
action_space_mapping::P
action_mapping::M
env::E
Expand All @@ -15,23 +15,12 @@ feeding it into `env`.
"""
function ActionTransformedEnv(
env;
action_space_mapping = identity,
action_mapping = identity,
action_space_mapping=identity,
action_mapping=identity,
)
ActionTransformedEnv(action_space_mapping, action_mapping, env)
end

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f (:action_space, :legal_action_space)
@eval RLBase.$f(x::ActionTransformedEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.state(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::ActionTransformedEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)

RLBase.action_space(env::ActionTransformedEnv, args...) =
env.action_space_mapping(action_space(env.env), args...)

Expand Down
19 changes: 0 additions & 19 deletions src/environments/wrappers/DefaultStateStyle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,3 @@ Reset the result of `DefaultStateStyle` without changing the original behavior.
DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env)

RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f (:DefaultStateStyle, :state, :state_space)
@eval RLBase.$f(x::DefaultStateStyleEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

(env::DefaultStateStyleEnv)(args...; kwargs...) = env.env(args...; kwargs...)

RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) =
state(env.env, ss, p)

RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)

RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle, p) =
state_space(env.env, ss, p)
15 changes: 2 additions & 13 deletions src/environments/wrappers/MaxTimeoutEnv.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export MaxTimeoutEnv

mutable struct MaxTimeoutEnv{E<:AbstractEnv} <: AbstractEnvWrapper
mutable struct MaxTimeoutEnv{E <: AbstractEnv} <: AbstractEnvWrapper
env::E
max_t::Int
current_t::Int
Expand All @@ -11,29 +11,18 @@ end
Force `is_terminated(env)` return `true` after `max_t` interactions.
"""
MaxTimeoutEnv(env::E, max_t::Int; current_t::Int = 1) where {E<:AbstractEnv} =
MaxTimeoutEnv(env::E, max_t::Int; current_t::Int=1) where {E <: AbstractEnv} =
MaxTimeoutEnv(env, max_t, current_t)

function (env::MaxTimeoutEnv)(args...; kwargs...)
env.env(args...; kwargs...)
env.current_t = env.current_t + 1
end

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f (:is_terminated, :reset!)
@eval RLBase.$f(x::MaxTimeoutEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.is_terminated(env::MaxTimeoutEnv) =
(env.current_t > env.max_t) || is_terminated(env.env)

function RLBase.reset!(env::MaxTimeoutEnv)
env.current_t = 1
RLBase.reset!(env.env)
end

RLBase.state(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::MaxTimeoutEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)
15 changes: 1 addition & 14 deletions src/environments/wrappers/RewardOverriddenEnv.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,11 @@
export RewardOverriddenEnv

struct RewardOverriddenEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper
struct RewardOverriddenEnv{F,E <: AbstractEnv} <: AbstractEnvWrapper
env::E
f::F
end

(env::RewardOverriddenEnv)(args...; kwargs...) = env.env(args...; kwargs...)

RewardOverriddenEnv(f) = env -> RewardOverriddenEnv(f, env)

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f != :reward
@eval RLBase.$f(x::RewardOverriddenEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.reward(env::RewardOverriddenEnv, args...; kwargs...) =
env.f(reward(env.env, args...; kwargs...))

RLBase.state(env::RewardOverriddenEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::RewardOverriddenEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)
13 changes: 1 addition & 12 deletions src/environments/wrappers/StateCachedEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ the next interaction with `env`. This function is useful because some
environments are stateful during each `state(env)`. For example:
`StateOverriddenEnv(StackFrames(...))`.
"""
mutable struct StateCachedEnv{S,E<:AbstractEnv} <: AbstractEnvWrapper
mutable struct StateCachedEnv{S,E <: AbstractEnv} <: AbstractEnvWrapper
s::S
env::E
is_state_cached::Bool
Expand All @@ -28,14 +28,3 @@ function RLBase.state(env::StateCachedEnv, args...; kwargs...)
env.s
end
end

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f != :state
@eval RLBase.$f(x::StateCachedEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.state(env::StateCachedEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::StateCachedEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)
17 changes: 1 addition & 16 deletions src/environments/wrappers/StateOverriddenEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,12 @@ Apply `f` to override `state(env)`.
If the meaning of state space is changed after apply `f`, one should
manually redefine the `RLBase.state_space(env::YourSpecificEnv)`.
"""
struct StateOverriddenEnv{F,E<:AbstractEnv} <: AbstractEnvWrapper
struct StateOverriddenEnv{F,E <: AbstractEnv} <: AbstractEnvWrapper
env::E
f::F
end

StateOverriddenEnv(f) = env -> StateOverriddenEnv(f, env)

(env::StateOverriddenEnv)(args...; kwargs...) = env.env(args...; kwargs...)

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f (:state,)
@eval RLBase.$f(x::StateOverriddenEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.state(env::StateOverriddenEnv, args...; kwargs...) =
env.f(state(env.env, args...; kwargs...))

RLBase.state(env::StateOverriddenEnv, ss::RLBase.AbstractStateStyle) =
env.f(state(env.env, ss))

RLBase.state_space(env::StateOverriddenEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)
15 changes: 2 additions & 13 deletions src/environments/wrappers/StochasticEnv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ export StochasticEnv

using StatsBase: sample, Weights

struct StochasticEnv{E<:AbstractEnv,R} <: AbstractEnv
struct StochasticEnv{E <: AbstractEnv,R} <: AbstractEnvWrapper
env::E
rng::R
end

function StochasticEnv(env; rng = Random.GLOBAL_RNG)
function StochasticEnv(env; rng=Random.GLOBAL_RNG)
ChanceStyle(env) === EXPLICIT_STOCHASTIC ||
throw(ArgumentError("only environments of EXPLICIT_STOCHASTIC style is supported"))
env = StochasticEnv(env, rng)
Expand Down Expand Up @@ -39,14 +39,3 @@ RLBase.ChanceStyle(::StochasticEnv) = STOCHASTIC
RLBase.players(env::StochasticEnv) =
[p for p in players(env.env) if p != chance_player(env.env)]
Random.seed!(env::StochasticEnv, s) = Random.seed!(env.rng, s)

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
if f (:players, :ChanceStyle, :reset!)
@eval RLBase.$f(x::StochasticEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

RLBase.state(env::StochasticEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
RLBase.state_space(env::StochasticEnv, ss::RLBase.AbstractStateStyle) =
state_space(env.env, ss)
16 changes: 16 additions & 0 deletions src/environments/wrappers/wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
export AbstractEnvWrapper

abstract type AbstractEnvWrapper <: AbstractEnv end

Base.nameof(env::AbstractEnvWrapper) = "$(nameof(env.env)) |> $(nameof(typeof(env)))"

Base.getindex(env::AbstractEnvWrapper) = env.env

(env::AbstractEnvWrapper)(args...; kwargs...) = env.env(args...; kwargs...)

for f in vcat(RLBase.ENV_API, RLBase.MULTI_AGENT_ENV_API)
@eval RLBase.$f(x::AbstractEnvWrapper, args...; kwargs...) = $f(x[], args...; kwargs...)
end

# avoid ambiguous
RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, p) = state(env[], ss, p)
RLBase.state(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle) = state(env[], ss)
RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle) = state_space(env[], ss)
RLBase.state_space(env::AbstractEnvWrapper, ss::RLBase.AbstractStateStyle, p) = state_space(env[], ss, p)

include("ActionTransformedEnv.jl")
include("DefaultStateStyle.jl")
include("MaxTimeoutEnv.jl")
Expand Down

0 comments on commit 9e02182

Please sign in to comment.