From 6e40776715e6dddbb0a2f49c4d05730716124d8d Mon Sep 17 00:00:00 2001 From: Johannes Fischer Date: Thu, 24 Aug 2023 18:09:19 +0200 Subject: [PATCH] Fix multi thread env state in case N=1 --- .../src/algorithms/policy_gradient/multi_thread_env.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl index 11d175833..649242f44 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl @@ -115,7 +115,12 @@ const MULTI_THREAD_ENV_CACHE = IdDict{AbstractEnv,Dict{Symbol,Array}}() function RLBase.state(env::MultiThreadEnv) N = ndims(env.states) @sync for i in 1:length(env) - @spawn selectdim(env.states, N, i) .= state(env[i]) + @spawn begin + if N == 1 + env.states[i] .= state(env[i]) + else + selectdim(env.states, N, i) .= state(env[i]) + end end env.states end @@ -167,7 +172,7 @@ function RLBase.plan!(π::QBasedPolicy, env::MultiThreadEnv, ::FullActionSet, A) ] end -function RLBase.plan!(π::QBasedPolicy, +function RLBase.plan!(π::QBasedPolicy, env::MultiThreadEnv, ::MinimalActionSet, ::Space{<:Vector{<:Base.OneTo{<:Integer}}},