From bceb510234edba4c55c4f3830c567f23bacdc879 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 1 Oct 2024 18:25:30 +0100 Subject: [PATCH] fix doc example --- docs/src/state_interface.md | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/docs/src/state_interface.md b/docs/src/state_interface.md index a7affe5..60a72af 100644 --- a/docs/src/state_interface.md +++ b/docs/src/state_interface.md @@ -17,7 +17,7 @@ Base.vec(state) This function takes the state and returns a vector of the parameter values stored in the state. ```julia -state = StateType(state, logp) +state = StateType(state::StateType, logp) ``` This function takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. @@ -42,7 +42,17 @@ x_i &\sim \text{Normal}(\mu, \sqrt{\tau^2}) We can write the log joint probability function as follows, where for the sake of simplicity for the following steps, we will assume that the `mu` and `tau2` parameters are one-element vectors. And `x` is the data. -```julia +```@example gibbs_example +using AbstractMCMC: AbstractMCMC, LogDensityProblems # hide +using Distributions # hide +using Random # hide +using AbstractMCMC: AbstractMCMC # hide +using AbstractPPL: AbstractPPL # hide +using BangBang: constructorof # hide +using AbstractPPL: AbstractPPL +``` + +```@example gibbs_example function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64}) # mu is the mean # tau2 is the variance @@ -71,7 +81,7 @@ end To make using `LogDensityProblems` interface, we create a simple type for this model. -```julia +```@example gibbs_example abstract type AbstractHierNormal end struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal @@ -89,7 +99,7 @@ end where `ConditionedHierNormal` is a type that represents the model conditioned on some variables, and -```julia +```@example gibbs_example function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple) return ConditionedHierNormal(hn.data, conditioned_values) end @@ -97,7 +107,7 @@ end then we can simply write down the `LogDensityProblems` interface for this model. -```julia +```@example gibbs_example function LogDensityProblems.logdensity( hier_normal_model::ConditionedHierNormal{Tdata,Tconditioned_vars}, params::AbstractVector, @@ -132,7 +142,7 @@ Although the interface doesn't force the sampler to implement `Transition` and ` Here we define some bare minimum types to represent the transitions and states. -```julia +```@example gibbs_example struct MHTransition{T} params::Vector{T} end @@ -145,7 +155,7 @@ end Next we define the `state` interface functions mentioned at the beginning of this section. -```julia +```@example gibbs_example # Interface 1: LogDensityProblems.logdensity # This function takes the logdensity function and the state (state is defined by the sampler package) # and returns the logdensity. It allows for optional recomputation of the log probability. @@ -168,7 +178,6 @@ Base.vec(state::MHState) = state.params # Interface 3: constructorof and MHState(state::MHState, logp::Float64) # This function allows the state to be updated with a new log probability. -BangBang.constructorof(state::MHState{T}) where {T} = MHState function MHState(state::MHState, logp::Float64) return MHState(state.params, logp) end @@ -176,7 +185,7 @@ end It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `LogDensityProblems.logdensity` to easily read the log probability of the current state. -```julia +```@example gibbs_example """ RandomWalkMH{T} <: AbstractMCMC.AbstractSampler @@ -264,7 +273,7 @@ end At last, we can proceed to implement a very simple Gibbs sampler. -```julia +```@example gibbs_example struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler "Maps variables to their samplers." sampler_map::T @@ -440,7 +449,7 @@ The Gibbs sampler operates in two main phases: b. Update current variable: - Recompute the log probability of the current state, as other variables may have changed: - Use `LogDensityProblems.logdensity(cond_logdensity_model, sub_state)` to get the new log probability. - - Update the state with `sub_state = sub_state(logp)` to incorporate the new log probability. + - Update the state with `sub_state = constructorof(typeof(sub_state))(sub_state, logp)` to incorporate the new log probability. - Perform a sampling step for the current conditional probability problem: - Use `AbstractMCMC.step(rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...)` to generate a new state. - Update the global trace: