[WIP] Sketch for a wrapper for Distribution
that enables batching and GPU sampling
#22
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I started this PR trying to address the lack of GPU support for distributions, as initially mentioned in issue #12. There are two strategies discussed to resolve this issue:
Distribution
.Distribution
s that can work on GPU.The latter approach, being less code-intensive, forms the basis of this PR.
The concept of shapes for distributions used in the
Tensorflow Probability
,PyMC
, andPyro
packages is employed here. These includeEvent Shape
,Sample Shape
andBatch Shape
.Event Shape
is essentiallylength(d::Distribution)
.Sample Shape
is explicit when callrand
function. 'Batch Shape' is particularly significant for this implementation.BatchDistributionWrapper
is used for dispatch to specific implementatino of functions, these function may target GPU for high performance. By using the type of parameter arrays as surrogates for the device type (Array
for CPU andCuArray
for GPU), we can facilitate dispatching based on theDistribution
type and the relevant device type.Here's the way
BatchDistributionWrapper
is defined:rand
function can be implemented asTest example for
BatchDistributionWrapper
:returns
Some further considerations
my_dist
that requires two parameters for construction. If the first parameter is a scalar and the second is a vector, calls such asmy_dist(rand(2), rand(2, 2))
andmy_dist(rand(2, 1), rand(2, 2))
should both be valid and produce identical results. As of now, our implementation only supports cases where all parameters have the same number of dimensions.bijectors
, andtransformed distributions