Skip to content
This repository has been archived by the owner on Aug 11, 2023. It is now read-only.

Commit

Permalink
add env that terminates after max threshold time (#85)
Browse files Browse the repository at this point in the history
* add env that terminates after max threshold time

* add partial constructor with optional keyword args

Co-authored-by: Jun Tian <[email protected]>
  • Loading branch information
Sid-Bhatia-0 and findmyway authored Oct 13, 2020
1 parent d2d82b4 commit 670b9e3
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion src/implementations/environments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ export SubjectiveEnv,
StateOverriddenEnv,
RewardOverriddenEnv,
ActionTransformedEnv,
StateCachedEnv
StateCachedEnv,
MaxTimeoutEnv

using MacroTools: @forward
using Random
Expand Down Expand Up @@ -147,6 +148,33 @@ end
get_reward(env::RewardOverriddenEnv, args...; kwargs...) =
foldl(|>, env.processors; init = get_reward(env.env, args...; kwargs...))


#####
# MaxTimeoutEnv
#####

mutable struct MaxTimeoutEnv{E<:AbstractEnv} <: AbstractEnv
env::E
max_t::Int
current_t::Int
end

function (env::MaxTimeoutEnv)(args...; kwargs...)
env.env(args...; kwargs...)
env.current_t = env.current_t + 1
end

# partial constructor to allow chaining
MaxTimeoutEnv(max_t::Int; current_t::Int = 1) = env -> MaxTimeoutEnv(env, max_t, current_t)

for f in vcat(ENV_API, MULTI_AGENT_ENV_API)
if f != :get_terminal
@eval $f(x::MaxTimeoutEnv, args...; kwargs...) = $f(x.env, args...; kwargs...)
end
end

get_terminal(env::MaxTimeoutEnv) = (env.current_t > env.max_t) || get_terminal(env.env)

#####
# ActionTransformedEnv
#####
Expand Down

0 comments on commit 670b9e3

Please sign in to comment.