diff --git a/src/implementations/environments.jl b/src/implementations/environments.jl index 8199ae9..03ba8c1 100644 --- a/src/implementations/environments.jl +++ b/src/implementations/environments.jl @@ -159,7 +159,8 @@ mutable struct MaxTimeoutEnv{E<:AbstractEnv} <: AbstractEnv current_t::Int end -MaxTimeoutEnv(env::E, max_t::Int; current_t::Int = 1) where E<:AbstractEnv = MaxTimeoutEnv(E, max_t, current_t) +MaxTimeoutEnv(env::E, max_t::Int; current_t::Int = 1) where {E<:AbstractEnv} = + MaxTimeoutEnv(E, max_t, current_t) function (env::MaxTimeoutEnv)(args...; kwargs...) env.env(args...; kwargs...) @@ -195,19 +196,25 @@ end `processors` will be applied to the `action` before sending it to the inner environment. The same effect like `env(action |> processors)`. """ -ActionTransformedEnv(processors...;mapping=identity) = env -> ActionTransformedEnv(processors, mapping, env) +ActionTransformedEnv(processors...; mapping = identity) = + env -> ActionTransformedEnv(processors, mapping, env) for f in vcat(ENV_API, MULTI_AGENT_ENV_API) if f ∉ (:get_actions, :get_legal_actions) - @eval $f(x::ActionTransformedEnv, args...; kwargs...) = $f(x.env, args...; kwargs...) + @eval $f(x::ActionTransformedEnv, args...; kwargs...) = + $f(x.env, args...; kwargs...) end end -get_actions(env::ActionTransformedEnv{<:Any, typeof(identity)}, args...) = get_actions(env.env, args...) -get_legal_actions(env::ActionTransformedEnv{<:Any, typeof(identity)}, args...) = get_legal_actions(env.env, args...) +get_actions(env::ActionTransformedEnv{<:Any,typeof(identity)}, args...) = + get_actions(env.env, args...) +get_legal_actions(env::ActionTransformedEnv{<:Any,typeof(identity)}, args...) = + get_legal_actions(env.env, args...) -get_actions(env::ActionTransformedEnv, args...) = map(env.mapping, get_actions(env.env, args...)) -get_legal_actions(env::ActionTransformedEnv, args...) = map(env.mapping, get_legal_actions(env.env, args...)) +get_actions(env::ActionTransformedEnv, args...) = + map(env.mapping, get_actions(env.env, args...)) +get_legal_actions(env::ActionTransformedEnv, args...) = + map(env.mapping, get_legal_actions(env.env, args...)) (env::ActionTransformedEnv)(action, args...; kwargs...) = env.env(foldl(|>, env.processors; init = action), args...; kwargs...) diff --git a/src/implementations/policies.jl b/src/implementations/policies.jl index 773731e..7d4b13c 100644 --- a/src/implementations/policies.jl +++ b/src/implementations/policies.jl @@ -40,7 +40,8 @@ RandomPolicy(::FullActionSet, env::AbstractEnv, rng) = RandomPolicy(nothing, rng # Ideally we should return a Categorical distribution. # But this means we need to introduce an extra dependency of Distributions # watch https://github.com/JuliaStats/Distributions.jl/issues/1139 -get_prob(p::RandomPolicy{<:VectSpace}, env::MultiThreadEnv) = [fill(1/length(s), length(s)) for s in p.action_space] +get_prob(p::RandomPolicy{<:VectSpace}, env::MultiThreadEnv) = + [fill(1 / length(s), length(s)) for s in p.action_space] get_prob(p::RandomPolicy, env::MultiThreadEnv) = [get_prob(p, x) for x in env] get_prob(p::RandomPolicy, env) = fill(1 / length(p.action_space), length(p.action_space)) get_prob(p::RandomPolicy{Nothing}, env) = get_prob(p, env, ChanceStyle(env)) diff --git a/src/interface.jl b/src/interface.jl index dae6272..04408d4 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -46,7 +46,8 @@ Get the probability distribution of actions based on policy `π` given an `env`. Only valid for environments with discrete action space. """ -@api get_prob(π::AbstractPolicy, env, action) = get_prob(π, env)[findfirst(==(action), get_actions(env))] +@api get_prob(π::AbstractPolicy, env, action) = + get_prob(π, env)[findfirst(==(action), get_actions(env))] """ get_priority(π::AbstractPolicy, experience) diff --git a/test/base.jl b/test/base.jl index c71b127..d38cc30 100644 --- a/test/base.jl +++ b/test/base.jl @@ -7,9 +7,10 @@ run(policy, env) @test get_terminal(env) - discrete_env = env |> ActionTransformedEnv( - a -> get_actions(env)[a]; # action index to action - mapping = x -> Dict(x => i for (i, a) in enumerate(get_actions(env)))[x] # arbitrary vector to DiscreteSpace + discrete_env = + env |> ActionTransformedEnv( + a -> get_actions(env)[a]; # action index to action + mapping = x -> Dict(x => i for (i, a) in enumerate(get_actions(env)))[x], # arbitrary vector to DiscreteSpace ) policy = RandomPolicy(discrete_env) reset!(discrete_env)