Skip to content

Commit

Permalink
fix doc example
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Oct 1, 2024
1 parent 8d74889 commit bceb510
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions docs/src/state_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Base.vec(state)
This function takes the state and returns a vector of the parameter values stored in the state.

```julia
state = StateType(state, logp)
state = StateType(state::StateType, logp)
```

This function takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability.
Expand All @@ -42,7 +42,17 @@ x_i &\sim \text{Normal}(\mu, \sqrt{\tau^2})

We can write the log joint probability function as follows, where for the sake of simplicity for the following steps, we will assume that the `mu` and `tau2` parameters are one-element vectors. And `x` is the data.

```julia
```@example gibbs_example
using AbstractMCMC: AbstractMCMC, LogDensityProblems # hide
using Distributions # hide
using Random # hide
using AbstractMCMC: AbstractMCMC # hide
using AbstractPPL: AbstractPPL # hide
using BangBang: constructorof # hide
using AbstractPPL: AbstractPPL
```

```@example gibbs_example
function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64})
# mu is the mean
# tau2 is the variance
Expand Down Expand Up @@ -71,7 +81,7 @@ end

To make using `LogDensityProblems` interface, we create a simple type for this model.

```julia
```@example gibbs_example
abstract type AbstractHierNormal end
struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal
Expand All @@ -89,15 +99,15 @@ end

where `ConditionedHierNormal` is a type that represents the model conditioned on some variables, and

```julia
```@example gibbs_example
function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple)
return ConditionedHierNormal(hn.data, conditioned_values)
end
```

then we can simply write down the `LogDensityProblems` interface for this model.

```julia
```@example gibbs_example
function LogDensityProblems.logdensity(
hier_normal_model::ConditionedHierNormal{Tdata,Tconditioned_vars},
params::AbstractVector,
Expand Down Expand Up @@ -132,7 +142,7 @@ Although the interface doesn't force the sampler to implement `Transition` and `

Here we define some bare minimum types to represent the transitions and states.

```julia
```@example gibbs_example
struct MHTransition{T}
params::Vector{T}
end
Expand All @@ -145,7 +155,7 @@ end

Next we define the `state` interface functions mentioned at the beginning of this section.

```julia
```@example gibbs_example
# Interface 1: LogDensityProblems.logdensity
# This function takes the logdensity function and the state (state is defined by the sampler package)
# and returns the logdensity. It allows for optional recomputation of the log probability.
Expand All @@ -168,15 +178,14 @@ Base.vec(state::MHState) = state.params
# Interface 3: constructorof and MHState(state::MHState, logp::Float64)
# This function allows the state to be updated with a new log probability.
BangBang.constructorof(state::MHState{T}) where {T} = MHState
function MHState(state::MHState, logp::Float64)
return MHState(state.params, logp)
end
```

It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `LogDensityProblems.logdensity` to easily read the log probability of the current state.

```julia
```@example gibbs_example
"""
RandomWalkMH{T} <: AbstractMCMC.AbstractSampler
Expand Down Expand Up @@ -264,7 +273,7 @@ end

At last, we can proceed to implement a very simple Gibbs sampler.

```julia
```@example gibbs_example
struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler
"Maps variables to their samplers."
sampler_map::T
Expand Down Expand Up @@ -440,7 +449,7 @@ The Gibbs sampler operates in two main phases:
b. Update current variable:
- Recompute the log probability of the current state, as other variables may have changed:
- Use `LogDensityProblems.logdensity(cond_logdensity_model, sub_state)` to get the new log probability.
- Update the state with `sub_state = sub_state(logp)` to incorporate the new log probability.
- Update the state with `sub_state = constructorof(typeof(sub_state))(sub_state, logp)` to incorporate the new log probability.
- Perform a sampling step for the current conditional probability problem:
- Use `AbstractMCMC.step(rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...)` to generate a new state.
- Update the global trace:
Expand Down

0 comments on commit bceb510

Please sign in to comment.