Skip to content

Commit

Permalink
Reduce find_all_max allocations and increase speed based on chatgpt s… (
Browse files Browse the repository at this point in the history
#938)

* Reduce find_all_max allocations and increase speed based on chatgpt suggestion

* Use GPUArrays to dispatch find_all_max GPU fallback

* Drop excess function
  • Loading branch information
jeremiahpslewis authored Aug 7, 2023
1 parent 3d97a4f commit 75a6dab
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Expand All @@ -36,6 +37,7 @@ Distributions = "0.25"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13, 1"
Flux = "0.13, 0.14"
Functors = "0.1, 0.2, 0.3, 0.4"
GPUArrays = "8"
Parsers = "2"
ProgressMeter = "1"
Reexport = "1"
Expand Down
16 changes: 15 additions & 1 deletion src/ReinforcementLearningCore/src/utils/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export global_norm,
orthogonal

using FillArrays: Trues
using GPUArrays

#####
# Zygote
Expand Down Expand Up @@ -87,7 +88,20 @@ flatten_batch(x::AbstractArray) = reshape(x, size(x)[1:end-2]..., :)
# RLUtils
#####

function find_all_max(x)
function find_all_max(x::A) where {A <: AbstractArray}
v = maximum(x)
indices = Vector{Int}(undef, count(==(v), x))
j = 1
for i in eachindex(x)
if x[i] == v
indices[j] = i
j += 1
end
end
v, indices
end

function find_all_max(x::A) where {A <: AbstractGPUArray}
v = maximum(x)
v, findall(==(v), x)
end
Expand Down
4 changes: 3 additions & 1 deletion src/ReinforcementLearningCore/src/utils/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ export normlogpdf, mvnormlogpdf, diagnormlogpdf, mvnormkldivergence, diagnormkld
using Flux: unsqueeze
using LinearAlgebra

using GPUArrays

# watch https://github.com/JuliaStats/Distributions.jl/issues/1183
const log2π = log(2.0f0π)

Expand Down Expand Up @@ -70,7 +72,7 @@ end
Log-determinant of the Positive-Semi-Definite matrix A = L*U (cholesky lower and upper triangulars), given L or U.
Has a sign uncertainty for non PSD matrices.
"""
function logdetLorU(LorU::Union{A, LowerTriangular{T, A}, UpperTriangular{T, A}}) where {T, A <: CuArray}
function logdetLorU(LorU::Union{A, LowerTriangular{T, A}, UpperTriangular{T, A}}) where {T, A <: AbstractGPUArray}
return 2*sum(log.(diag(LorU)))
end

Expand Down

0 comments on commit 75a6dab

Please sign in to comment.