Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Sketch for a wrapper for Distribution that enables batching and GPU sampling #22

Closed
wants to merge 3 commits into from

Conversation

sunxd3
Copy link
Member

@sunxd3 sunxd3 commented Jul 30, 2023

I started this PR trying to address the lack of GPU support for distributions, as initially mentioned in issue #12. There are two strategies discussed to resolve this issue:

  1. Create a specific GPU version for each Distribution.
  2. Develop a wrapper for Distributions that can work on GPU.

The latter approach, being less code-intensive, forms the basis of this PR.

The concept of shapes for distributions used in the Tensorflow Probability, PyMC, and Pyro packages is employed here. These include Event Shape, Sample Shape and Batch Shape. Event Shape is essentially length(d::Distribution). Sample Shape is explicit when call rand function. 'Batch Shape' is particularly significant for this implementation.

BatchDistributionWrapper is used for dispatch to specific implementatino of functions, these function may target GPU for high performance. By using the type of parameter arrays as surrogates for the device type (Array for CPU and CuArray for GPU), we can facilitate dispatching based on the Distribution type and the relevant device type.

Here's the way BatchDistributionWrapper is defined:

struct BatchDistributionWrapper{D<:Distribution, T<:AbstractArray}
    distribution::Type{D}
    parameters::NTuple{M, T} where M
    batch_shape::NTuple{N, Int} where N
end
@functor BatchDistributionWrapper (parameters, )

function BatchDistributionWrapper(dist::Symbol, params, batch_shape=())
    if any((iszerondims), params) # if any of the parameters is a scalar, return the Distribution
        return getfield(Distributions, dist)(params...)
    end

    @assert isdefined(Distributions, dist) "Distribution $dist is not defined"
    
    @assert all(map(x->eltype(x) != Any, params)) "all parameters should have concrete element type"
    @assert all(map(x->eltype(x) == eltype(params[1]), params)) "all parameters should have the same element type"

    D = getfield(Distributions, dist)
    return BatchDistributionWrapper{D, eltype(params)}(D, params, batch_shape)
end

rand function can be implemented as

function Random.rand(rng::AbstractRNG, d::BatchDistributionWrapper{D, T}) where {D<:Normal, T<:CuArray}
    μ, σ = d.parameters
    x = similar(μ)
    CUDA.rand!(x)
    x .*= σ
    x .+= μ
    return x
end

Test example for BatchDistributionWrapper:

d = BatchDistributionWrapper(:Normal, (rand(2, 2, 1), rand(2, 2, 1)), (2, 2))
rand(gpu(d))

returns

2×2×1 CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}:
[:, :, 1] =
 0.377751  0.249488
 0.769026  0.433008

Some further considerations

  • The constructor as currently implemented is somewhat simplistic, and there are several areas where it could potentially be improved:
    • Implicit Broadcasting: Consider a distribution my_dist that requires two parameters for construction. If the first parameter is a scalar and the second is a vector, calls such as my_dist(rand(2), rand(2, 2)) and my_dist(rand(2, 1), rand(2, 2)) should both be valid and produce identical results. As of now, our implementation only supports cases where all parameters have the same number of dimensions.
    • Batch Shape Specification: It may be possible to eliminate the necessity of specifying the batch shape. This, however, would require prior knowledge of the size of the inputs to the distribution constructor.
  • Similar concept could also be expanded to include bijectors, and transformed distributions

@Red-Portal
Copy link
Member

Can't wait to see this happen! But isn't BatchDistributionWrapper quite a mouthful? How about something like BatchDist or BatchProduct? The latter tries to draw some weak connections to Distributions.Product.

@sunxd3
Copy link
Member Author

sunxd3 commented Aug 16, 2023

@Red-Portal yeah, I agree with the name thing. But for now, we are not pursuing the approach here, but https://github.com/TuringLang/NormalizingFlows.jl/blob/torfjelde/cuda/ext/NormalizingFlowsCUDAExt.jl

@sunxd3
Copy link
Member Author

sunxd3 commented Aug 17, 2023

StructArrays is a good alternative.

using StructArrays, Distributions
μ1 = [0.0, 0.0]
Σ1 = [1.0 0.5; 0.5 1.0]
μ2 = [1.0, 2.0]
Σ2 = [2.0 1.0; 1.0 2.0]

s = StructArray([MvNormal(μ1, Σ1), MvNormal(μ2, Σ2)])
julia> s.Σ 
2-element Vector{PDMats.PDMat{Float64, Matrix{Float64}}}:
 [1.0 0.5; 0.5 1.0]
 [2.0 1.0; 1.0 2.0]

julia> s.μ
2-element Vector{Vector{Float64}}:
 [0.0, 0.0]
 [1.0, 2.0]

julia> typeof(s)
StructVector{FullNormal, NamedTuple{(, ), Tuple{Vector{Vector{Float64}}, Vector{PDMat{Float64, Matrix{Float64}}}}}, Int64} (alias for StructArray{MvNormal{Float64, PDMats.PDMat{Float64, Array{Float64, 2}}, Array{Float64, 1}}, 1, NamedTuple{(, ), Tuple{Array{Array{Float64, 1}, 1}, Array{PDMats.PDMat{Float64, Array{Float64, 2}}, 1}}}, Int64})

julia> rand.(s)
2-element Vector{Vector{Float64}}:
 [0.5337271429228262, -0.10615266026389553]
 [2.516493294473824, 4.66390606518843]

Many function should just work through the broadcasting.

And if we want, we can still dispatch on Distribution and array type.

@sunxd3 sunxd3 closed this Feb 27, 2024
@sunxd3 sunxd3 deleted the batch_disribution branch February 29, 2024 13:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants