From 00125d67da2d0a07afffed0dee45b56e4e53e4e2 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Mon, 19 Oct 2020 13:15:51 +0800 Subject: [PATCH] Enhance_ActionTransformedEnv (#88) * enhance_ActionTransformedEnv * remove overhead with environments of identity maping --- src/implementations/environments.jl | 22 +++++++++++++++++++--- test/base.jl | 9 +++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/implementations/environments.jl b/src/implementations/environments.jl index 361ac15..8199ae9 100644 --- a/src/implementations/environments.jl +++ b/src/implementations/environments.jl @@ -181,18 +181,34 @@ get_terminal(env::MaxTimeoutEnv) = (env.current_t > env.max_t) || get_terminal(e # ActionTransformedEnv ##### -struct ActionTransformedEnv{P,E<:AbstractEnv} <: AbstractEnv +struct ActionTransformedEnv{P,M,E<:AbstractEnv} <: AbstractEnv processors::P + mapping::M env::E end # partial constructor to allow chaining -ActionTransformedEnv(processors...) = env -> ActionTransformedEnv(processors, env) +""" + ActionTransformedEnv(processors;mapping=identity) + +`mapping` will be applied to the result of `get_actions(env)`. +`processors` will be applied to the `action` before sending it to the inner environment. +The same effect like `env(action |> processors)`. +""" +ActionTransformedEnv(processors...;mapping=identity) = env -> ActionTransformedEnv(processors, mapping, env) for f in vcat(ENV_API, MULTI_AGENT_ENV_API) - @eval $f(x::ActionTransformedEnv, args...; kwargs...) = $f(x.env, args...; kwargs...) + if f ∉ (:get_actions, :get_legal_actions) + @eval $f(x::ActionTransformedEnv, args...; kwargs...) = $f(x.env, args...; kwargs...) + end end +get_actions(env::ActionTransformedEnv{<:Any, typeof(identity)}, args...) = get_actions(env.env, args...) +get_legal_actions(env::ActionTransformedEnv{<:Any, typeof(identity)}, args...) = get_legal_actions(env.env, args...) + +get_actions(env::ActionTransformedEnv, args...) = map(env.mapping, get_actions(env.env, args...)) +get_legal_actions(env::ActionTransformedEnv, args...) = map(env.mapping, get_legal_actions(env.env, args...)) + (env::ActionTransformedEnv)(action, args...; kwargs...) = env.env(foldl(|>, env.processors; init = action), args...; kwargs...) diff --git a/test/base.jl b/test/base.jl index 21a3747..c71b127 100644 --- a/test/base.jl +++ b/test/base.jl @@ -6,4 +6,13 @@ reset!(env) run(policy, env) @test get_terminal(env) + + discrete_env = env |> ActionTransformedEnv( + a -> get_actions(env)[a]; # action index to action + mapping = x -> Dict(x => i for (i, a) in enumerate(get_actions(env)))[x] # arbitrary vector to DiscreteSpace + ) + policy = RandomPolicy(discrete_env) + reset!(discrete_env) + run(policy, discrete_env) + @test get_terminal(discrete_env) end