Skip to content

Commit

Permalink
Merge branch 'master' into enhance_wrapped_distr
Browse files Browse the repository at this point in the history
  • Loading branch information
ParadaCarleton authored Dec 19, 2022
2 parents e9291c7 + 8c8cfc6 commit 078731f
Show file tree
Hide file tree
Showing 24 changed files with 1,497 additions and 477 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.20.2"
version = "0.21.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -30,6 +30,6 @@ Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
MacroTools = "0.5.6"
OrderedCollections = "1"
Setfield = "0.7.1, 0.8"
Setfield = "0.7.1, 0.8, 1"
ZygoteRules = "0.2"
julia = "1.6"
4 changes: 3 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
Distributions = "0.25"
Documenter = "0.27"
Setfield = "0.7.1, 0.8"
FillArrays = "0.13"
Setfield = "0.7.1, 0.8, 1"
StableRNGs = "1"
10 changes: 8 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@ using DynamicPPL
using DynamicPPL: AbstractPPL

# Doctest setup
DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true)
DocMeta.setdocmeta!(
DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true
)

makedocs(;
sitename="DynamicPPL",
format=Documenter.HTML(),
modules=[DynamicPPL],
pages=["Home" => "index.md", "API" => "api.md"],
pages=[
"Home" => "index.md",
"API" => "api.md",
"Tutorials" => ["tutorials/prob-interface.md"],
],
strict=true,
checkdocs=:exports,
)
Expand Down
35 changes: 33 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,23 +156,56 @@ AbstractVarInfo

### Common API

#### Accumulation of log-probabilities

```@docs
getlogp
setlogp!!
acclogp!!
resetlogp!!
```

#### Variables and their realizations

```@docs
keys
getindex
DynamicPPL.getindex_raw
push!!
empty!!
isempty
```

```@docs
values_as
```

#### Transformations

```@docs
DynamicPPL.AbstractTransformation
DynamicPPL.NoTransformation
DynamicPPL.DynamicTransformation
DynamicPPL.StaticTransformation
```

```@docs
DynamicPPL.istrans
DynamicPPL.settrans!!
DynamicPPL.transformation
DynamicPPL.link!!
DynamicPPL.invlink!!
DynamicPPL.default_transformation
DynamicPPL.maybe_invlink_before_eval!!
```

#### Utils

```@docs
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
```

#### `SimpleVarInfo`

```@docs
Expand All @@ -191,10 +224,8 @@ TypedVarInfo
One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.

```@docs
tonamedtuple
link!
invlink!
istrans
```

```@docs
Expand Down
File renamed without changes
98 changes: 98 additions & 0 deletions docs/src/tutorials/prob-interface.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# The Probability Interface

The easiest way to manipulate and query DynamicPPL models is via the DynamicPPL probability
interface.

Let's use a simple model of normally-distributed data as an example.
```@example probinterface
using DynamicPPL
using Distributions
using FillArrays
using LinearAlgebra
using Random
Random.seed!(1776) # Set seed for reproducibility
@model function gdemo(n)
μ ~ Normal(0, 1)
x ~ MvNormal(Fill(μ, n), I)
return nothing
end
nothing # hide
```

We generate some data using `μ = 0` and `σ = 1`:

```@example probinterface
dataset = randn(100)
nothing # hide
```

## Conditioning and Deconditioning

Bayesian models can be transformed with two main operations, conditioning and deconditioning (also known as marginalization).
Conditioning takes a variable and fixes its value as known.
We do this by passing a model and a named tuple of conditioned variables to `|`:
```@example probinterface
model = gdemo(length(dataset)) | (x=dataset, μ=0, σ=1)
nothing # hide
```

This operation can be reversed by applying `decondition`:
```@example probinterface
decondition(model)
nothing # hide
```

We can also decondition only some of the variables:
```@example probinterface
decondition(model, :μ)
nothing # hide
```

## Probabilities and Densities

We often want to calculate the (unnormalized) probability density for an event.
This probability might be a prior, a likelihood, or a posterior (joint) density.
DynamicPPL provides convenient functions for this.
For example, if we wanted to calculate the probability of a draw from the prior:
```@example probinterface
model = gdemo(length(dataset)) | (x=dataset,)
x1 = rand(model)
logjoint(model, x1)
```

For convenience, we provide the functions `loglikelihood` and `logjoint` to calculate probabilities for a named tuple, given a model:
```@example probinterface
@assert logjoint(model, x1) ≈ loglikelihood(model, x1) + logprior(model, x1)
```

## Example: Cross-validation

To give an example of the probability interface in use, we can use it to estimate the performance of our model using cross-validation. In cross-validation, we split the dataset into several equal parts. Then, we choose one of these sets to serve as the validation set. Here, we measure fit using the cross entropy (Bayes loss).¹
``` @example probinterface
function cross_val(model, dataset)
training_loss = zero(logjoint(model, rand(model)))
# Partition our dataset into 5 folds with 20 observations:
test_folds = collect(Iterators.partition(dataset, 20))
train_folds = setdiff.((dataset,), test_folds)
for (train, test) in zip(train_folds, test_folds)
# First, we train the model on the training set.
# For normally-distributed data, the posterior can be solved in closed form:
posterior = Normal(mean(train), 1)
# Sample from the posterior
samples = NamedTuple{(:μ,)}.(rand(posterior, 1000))
# Test
testing_model = gdemo(length(test)) | (x = test,)
training_loss += sum(samples) do sample
logjoint(testing_model, sample)
end
end
return training_loss
end
cross_val(model, dataset)
```

¹See [ParetoSmooth.jl](https://github.com/TuringLang/ParetoSmooth.jl) for a faster and more accurate implementation of cross-validation than the one provided here.
15 changes: 12 additions & 3 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ export AbstractVarInfo,
setorder!,
istrans,
link!,
link!!,
invlink!,
invlink!!,
tonamedtuple,
values_as,
# VarName (reexport from AbstractPPL)
Expand Down Expand Up @@ -126,27 +128,33 @@ export loglikelihood
# Used here and overloaded in Turing
function getspace end

# Necessary forward declarations
"""
AbstractVarInfo
Abstract supertype for data structures that capture random variables when executing a
probabilistic model and accumulate log densities such as the log likelihood or the
log joint probability of the model.
See also: [`VarInfo`](@ref)
See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref).
"""
abstract type AbstractVarInfo <: AbstractModelTrace end

const LEGACY_WARNING = """
!!! warning
This method is considered legacy, and is likely to be deprecated in the future.
"""

# Necessary forward declarations
include("utils.jl")
include("selector.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
Expand All @@ -155,5 +163,6 @@ include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")

end # module
Loading

0 comments on commit 078731f

Please sign in to comment.