Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,45 @@
# DynamicPPL Changelog

## 0.38.0

**Breaking changes**

foo

**Other changes**

### Thread-safe execution

This release removes `ThreadSafeVarInfo`, which was a construction used to ensure thread-safe accumulation of log-likelihood terms using the `Threads.@threads`.
However, `Threads.@threads` is no longer the recommended way to perform multithreaded tasks: see e.g. [this Julia blog post](https://julialang.org/blog/2023/07/PSA-dont-use-threadid/).

In its place a new macro, `@pobserve` is introduced, which under the hood uses `Threads.@spawn`.
**From a user's point of view you simply need to replace `Threads.@threads` with `@pobserve`.**
For example, here the likelihood contributions for each element of `y` are calculated in parallel:

```julia
@model function f(y)
mu ~ Normal()
yplusones = @pobserve for i in eachindex(y)
y[i] ~ Normal(mu)
return y[i] + 1
end
end
```

Furthermore, the `@pobserve` block will also return the final value inside the block, so can also be used to parallelise arbitrary computation. In the model above, `yplusones` will be a vector of length `y` where each element is `y[i] + 1`.

Please note that this only works for **likelihood terms**, i.e., observed variables (hence the macro name).
It is a long-term goal to be able to parallelise other parts of model execution such as the sampling of new variables, but this is not presently possible.

`@pobserve` is also not currently compatible with Turing's particle samplers (because Libtask.jl does not work with `Threads.@spawn)`.
This is, in fact, a good thing, because the previous behaviour of particle samplers with `Threads.@threads` was to silently give a wrong result.

### Other minor changes

The `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl.
Their behaviour is otherwise identical.

## 0.37.3

Prevents inlining of `DynamicPPL.istrans` with Enzyme, which allows Enzyme to differentiate models where `VarName`s have the same symbol but different types.
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.37.3"
version = "0.38.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -46,7 +46,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.13"
AbstractPPL = "0.13.1"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ DynamicPPL = {path = "../"}
ADTypes = "1.14.0"
BenchmarkTools = "1.6.0"
Distributions = "0.25.117"
DynamicPPL = "0.37"
DynamicPPL = "0.38"
ForwardDiff = "0.10.38, 1"
LogDensityProblems = "2.1.2"
Mooncake = "0.4"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Accessors = "0.1"
Distributions = "0.25"
Documenter = "1"
DocumenterMermaid = "0.1, 0.2"
DynamicPPL = "0.37"
DynamicPPL = "0.38"
FillArrays = "0.13, 1"
ForwardDiff = "0.10, 1"
JET = "0.9, 0.10"
Expand Down
41 changes: 33 additions & 8 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ It is possible to manually increase (or decrease) the accumulated log likelihood
@addlogprob!
```

If you want to perform observations in parallel (using Julia threads), you can use the following macro.

```@docs
@pobserve
```

Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`.

```@docs
Expand Down Expand Up @@ -435,8 +441,6 @@ DynamicPPL.maybe_invlink_before_eval!!
Base.merge(::AbstractVarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.varname_leaves
DynamicPPL.varname_and_value_leaves
```

### Evaluation Contexts
Expand All @@ -449,11 +453,6 @@ AbstractPPL.evaluate!!

This method mutates the `varinfo` used for execution.
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:

```@docs
DynamicPPL.evaluate_and_sample!!
```

The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
Contexts are subtypes of `AbstractPPL.AbstractContext`.
Expand All @@ -463,6 +462,32 @@ SamplingContext
DefaultContext
PrefixContext
ConditionContext
InitContext
```

### VarInfo initialisation

The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.

```@docs
DynamicPPL.init!!
```

To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.
There are three concrete strategies provided in DynamicPPL:

```@docs
InitFromPrior
InitFromUniform
InitFromParams
```

If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method.

```@docs
DynamicPPL.AbstractInitStrategy
DynamicPPL.init
```

### Samplers
Expand All @@ -486,7 +511,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu
```@docs
DynamicPPL.initialstep
DynamicPPL.loadstate
DynamicPPL.initialsampler
DynamicPPL.init_strategy
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.
Expand Down
43 changes: 25 additions & 18 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using JET: JET
function DynamicPPL.Experimental.is_suitable_varinfo(
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true
)
# Let's make sure that both evaluation and sampling doesn't result in type errors.
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo)
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
# This way we don't just fall back to untyped if the user's code is the issue.
Expand All @@ -21,32 +20,40 @@ end
function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model; only_ddpl::Bool=true
)
# Use SamplingContext to test type stability.
sampling_model = DynamicPPL.contextualize(
model, DynamicPPL.SamplingContext(model.context)
)

# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(sampling_model)
# Generate a typed varinfo to test model type stability with
varinfo = DynamicPPL.typed_varinfo(model)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
sampling_model, varinfo; only_ddpl
# Check type stability of evaluation (i.e. DefaultContext)
model = DynamicPPL.contextualize(
model, DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext())
)
eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo(
model, varinfo; only_ddpl
)
if !eval_issuccess
@debug "Evaluation with typed varinfo failed with the following issues:"
@debug eval_result
end

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug result
# Check type stability of initialisation (i.e. InitContext)
model = DynamicPPL.contextualize(
model, DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext())
)
init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo(
model, varinfo; only_ddpl
)
if !init_issuccess
@debug "Initialisation with typed varinfo failed with the following issues:"
@debug init_result
end

# If we didn't fail anywhere, we return the type stable one.
return if issuccess
# If neither of them failed, we can return the typed varinfo as it's type stable.
return if (eval_issuccess && init_issuccess)
varinfo
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(sampling_model)
DynamicPPL.untyped_varinfo(model)
end
end

Expand Down
53 changes: 34 additions & 19 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
module DynamicPPLMCMCChainsExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
else
using ..DynamicPPL: DynamicPPL
using ..MCMCChains: MCMCChains
end
using DynamicPPL: DynamicPPL, AbstractPPL
using MCMCChains: MCMCChains

# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
function DynamicPPL.loadstate(chain::MCMCChains.Chains)
Expand All @@ -28,7 +23,7 @@ end

function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using `VarName`s.")
error("This `Chains` object does not support indexing using `VarName`s.")
end

function DynamicPPL.getindex_varname(
Expand All @@ -42,6 +37,17 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

function chain_sample_to_varname_dict(
c::MCMCChains.Chains{Tval}, sample_idx, chain_idx
) where {Tval}
_check_varname_indexing(c)
d = Dict{DynamicPPL.VarName,Tval}()
for vn in DynamicPPL.varnames(c)
d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx)
end
return d
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down Expand Up @@ -114,14 +120,20 @@ function DynamicPPL.predict(

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo))

# Extract values from the chain
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`
_, varinfo = DynamicPPL.init!!(
rng,
model,
varinfo,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
vcat,
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
)

return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
Expand Down Expand Up @@ -248,13 +260,16 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702.
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to the `model`.
model(deepcopy(varinfo))
# Extract values from the chain
values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx)
# Resample any variables that are not present in `values_dict`, and
# return the model's retval.
retval, _ = DynamicPPL.init!!(
model,
varinfo,
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
)
retval
end
end

Expand Down
12 changes: 10 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ export AbstractVarInfo,
ConditionContext,
assume,
tilde_assume,
# Initialisation
InitContext,
AbstractInitStrategy,
InitFromPrior,
InitFromUniform,
InitFromParams,
# Pseudo distributions
NamedDist,
NoDist,
Expand All @@ -127,6 +133,7 @@ export AbstractVarInfo,
to_submodel,
# Convenience macros
@addlogprob!,
@pobserve,
value_iterator_from_chain,
check_model,
check_model_and_trace,
Expand Down Expand Up @@ -169,21 +176,22 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
# Necessary forward declarations
include("utils.jl")
include("chains.jl")
include("contexts.jl")
include("contexts/init.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("submodel.jl")
include("varnamedvector.jl")
include("accumulators.jl")
include("default_accumulators.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
include("pobserve_macro.jl")
include("pointwise_logdensities.jl")
include("transforming.jl")
include("logdensityfunction.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ function assume(
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::VarInfoOrThreadSafeVarInfo,
vi::VarInfo,
)
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
Expand Down
Loading
Loading