Skip to content
This repository has been archived by the owner on Aug 11, 2023. It is now read-only.

Commit

Permalink
Enhance_ActionTransformedEnv (#88)
Browse files Browse the repository at this point in the history
* enhance_ActionTransformedEnv

* remove overhead with environments of identity maping
  • Loading branch information
findmyway authored Oct 19, 2020
1 parent 6ee5216 commit 00125d6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/implementations/environments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,34 @@ get_terminal(env::MaxTimeoutEnv) = (env.current_t > env.max_t) || get_terminal(e
# ActionTransformedEnv
#####

struct ActionTransformedEnv{P,E<:AbstractEnv} <: AbstractEnv
struct ActionTransformedEnv{P,M,E<:AbstractEnv} <: AbstractEnv
processors::P
mapping::M
env::E
end

# partial constructor to allow chaining
ActionTransformedEnv(processors...) = env -> ActionTransformedEnv(processors, env)
"""
ActionTransformedEnv(processors;mapping=identity)
`mapping` will be applied to the result of `get_actions(env)`.
`processors` will be applied to the `action` before sending it to the inner environment.
The same effect like `env(action |> processors)`.
"""
ActionTransformedEnv(processors...;mapping=identity) = env -> ActionTransformedEnv(processors, mapping, env)

for f in vcat(ENV_API, MULTI_AGENT_ENV_API)
@eval $f(x::ActionTransformedEnv, args...; kwargs...) = $f(x.env, args...; kwargs...)
if f (:get_actions, :get_legal_actions)
@eval $f(x::ActionTransformedEnv, args...; kwargs...) = $f(x.env, args...; kwargs...)
end
end

get_actions(env::ActionTransformedEnv{<:Any, typeof(identity)}, args...) = get_actions(env.env, args...)
get_legal_actions(env::ActionTransformedEnv{<:Any, typeof(identity)}, args...) = get_legal_actions(env.env, args...)

get_actions(env::ActionTransformedEnv, args...) = map(env.mapping, get_actions(env.env, args...))
get_legal_actions(env::ActionTransformedEnv, args...) = map(env.mapping, get_legal_actions(env.env, args...))

(env::ActionTransformedEnv)(action, args...; kwargs...) =
env.env(foldl(|>, env.processors; init = action), args...; kwargs...)

Expand Down
9 changes: 9 additions & 0 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,13 @@
reset!(env)
run(policy, env)
@test get_terminal(env)

discrete_env = env |> ActionTransformedEnv(
a -> get_actions(env)[a]; # action index to action
mapping = x -> Dict(x => i for (i, a) in enumerate(get_actions(env)))[x] # arbitrary vector to DiscreteSpace
)
policy = RandomPolicy(discrete_env)
reset!(discrete_env)
run(policy, discrete_env)
@test get_terminal(discrete_env)
end

0 comments on commit 00125d6

Please sign in to comment.