Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more extensive usage docs #63

Merged
merged 14 commits into from
May 4, 2022
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Pathfinder"
uuid = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
authors = ["Seth Axen <seth.axen@gmail.com> and contributors"]
version = "0.4.1"
version = "0.4.2"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Pathfinder
# Pathfinder.jl: Parallel quasi-Newton variational inference

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://sethaxen.github.io/Pathfinder.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://sethaxen.github.io/Pathfinder.jl/dev)
18 changes: 16 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
[deps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
AdvancedHMC = "0.3"
Documenter = "0.27"
DynamicHMC = "3"
ForwardDiff = "0.10"
LogDensityProblems = "0.11"
Pathfinder = "0.4"
Plots = "1"
StatsFuns = "0.9"
StatsPlots = "0.14"
TransformVariables = "0.6"
Turing = "0.21"
11 changes: 9 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -15,8 +15,15 @@ makedocs(;
),
pages=[
"Home" => "index.md",
"Single-path Pathfinder" => "pathfinder.md",
"Multi-path Pathfinder" => "multipathfinder.md",
"Library" => [
"Public" => "lib/public.md",
"Internals" => "lib/internals.md",
],
"Examples" => [
"Quickstart" => "examples/quickstart.md",
"Initializing HMC" => "examples/initializing-hmc.md",
"Turing usage" => "examples/turing.md",
]
],
)

239 changes: 239 additions & 0 deletions docs/src/examples/initializing-hmc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Initializing HMC with Pathfinder

## The MCMC warm-up phase

When using MCMC to draw samples from some target distribution, there is often a lengthy warm-up phase with 2 phases:
1. converge to the _typical set_ (the region of the target distribution where the bulk of the probability mass is located)
2. adapt any tunable parameters of the MCMC sampler (optional)

While (1) often happens fairly quickly, (2) usually requires a lengthy exploration of the typical set to iteratively adapt parameters suitable for further exploration.
An example is the widely used windowed adaptation scheme of Hamiltonian Monte Carlo (HMC) in Stan, where a step size and positive definite metric (aka mass matrix) are adapted.[^1]
For posteriors with complex geometry, the adaptation phase can require many evaluations of the gradient of the log density function of the target distribution.

Pathfinder can be used to initialize MCMC, and in particular HMC, in 3 ways:
1. Initialize MCMC from one of Pathfinder's draws (replace phase 1 of the warm-up).
2. Initialize an HMC metric adaptation from the inverse of the covariance of the multivariate normal approximation (replace the first window of phase 2 of the warm-up).
3. Use the inverse of the covariance as the metric without further adaptation (replace phase 2 of the warm-up).

