Skip to content

Commit

Permalink
Fix type piracy
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiahpslewis committed Mar 4, 2024
1 parent 2570f48 commit d6a39b5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, x::Vector{I}, mask::Trues) w
function RLBase.plan!(s::EpsilonGreedyExplorer{<:Any,false}, values::Vector{I}, mask::M) where {I<:Real, M<:Union{BitVector, Vector{Bool}}}
ϵ = get_ϵ(s)
s.step += 1
rand(s.rng) >= ϵ ? findmax(values, mask)[2] : rand(s.rng, findall(mask))
rand(s.rng) >= ϵ ? findmax_masked(values, mask)[2] : rand(s.rng, findall(mask))
end

#####
Expand Down Expand Up @@ -188,7 +188,7 @@ function RLBase.prob(s::EpsilonGreedyExplorer{<:Any,false}, values, mask)
ϵ, n = get_ϵ(s), length(values)
probs = zeros(n)
probs[mask] .= ϵ / sum(mask)
probs[findmax(values, mask)[2]] += 1 - ϵ
probs[findmax_masked(values, mask)[2]] += 1 - ϵ
Categorical(probs; check_args=false)
end

Expand All @@ -201,7 +201,7 @@ struct GreedyExplorer <: AbstractExplorer end
RLBase.plan!(s::GreedyExplorer, x, mask::Trues) = s(x)

RLBase.plan!(s::GreedyExplorer, values) = findmax(values)[2]
RLBase.plan!(s::GreedyExplorer, values, mask) = findmax(values, mask)[2]
RLBase.plan!(s::GreedyExplorer, values, mask) = findmax_masked(values, mask)[2]

RLBase.prob(s::GreedyExplorer, values) =
Categorical(onehot(findmax(values)[2], 1:length(values)); check_args=false)
Expand All @@ -210,4 +210,4 @@ RLBase.prob(s::GreedyExplorer, values, action::Integer) =
findmax(values)[2] == action ? 1.0 : 0.0

RLBase.prob(s::GreedyExplorer, values, mask) =
Categorical(onehot(findmax(values, mask)[2], length(values)); check_args=false)
Categorical(onehot(findmax_masked(values, mask)[2], length(values)); check_args=false)
9 changes: 2 additions & 7 deletions src/ReinforcementLearningCore/src/utils/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,11 @@ function find_all_max(x, mask::AbstractVector{Bool})
v, [k for (m, k) in zip(mask, keys(x)) if m && x[k] == v]
end

# !!! watch https://github.com/JuliaLang/julia/pull/35316#issuecomment-622629895
# Base.findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)
# _rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m)

# !!! type piracy
Base.findmax(A::AbstractVector{T}, mask::AbstractVector{Bool}) where {T} =
findmax_masked(A::AbstractVector{T}, mask::AbstractVector{Bool}) where {T} =
findmax(ifelse.(mask, A, typemin(T)))

Base.findmax(A::AbstractVector, mask::Trues) = findmax(A)

findmax_masked(A::AbstractVector, mask::Trues) = findmax(A)

const VectorOrMatrix = Union{AbstractMatrix,AbstractVector}

Expand Down

0 comments on commit d6a39b5

Please sign in to comment.