Skip to content
This repository has been archived by the owner on Aug 11, 2023. It is now read-only.

Commit

Permalink
Format .jl files (#83)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] authored Oct 22, 2020
1 parent 00125d6 commit 842837b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 12 deletions.
21 changes: 14 additions & 7 deletions src/implementations/environments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ mutable struct MaxTimeoutEnv{E<:AbstractEnv} <: AbstractEnv
current_t::Int
end

MaxTimeoutEnv(env::E, max_t::Int; current_t::Int = 1) where E<:AbstractEnv = MaxTimeoutEnv(E, max_t, current_t)
MaxTimeoutEnv(env::E, max_t::Int; current_t::Int = 1) where {E<:AbstractEnv} =
MaxTimeoutEnv(E, max_t, current_t)

function (env::MaxTimeoutEnv)(args...; kwargs...)
env.env(args...; kwargs...)
Expand Down Expand Up @@ -195,19 +196,25 @@ end
`processors` will be applied to the `action` before sending it to the inner environment.
The same effect like `env(action |> processors)`.
"""
ActionTransformedEnv(processors...;mapping=identity) = env -> ActionTransformedEnv(processors, mapping, env)
ActionTransformedEnv(processors...; mapping = identity) =
env -> ActionTransformedEnv(processors, mapping, env)

for f in vcat(ENV_API, MULTI_AGENT_ENV_API)
if f (:get_actions, :get_legal_actions)
@eval $f(x::ActionTransformedEnv, args...; kwargs...) = $f(x.env, args...; kwargs...)
@eval $f(x::ActionTransformedEnv, args...; kwargs...) =
$f(x.env, args...; kwargs...)
end
end

get_actions(env::ActionTransformedEnv{<:Any, typeof(identity)}, args...) = get_actions(env.env, args...)
get_legal_actions(env::ActionTransformedEnv{<:Any, typeof(identity)}, args...) = get_legal_actions(env.env, args...)
get_actions(env::ActionTransformedEnv{<:Any,typeof(identity)}, args...) =
get_actions(env.env, args...)
get_legal_actions(env::ActionTransformedEnv{<:Any,typeof(identity)}, args...) =
get_legal_actions(env.env, args...)

get_actions(env::ActionTransformedEnv, args...) = map(env.mapping, get_actions(env.env, args...))
get_legal_actions(env::ActionTransformedEnv, args...) = map(env.mapping, get_legal_actions(env.env, args...))
get_actions(env::ActionTransformedEnv, args...) =
map(env.mapping, get_actions(env.env, args...))
get_legal_actions(env::ActionTransformedEnv, args...) =
map(env.mapping, get_legal_actions(env.env, args...))

(env::ActionTransformedEnv)(action, args...; kwargs...) =
env.env(foldl(|>, env.processors; init = action), args...; kwargs...)
Expand Down
3 changes: 2 additions & 1 deletion src/implementations/policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ RandomPolicy(::FullActionSet, env::AbstractEnv, rng) = RandomPolicy(nothing, rng
# Ideally we should return a Categorical distribution.
# But this means we need to introduce an extra dependency of Distributions
# watch https://github.com/JuliaStats/Distributions.jl/issues/1139
get_prob(p::RandomPolicy{<:VectSpace}, env::MultiThreadEnv) = [fill(1/length(s), length(s)) for s in p.action_space]
get_prob(p::RandomPolicy{<:VectSpace}, env::MultiThreadEnv) =
[fill(1 / length(s), length(s)) for s in p.action_space]
get_prob(p::RandomPolicy, env::MultiThreadEnv) = [get_prob(p, x) for x in env]
get_prob(p::RandomPolicy, env) = fill(1 / length(p.action_space), length(p.action_space))
get_prob(p::RandomPolicy{Nothing}, env) = get_prob(p, env, ChanceStyle(env))
Expand Down
3 changes: 2 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ Get the probability distribution of actions based on policy `π` given an `env`.
Only valid for environments with discrete action space.
"""
@api get_prob::AbstractPolicy, env, action) = get_prob(π, env)[findfirst(==(action), get_actions(env))]
@api get_prob::AbstractPolicy, env, action) =
get_prob(π, env)[findfirst(==(action), get_actions(env))]

"""
get_priority(π::AbstractPolicy, experience)
Expand Down
7 changes: 4 additions & 3 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
run(policy, env)
@test get_terminal(env)

discrete_env = env |> ActionTransformedEnv(
a -> get_actions(env)[a]; # action index to action
mapping = x -> Dict(x => i for (i, a) in enumerate(get_actions(env)))[x] # arbitrary vector to DiscreteSpace
discrete_env =
env |> ActionTransformedEnv(
a -> get_actions(env)[a]; # action index to action
mapping = x -> Dict(x => i for (i, a) in enumerate(get_actions(env)))[x], # arbitrary vector to DiscreteSpace
)
policy = RandomPolicy(discrete_env)
reset!(discrete_env)
Expand Down

0 comments on commit 842837b

Please sign in to comment.