This tutorial demonstrates all three approaches with [DynamicHMC.jl](https://tamaspapp.eu/DynamicHMC.jl/stable/) and [AdvancedHMC.jl](https://github.com/TuringLang/AdvancedHMC.jl).
Both of these packages have standalone implementations of adaptive HMC (aka NUTS) and can be used independently of any probabilistic programming language (PPL).
Both the [Turing](https://turing.ml/stable/) and [Soss](https://github.com/cscherrer/Soss.jl) PPLs have some DynamicHMC integration, while Turing also integrates with AdvancedHMC.

For demonstration purposes, we'll use the following dummy data:

```@example 1
using LinearAlgebra, Pathfinder, Random, StatsFuns, StatsPlots

Random.seed!(91)

x = 0:0.01:1
y = @. sin(10x) + randn() * 0.2 + x

scatter(x, y; xlabel="x", ylabel="y", legend=false, msw=0, ms=2)
```

We'll fit this using a simple polynomial regression model:

```math
\begin{aligned}
\sigma &\sim \text{Half-Normal}(\mu=0, \sigma=1)\\
\alpha, \beta_j &\sim \mathrm{Normal}(\mu=0, \sigma=1)\\
\hat{y}_i &= \alpha + \sum_{j=1}^J x_i^j \beta_j\\
y_i &\sim \mathrm{Normal}(\mu=\hat{y}_i, \sigma=\sigma)
\end{aligned}
```

We just need to implement the log-density function of the resulting posterior.

```@example 1
struct RegressionProblem{X,Z,Y}
x::X
J::Int
z::Z
y::Y
end
function RegressionProblem(x, J, y)
z = x .* (1:J)'
return RegressionProblem(x, J, z, y)
end

function (prob::RegressionProblem)(θ)
σ = θ.σ
α = θ.α
β = θ.β
z = prob.z
y = prob.y
lp = normlogpdf(σ) + logtwo
lp += normlogpdf(α)
lp += sum(normlogpdf, β)
y_hat = muladd(z, β, α)
lp += sum(eachindex(y_hat, y)) do i
return normlogpdf(y_hat[i], σ, y[i])
end
return lp
end

J = 3
dim = J + 2
model = RegressionProblem(x, J, y)
ndraws = 1_000;
nothing # hide
```

## DynamicHMC.jl

To use DynamicHMC, we first need to transform our model to an unconstrained space using [TransformVariables.jl](https://tamaspapp.eu/TransformVariables.jl/stable/) and wrap it in a type that implements the [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface:

```@example 1
using DynamicHMC, LogDensityProblems, TransformVariables

transform = as((σ=asℝ₊, α=asℝ, β=as(Array, J)))
P = TransformedLogDensity(transform, model)
∇P = ADgradient(:ForwardDiff, P)
```

Pathfinder, on the other hand, expects a log-density function:

```@example 1
logp(x) = LogDensityProblems.logdensity(P, x)
∇logp(x) = LogDensityProblems.logdensity_and_gradient(∇P, x)[2]
result_pf = pathfinder(logp, ∇logp; dim)
```

```@example 1
init_params = result_pf.draws[:, 1]
```

```@example 1
inv_metric = result_pf.fit_distribution.Σ
```

### Initializing from Pathfinder's draws

Here we just need to pass one of the draws as the initial point `q`:

```@example 1
result_dhmc1 = mcmc_with_warmup(
Random.GLOBAL_RNG,
∇P,
ndraws;
initialization=(; q=init_params),
reporter=NoProgressReport(),
)
```

### Initializing metric adaptation from Pathfinder's estimate

To start with Pathfinder's inverse metric estimate, we just need to initialize a `GaussianKineticEnergy` object with it as input:

```@example 1
result_dhmc2 = mcmc_with_warmup(
Random.GLOBAL_RNG,
∇P,
ndraws;
initialization=(; q=init_params, κ=GaussianKineticEnergy(inv_metric)),
warmup_stages=default_warmup_stages(; M=Symmetric),
reporter=NoProgressReport(),
)
```

We also specified that DynamicHMC should tune a dense `Symmetric` matrix.
However, we could also have requested a `Diagonal` metric.

### Use Pathfinder's metric estimate for sampling

To turn off metric adaptation entirely and use Pathfinder's estimate, we just set the number and size of the metric adaptation windows to 0.

```@example 1
result_dhmc3 = mcmc_with_warmup(
Random.GLOBAL_RNG,
∇P,
ndraws;
initialization=(; q=init_params, κ=GaussianKineticEnergy(inv_metric)),
warmup_stages=default_warmup_stages(; middle_steps=0, doubling_stages=0),
reporter=NoProgressReport(),
)
```

## AdvancedHMC.jl

Similar to Pathfinder, AdvancedHMC works with an unconstrained log density function and its gradient.
We'll just use the `logp` we already created above.

```@example 1
using AdvancedHMC, ForwardDiff

nadapts = 500;
nothing # hide
```

### Initializing from Pathfinder's draws

```@example 1
metric = DiagEuclideanMetric(dim)
hamiltonian = Hamiltonian(metric, logp, ForwardDiff)
ϵ = find_good_stepsize(hamiltonian, init_params)
integrator = Leapfrog(ϵ)
proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
adaptor = StepSizeAdaptor(0.8, integrator)
samples_ahmc1, stats_ahmc1 = sample(
hamiltonian,
proposal,
init_params,
ndraws + nadapts,
adaptor,
nadapts;
drop_warmup=true,
progress=false,
)
```

### Initializing metric adaptation from Pathfinder's estimate

We can't start with Pathfinder's inverse metric estimate directly.
Instead we need to first extract its diagonal for a `DiagonalEuclideanMetric` or make it dense for a `DenseEuclideanMetric`:

```@example 1
metric = DenseEuclideanMetric(Matrix(inv_metric))
hamiltonian = Hamiltonian(metric, logp, ForwardDiff)
ϵ = find_good_stepsize(hamiltonian, init_params)
integrator = Leapfrog(ϵ)
proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
adaptor = StepSizeAdaptor(0.8, integrator)
samples_ahmc2, stats_ahmc2 = sample(
hamiltonian,
proposal,
init_params,
ndraws + nadapts,
adaptor,
nadapts;
drop_warmup=true,
progress=false,
)
```

### Use Pathfinder's metric estimate for sampling

To use Pathfinder's metric with no metric adaptation, we need to use Pathfinder's own `RankUpdateEuclideanMetric` type, which just wraps our inverse metric estimate for use with AdvancedHMC:

```@example 1
nadapts = 75
metric = Pathfinder.RankUpdateEuclideanMetric(inv_metric)
hamiltonian = Hamiltonian(metric, logp, ForwardDiff)
ϵ = find_good_stepsize(hamiltonian, init_params)
integrator = Leapfrog(ϵ)
proposal = NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
adaptor = StepSizeAdaptor(0.8, integrator)
samples_ahmc3, stats_ahmc3 = sample(
hamiltonian,
proposal,
init_params,
ndraws + nadapts,
adaptor,
nadapts;
drop_warmup=true,
progress=false,
)
```

[^1]: https://mc-stan.org/docs/reference-manual/hmc-algorithm-parameters.html
176 changes: 176 additions & 0 deletions docs/src/examples/quickstart.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Quickstart

This page introduces basic Pathfinder usage with examples.

## A 5-dimensional multivariate normal

For a simple example, we'll run Pathfinder on a multivariate normal distribution with
a dense covariance matrix.

```@example 1
using LinearAlgebra, Pathfinder, Printf, StatsPlots, Random
Random.seed!(42)
Σ = [
2.71 0.5 0.19 0.07 1.04
0.5 1.11 -0.08 -0.17 -0.08
0.19 -0.08 0.26 0.07 -0.7
0.07 -0.17 0.07 0.11 -0.21
1.04 -0.08 -0.7 -0.21 8.65
]
μ = [-0.55, 0.49, -0.76, 0.25, 0.94]
P = inv(Symmetric(Σ))
function logp_mvnormal(x)
z = x - μ
return -dot(z, P, z) / 2
end
nothing # hide
```

Now we run [`pathfinder`](@ref).

```@example 1
result = pathfinder(logp_mvnormal; dim=5, init_scale=4)
```

`result` is a [`PathfinderResult`](@ref).
See its docstring for a description of its fields.

`result.fit_distribution` is a multivariate normal approximation to our target distribution.
Its mean and covariance are quite close to our target distribution's.

```@example 1
result.fit_distribution.μ
```

```@example 1
result.fit_distribution.Σ
```

`result.draws` is a `Matrix` whose columns are the requested draws from `result.fit_distribution`:
```@example 1
result.draws
```

```@example 1
iterations = length(result.optim_trace) - 1
trace_points = result.optim_trace.points
trace_dists = result.fit_distributions
xrange = -5:0.1:5
yrange = -5:0.1:5
μ_marginal = μ[1:2]
P_marginal = inv(Σ[1:2,1:2])
logp_mvnormal_marginal(x) = -dot(x - μ_marginal, P_marginal, x - μ_marginal) / 2
anim = @animate for i in 1:iterations
contour(xrange, yrange, (x, y) -> logp_mvnormal_marginal([x, y]), label="")
trace = trace_points[1:(i + 1)]
dist = trace_dists[i + 1]
plot!(first.(trace), last.(trace); seriestype=:scatterpath, color=:black, msw=0, label="trace")
covellipse!(dist.μ[1:2], dist.Σ[1:2, 1:2]; n_std=2.45, alpha=0.7, color=1, linecolor=1, label="MvNormal 95% ellipsoid")
title = "Iteration $i"
plot!(; xlims=extrema(xrange), ylims=extrema(yrange), xlabel="x₁", ylabel="x₂", legend=:bottomright, title)
end
gif(anim, fps=5)
```

## A 100-dimensional funnel

Especially for complicated target distributions, it's more useful to run multi-path Pathfinder.
One difficult distribution to sample is Neal's funnel:

```math
\begin{aligned}
\tau &\sim \mathrm{Normal}(\mu=0, \sigma=3)\\
\beta_i &\sim \mathrm{Normal}(\mu=0, \sigma=e^{\tau/2})
\end{aligned}
```

Such funnel geometries appear in other models (e.g. many hierarchical models) and typically frustrate MCMC sampling.
Multi-path Pathfinder can't sample the funnel well, but it can quickly give us draws that can help us diagnose that we have a funnel.

In this example, we draw from a 100-dimensional funnel and visualize 2 dimensions.

```@example 1
Random.seed!(68)
function logp_funnel(x)
n = length(x)
τ = x[1]
β = view(x, 2:n)
return ((τ / 3)^2 + (n - 1) * τ + sum(b -> abs2(b * exp(-τ / 2)), β)) / -2
end
dim = 100
init_scale = 10
nothing # hide
```

First, let's fit this posterior with single-path Pathfinder.

```@example 1
result_single = pathfinder(logp_funnel; dim, init_scale)
```

The L-BFGS optimizer constructs an approximation to the inverse Hessian of the negative log density using the limited history of previous points and gradients.
For each iteration, Pathfinder uses this estimate as an approximation to the covariance matrix of a multivariate normal that approximates the target distribution.
The distribution that maximizes the evidence lower bound (ELBO) is returned.

Let's visualize this sequence of multivariate normals for the first two dimensions.

```@example 1
iterations = min(length(result_single.optim_trace) - 1, 15)
trace_points = result_single.optim_trace.points
trace_dists = result_single.fit_distributions
τ_range = -15:0.01:5
β₁_range = -5:0.01:5
anim = @animate for i in 1:iterations
contour(β₁_range, τ_range, (β, τ) -> exp(logp_funnel([τ, β])), label="")
trace = trace_points[1:(i + 1)]
dist = trace_dists[i + 1]
plot!(map(x -> x[2], trace), first.(trace); seriestype=:scatterpath, color=:black, msw=0, label="trace")
covellipse!(dist.μ[[2, 1]] , dist.Σ[[2, 1], [2, 1]]; n_std=2.45, alpha=0.7, color=1, linecolor=1, label="MvNormal 95% ellipsoid")
est = result_single.elbo_estimates[i]
title = "Iteration $i ELBO estimate: " * @sprintf("%.1f", est.value)
plot!(; xlims=extrema(β₁_range), ylims=extrema(τ_range), xlabel="β₁", ylabel="τ", legend=:bottomright, title)
end
gif(anim, fps=2)
```

For this challenging posterior, we can see that most of the approximations are not great, because this distribution is far from normal.
Also, this distribution has a pole instead of a mode, so there is no MAP estimate, and no Laplace distribution exists.
As optimization proceeds, the approximation goes from very bad to less bad and finally extremely bad.
The ELBO-maximizing distribution is at the neck of the funnel, which would be a good location to initialize MCMC.

It is always a good idea to run [`multipathfinder`](@ref) directly, which runs single-path Pathfinder multiple times.

```@example 1
ndraws = 1_000
result = multipathfinder(logp_funnel, ndraws; nruns=20, dim, init_scale)
```

`result` is a [`MultiPathfinderResult`](@ref).
See its docstring for a description of its fields.

`result.fit_distribution` is a uniformly-weighted `Distributions.MixtureModel`.
Each component is the result of a single Pathfinder run.
It's possible that some runs fit the target distribution much better than others, so instead of just drawing samples from `result.fit_distribution`, `multipathfinder` draws many samples from each component and then uses Pareto-smoothed importance resampling from these draws to better target `logp_funnel`.

The Pareto shape diagnostic also informs us on the quality of these draws.
Here [PSIS.jl](https://psis.julia.arviz.org/stable/), which smooths the importance weights, warns us that the importance weights are unsuitable for computing estimates, so we should definitely run MCMC to get better draws.

Here we can see that the bulk of Pathfinder's draws come from the neck of the funnel, where the fit from the single path we examined was located.

```@example 1
τ_approx = result.draws[1, :]
β₁_approx = result.draws[2, :]
contour(β₁_range, τ_range, (β, τ) -> exp(logp_funnel([τ, β])))
scatter!(β₁_approx, τ_approx; msw=0, ms=2, alpha=0.5, color=1)
plot!(xlims=extrema(β₁_range), ylims=extrema(τ_range), xlabel="β₁", ylabel="τ", legend=false)
```
110 changes: 110 additions & 0 deletions docs/src/examples/turing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Running Pathfinder on Turing.jl models

This tutorial demonstrates how [Turing](https://turing.ml/stable/) can be used with Pathfinder.

We'll demonstrate with a regression example.

```@example 1
using AdvancedHMC, LinearAlgebra, Pathfinder, Random, Turing
Random.seed!(39)
@model function regress(x, y)
α ~ Normal()
β ~ Normal()
σ ~ truncated(Normal(); lower=0)
y .~ Normal.(α .+ β .* x, σ)
end
x = 0:0.1:10
y = @. 2x + 1.5 + randn() * 0.2
nothing # hide
```

```@example 1
model = regress(collect(x), y)
```

The first way we can use Turing with Pathfinder is via its mode estimation functionality.
We can use `Turing.optim_function` to generate a `SciMLBase.OptimizationFunction`, which [`pathfinder`](@ref) and [`multipathfinder`](@ref) can take as inputs.

```@example 1
fun = optim_function(model, MAP(); constrained=false)
```

```@example 1
dim = length(fun.init())
pathfinder(fun.func; dim)
```

```@example 1
multipathfinder(fun.func, 1_000; dim, nruns=8)
```

However, for convenience, `pathfinder` and `multipathfinder` can take Turing models as inputs and produce `MCMCChains.Chains` objects as outputs.

```@example 1
result_single = pathfinder(model; ndraws=1_000)
```

```@example 1
result_multi = multipathfinder(model, 1_000; nruns=8)
```

Here, the Pareto shape diagnostic indicates that it is likely safe to use these draws to compute posterior estimates.

When passed a `Model`, Pathfinder also gives access to the posterior draws in a familiar `MCMCChains.Chains` object.

```@example 1
result_multi.draws_transformed
```

We can also use these posterior draws to initialize MCMC sampling.

```@example 1
init_params = collect.(eachrow(result_multi.draws_transformed.value[1:4, :, 1]))
```

```@example 1
chns = sample(model, Turing.NUTS(), MCMCThreads(), 1_000, 4; init_params, progress=false)
```

To use Pathfinder's estimate of the metric and skip warm-up, at the moment one needs to use AdvancedHMC directly.

```@example 1
ℓπ(x) = -fun.func.f(x, nothing)
function ∂ℓπ∂θ(x)
g = similar(x)
fun.func.grad(g, x, nothing)
rmul!(g, -1)
return ℓπ(x), g
end
ndraws = 1_000
nadapts = 50
inv_metric = result_multi.pathfinder_results[1].fit_distribution.Σ
metric = Pathfinder.RankUpdateEuclideanMetric(inv_metric)
hamiltonian = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
ϵ = find_good_stepsize(hamiltonian, init_params[1])
integrator = Leapfrog(ϵ)
proposal = AdvancedHMC.NUTS{MultinomialTS,GeneralisedNoUTurn}(integrator)
adaptor = StepSizeAdaptor(0.8, integrator)
samples, stats = sample(
hamiltonian,
proposal,
init_params[1],
ndraws + nadapts,
adaptor,
nadapts;
drop_warmup=true,
progress=false,
)
```

Now we pack the samples into an `MCMCChains.Chains`:

```@example 1
samples_transformed = reduce(vcat, fun.transform.(samples)')
varnames = Pathfinder.flattened_varnames_list(model)
chns = MCMCChains.Chains(samples_transformed, varnames)
```

See [Initializing HMC with Pathfinder](@ref) for further examples.
52 changes: 36 additions & 16 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -2,28 +2,48 @@
CurrentModule = Pathfinder
```

# Pathfinder
# Pathfinder.jl: Parallel quasi-Newton variational inference

Pathfinder[^Zhang2021] is a variational method for initializing Markov chain Monte Carlo (MCMC) methods.
This package implements Pathfinder, [^Zhang2021] a variational method for initializing Markov chain Monte Carlo (MCMC) methods.

## Introduction
## Single-path Pathfinder

When using MCMC to draw samples from some target distribution, there is often a length warm-up phase with 2 goals:
1. converge to the _typical set_ (the region of the target distribution where the bulk of the probability mass is located)
2. adapt any tunable parameters of the MCMC sampler (optional)
Single-path Pathfinder ([`pathfinder`](@ref)) attempts to return draws in or near the typical set, usually with many fewer gradient evaluations.
Pathfinder uses the [limited-memory BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS)(L-BFGS) optimizer to construct a _maximum a posteriori_ (MAP) estimate of a target posterior distribution ``p``.
It then uses the trace of the optimization to construct a sequence of multivariate normal approximations to the target distribution, returning the approximation that maximizes the [evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound) (ELBO) -- equivalently, minimizes the [Kullback-Leibler](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) divergence from the target distribution -- as well as draws from the distribution.

Typically (2) requires a lengthy exploration of the typical set.
An example is the widely used windowed adaptation scheme of Hamiltonian Monte Carlo (HMC), where a step size and mass matrix are adapted
For posteriors with complex geometry, the adaptation phase can require many evaluations of the gradient of the log density function of the target distribution.
## Multi-path Pathfinder

Pathfinder attempts to return samples in or near the typical set with many fewer gradient evaluations.
Pathfinder uses [L-BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) to construct a _maximum a posteriori_ (MAP) estimate of a target distribution ``p``.
It then uses the trace of the optimization to construct a sequence of multivariate normal approximations to the target distribution, returning the approximation that maximizes the evidence lower bound (ELBO), as well as draws from the distribution.
The covariance of the multivariate normal approximation can be used to instantiate the mass matrix adaptation in HMC.

Its extension, multi-path Pathfinder, runs Pathfinder multiple times.
Multi-path Pathfinder ([`multipathfinder`](@ref)) consists of running Pathfinder multiple times.
It returns a uniformly-weighted mixture model of the multivariate normal approximations of the individual runs.
It also uses importance resampling to return samples that better approximate the target distribution.
It also uses importance resampling to return samples that better approximate the target distribution and assess the quality of the approximation.

## Uses

### Using the Pathfinder draws

!!! note "Folk theorem of statistical computing"
When you have computational problems, often there’s a problem with your model.

Visualizing posterior draws is a common way to diagnose problems with a model.
However, problematic models often tend to be slow to warm-up.
Even if the draws returned by Pathfinder are only approximations to the posterior, they can sometimes still be used to diagnose basic issues such as highly correlated parameters, parameters with very different posterior variances, and multimodality.

### Initializing MCMC

Pathfinder can be used to initialize MCMC.
This especially useful when sampling with Hamiltonian Monte Carlo.
See [Initializing HMC with Pathfinder](@ref) for details.

## Integration with the Julia ecosystem

Pathfinder uses several packages for extended functionality:

- [GalacticOptim.jl](https://galacticoptim.sciml.ai/stable/): This allows the L-BFGS optimizer to be replaced with any of the many GalacticOptim-compatible optimizers and supports use of callbacks. Note that any changes made to Pathfinder using these features would be experimental.
- [Transducers.jl](https://juliafolds.github.io/Transducers.jl/stable/): parallelization support
- [Distributions.jl](https://juliastats.org/Distributions.jl/stable/)/[PDMats.jl](https://github.com/JuliaStats/PDMats.jl): fits can be used anywhere a `Distribution` can be used
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl): selecting the AD package used to differentiate the provided log-density function.
- [ProgressLogging.jl](https://julialogging.github.io/ProgressLogging.jl/stable/): In Pluto, Juno, and VSCode, nested progress bars are shown. In the REPL, use TerminalLoggers.jl to get progress bars.

[^Zhang2021]: Lu Zhang, Bob Carpenter, Andrew Gelman, Aki Vehtari (2021).
Pathfinder: Parallel quasi-Newton variational inference.
19 changes: 19 additions & 0 deletions docs/src/lib/internals.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Internals

Documentation for `Pathfinder.jl`'s internal functions.

See the [Public Documentation](@ref) section for documentation of the public interface.

## Index

```@index
Pages = ["internals.md"]
```

## Internal Interface

```@autodocs
Modules = [Pathfinder]
Public = false
Private = true
```
29 changes: 29 additions & 0 deletions docs/src/lib/public.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Public Documentation

Documentation for `Pathfinder.jl`'s public interface.

See the [Internals](@ref) section for documentation of internal functions.

## Index

```@index
Pages = ["public.md"]
```

## Public Interface

```@autodocs
Modules = [Pathfinder]
Pages = ["singlepath.jl"]
Order = [:function, :type]
Public = true
Private = false
```

```@autodocs
Modules = [Pathfinder]
Pages = ["multipath.jl"]
Order = [:function, :type]
Public = true
Private = false
```
88 changes: 0 additions & 88 deletions docs/src/multipathfinder.md

This file was deleted.

55 changes: 0 additions & 55 deletions docs/src/pathfinder.md

This file was deleted.

2 changes: 1 addition & 1 deletion src/inverse_hessian.jl
Original file line number Diff line number Diff line change
@@ -80,8 +80,8 @@ E &= I \\circ R\\\\
D &= \\begin{pmatrix}
0 & -R^{-1}\\\\
-R^{-\\mathrm{T}} & R^\\mathrm{-T} (E + Y^\\mathrm{T} H_0 Y ) R^\\mathrm{-1}\\\\
\\end{pmatrix}\\
H &= H_0 + B D B^\\mathrm{T}
\\end{pmatrix}
\\end{align}
```