diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl index 11d175833..10a17157c 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/multi_thread_env.jl @@ -131,7 +131,7 @@ function RLBase.is_terminated(env::MultiThreadEnv) end function RLBase.legal_action_space_mask(env::MultiThreadEnv) - N = ndims(env.states) + N = ndims(env.legal_action_space_mask) @sync for i in 1:length(env) @spawn selectdim(env.legal_action_space_mask, N, i) .= legal_action_space_mask(env[i])