From d6a39b5fdc7c185dbf9c402dd1706fe7778ec59c Mon Sep 17 00:00:00 2001 From: Jeremiah Lewis <4462211+jeremiahpslewis@users.noreply.github.com> Date: Mon, 4 Mar 2024 20:37:57 +0100 Subject: [PATCH] Fix type piracy --- .../src/policies/explorers/epsilon_greedy_explorer.jl | 8 ++++---- src/ReinforcementLearningCore/src/utils/basic.jl | 9 ++------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl b/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl index 6855810c8..22c387487 100644 --- a/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl +++ b/src/ReinforcementLearningCore/src/policies/explorers/epsilon_greedy_explorer.jl @@ -126,7 +126,7 @@ RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, x::Vector{I}, mask::Trues) w function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::Vector{I}, mask::M) where {I<:Real, M<:Union{BitVector, Vector{Bool}}} ϵ = get_ϵ(s) s.step += 1 - rand(s.rng) >= ϵ ? findmax(values, mask)[2] : rand(s.rng, findall(mask)) + rand(s.rng) >= ϵ ? findmax_masked(values, mask)[2] : rand(s.rng, findall(mask)) end ##### @@ -188,7 +188,7 @@ function RLBase.prob(s::EpsilonGreedyExplorer{<:Any,false}, values, mask) ϵ, n = get_ϵ(s), length(values) probs = zeros(n) probs[mask] .= ϵ / sum(mask) - probs[findmax(values, mask)[2]] += 1 - ϵ + probs[findmax_masked(values, mask)[2]] += 1 - ϵ Categorical(probs; check_args=false) end @@ -201,7 +201,7 @@ struct GreedyExplorer <: AbstractExplorer end RLBase.plan!(s::GreedyExplorer, x, mask::Trues) = s(x) RLBase.plan!(s::GreedyExplorer, values) = findmax(values)[2] -RLBase.plan!(s::GreedyExplorer, values, mask) = findmax(values, mask)[2] +RLBase.plan!(s::GreedyExplorer, values, mask) = findmax_masked(values, mask)[2] RLBase.prob(s::GreedyExplorer, values) = Categorical(onehot(findmax(values)[2], 1:length(values)); check_args=false) @@ -210,4 +210,4 @@ RLBase.prob(s::GreedyExplorer, values, action::Integer) = findmax(values)[2] == action ? 1.0 : 0.0 RLBase.prob(s::GreedyExplorer, values, mask) = - Categorical(onehot(findmax(values, mask)[2], length(values)); check_args=false) + Categorical(onehot(findmax_masked(values, mask)[2], length(values)); check_args=false) diff --git a/src/ReinforcementLearningCore/src/utils/basic.jl b/src/ReinforcementLearningCore/src/utils/basic.jl index b501367bc..578593e4a 100644 --- a/src/ReinforcementLearningCore/src/utils/basic.jl +++ b/src/ReinforcementLearningCore/src/utils/basic.jl @@ -113,16 +113,11 @@ function find_all_max(x, mask::AbstractVector{Bool}) v, [k for (m, k) in zip(mask, keys(x)) if m && x[k] == v] end -# !!! watch https://github.com/JuliaLang/julia/pull/35316#issuecomment-622629895 -# Base.findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain) -# _rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m) -# !!! type piracy -Base.findmax(A::AbstractVector{T}, mask::AbstractVector{Bool}) where {T} = +findmax_masked(A::AbstractVector{T}, mask::AbstractVector{Bool}) where {T} = findmax(ifelse.(mask, A, typemin(T))) -Base.findmax(A::AbstractVector, mask::Trues) = findmax(A) - +findmax_masked(A::AbstractVector, mask::Trues) = findmax(A) const VectorOrMatrix = Union{AbstractMatrix,AbstractVector}