diff --git a/Project.toml b/Project.toml index 97c90709..a30156a4 100644 --- a/Project.toml +++ b/Project.toml @@ -3,12 +3,13 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.2.0" +version = "5.3.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -20,11 +21,13 @@ TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] +AbstractPPL = "0.8" BangBang = "0.3.19, 0.4" ConsoleProgressMonitor = "0.1" FillArrays = "1" LogDensityProblems = "2" LoggingExtras = "0.4, 0.5, 1" +MCMCChains = "6" ProgressLogging = "0.1" StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" @@ -32,10 +35,13 @@ Transducers = "0.4.30" julia = "1.6" [extras] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["FillArrays", "IJulia", "Statistics", "Test"] +test = ["AbstractPPL","FillArrays", "Distributions", "IJulia", "MCMCChains", "Statistics", "Test"] diff --git a/design_notes/on_gibbs_implementation.md b/design_notes/on_gibbs_implementation.md new file mode 100644 index 00000000..78a42e60 --- /dev/null +++ b/design_notes/on_gibbs_implementation.md @@ -0,0 +1,59 @@ +# On `AbstractMCMC` Interface Supporting `Gibbs` + +This is written at Oct 1st, 2024. Version of packages described in this passage are: + +* `Turing.jl`: 0.34.1 + +In this passage, `Gibbs` refers to `Experimental.Gibbs`. + +## Current Implementation of `Gibbs` in `Turing` + +Here I describe the current implementation of `Gibbs` in `Turing` and the interface it requires from its sampler states. + +### Interface 1: `getparams` + +From the [definition of `GibbsState`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/experimental/gibbs.jl#L244-L248), we can see that a `vi::DynamicPPL.AbstractVarInfo` field is used to keep track of the names and values of parameters and the log density. The `states` field collects the sampler-specific *state*s. + +(The *link*ing of *varinfo*s is omitted in this discussion.) +A local `VarInfo` is initially created with `DynamicPPL.subset(::VarInfo, ::Vector{<:VarName})` to make the conditioned model. After the Gibbs step, an updated `varinfo` is obtained by calling `Turing.Inference.varinfo` on the sampler state. + +For samplers and their states defined in `Turing` (including `DynamicHMC`, as `DynamicNUTSState` is defined by `Turing` in the package extension), we (à la `Turing.jl` package) assume that the *state*s all have a field called `vi`. Then `varinfo(_some_sampler_state_)` is simply `varinfo(state) = state.vi` (defined in [`src/mcmc/gibbs.jl`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/gibbs.jl#L97)). (`GibbsState` conforms to this assumption.) + +For `ExternalSamplers`, we currently only support `AdvancedHMC` and `AdvancedMH`. The mechanism is as follows: at the end of the `step` call with an external sampler, [`transition_to_turing` and `state_to_turing` are called](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/abstractmcmc.jl#L147). These two functions then call `getparams` on the sampler state of the external samplers. `getparams` for `AdvancedHMC.HMCState` and `AdvancedMH.Transition` (`AdvancedMH` uses `Transition` as state) are defined in `abstractmcmc.jl`. + +Thus, the first interface emerges: `getparams`. As `getparams` is designed to be implemented by a sampler that works with the `LogDensityProblems` interface, it makes sense for `getparams` to return a vector of `Real`s. The `logdensity_problem` should then be responsible for performing the transformation between its underlying representation and the vector of `Real`s. + +It's worth noting that: + +* `getparams` is not a function specific for `Gibbs`. It is required for the current support of external samplers. +* There is another [`getparams`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/Inference.jl#L328-L351) in `Turing.jl` that takes *model* and *varinfo*, then returns a `NamedTuple`. + +### Interface 2: `recompute_logp!!` + +Consider a model with multiple groups of variables, say $\theta_1, \theta_2, \ldots, \theta_k$. At the beginning of the $t$-th Gibbs step, the model parameters in the `GibbsState` are typically updated and different from the $(t-1)$-th step. The `GibbsState` maintains $k$ sub-states, one for each variable group, denoted as $\text{state}_{t,1}, \text{state}_{t,2}, \ldots, \text{state}_{t,k}$. + +The parameter values in each sub-state, i.e., $\theta_{t,i}$ in $\text{state}_{t,i}$, are always in sync with the corresponding values in the `GibbsState`. At the end of the $t$-th Gibbs step, $\text{state}_{t,i}$ will store the log density of the $i$-th variable group conditioned on all other variable groups at their values from step $t$, denoted as $\log p(\theta_{t,i} \mid \theta_{t,-i})$. This log density is equal to the joint log density of the whole model evaluated at the current parameter values $(\theta_{t,1}, \ldots, \theta_{t,k})$. + +However, the log density stored in each sub-state is in general not equal to the log density needed for the next Gibbs step at $t+1$, i.e., $\log p(\theta_{t,i} \mid \theta_{t+1,-i})$. This is because the values of the other variable groups $\theta_{-i}$ will have been updated in the Gibbs step from $t$ to $t+1$, changing the conditioning set. Therefore, the log density typically needs to be recomputed at each Gibbs step to account for the updated values of the conditioning variables. + +Only in certain special cases, the recomputation can be skipped. For example, in a Metropolis-Hastings step where the proposal is rejected for all other variable groups, i.e., $\theta_{t+1,-i} = \theta_{t,-i}$, the log density $\log p(\theta_{t,i} \mid \theta_{t,-i})$ remains valid and doesn't need to be recomputed. + +The `recompute_logp!!` function in `abstractmcmc.jl` handles this recomputation. It takes an updated conditioned log density function $\log p(\cdot \mid \theta_{t+1,j})$ and the parameter values $\theta_{t,i}$ stored in $\text{state}_{t,i}$ to compute the updated log density $\log p(\theta_{t,i} \mid \theta_{t+1,j})$. + +## Proposed Interface + +The two functions `getparams` and `recompute_logp!!` form a minimal interface to support the `Gibbs` implementation. However, there are concerns about introducing them directly into `AbstractMCMC`. The main reason is that `AbstractMCMC` is a root dependency of the `Turing` packages, so we want to be very careful with new releases. + +Here, some alternative functions that achieve the same functionality as `getparams` and `recompute_logp!!` are proposed, but without introducing new interface functions. + +For `getparams`, we can use `Base.vec`. It is a `Base` function, so there's no need to export anything from `AbstractMCMC`. Since `getparams` should return a vector, using `vec` makes sense. The concern is that, officially, `Base.vec` is defined for `AbstractArray`, so it remains a question whether we should only introduce `vec` in the absence of other `AbstractArray` interfaces. + +For `recompute_logp!!`, we could overload `LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::State; recompute_logp=true)` to compute the log probability. If `recompute_logp` is `true`, it should recompute the log probability of the state. Otherwise, it could use the log probability stored in the state. To allow updating the log probability stored in the state, samplers should define outer constructor for their state types `StateType(state::StateType, logdensity=logp)` that takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. + +While overloading `LogDensityProblems.logdensity` to take a state object instead of a vector for the second argument somewhat deviates from the interface in `LogDensityProblems`, it provides a clean and extensible solution for handling log probability recomputation within the existing interface. + +An example demonstrating these interfaces is provided in `src/state_interface.md`. + +## A More Standalone `Gibbs` Implementation + +`AbstractMCMC.Gibbs` should not manage a `variable name → sampler` but rather `range → sampler`, i.e. it maintains a vector of parameter values, while a higher-level interface like `AbstractPPL` / `DynamicPPL` should manage both the name and transformations. diff --git a/docs/Project.toml b/docs/Project.toml index f74dfb58..040a68b0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,8 @@ [deps] +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/docs/make.jl b/docs/make.jl index 9395d2a0..adec1df9 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,7 +8,7 @@ makedocs(; sitename="AbstractMCMC", format=Documenter.HTML(), modules=[AbstractMCMC], - pages=["Home" => "index.md", "api.md", "design.md"], + pages=["Home" => "index.md", "api.md", "design.md", "state_interface.md"], checkdocs=:exports, ) diff --git a/docs/src/state_interface.md b/docs/src/state_interface.md new file mode 100644 index 00000000..26efb9d4 --- /dev/null +++ b/docs/src/state_interface.md @@ -0,0 +1,506 @@ +# Interface For Sampler `state` and Gibbs Sampling + +We encourage sampler packages to implement the following interface functions for the `state` type(s) they maintain: + +```julia +LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::MHState; recompute_logp=true) +``` + +This function takes the logdensity model and the state, and returns the log probability of the state. +If `recompute_logp` is `true`, it should recompute the log probability of the state. +Otherwise, if available, it will use the log probability stored in the state. + +```julia +Base.vec(state) +``` + +This function takes the state and returns a vector of the parameter values stored in the state. + +```julia +state = StateType(state::StateType, logp) +``` + +This function takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. + +These functions provide a minimal interface to interact with the `state` datatype, which a sampler package can optionally implement. +The interface facilitates the implementation of "meta-algorithms" that combine different samplers. +We will demonstrate this in the following sections. + +## Using the `state` Interface for block sampling within Gibbs + +In this sections, we will demonstrate how a `model` package may use this `state` interface to support a Gibbs sampler that can support blocking sampling using different inference algorithms. + +We consider a simple hierarchical model with a normal likelihood, with unknown mean and variance parameters. + +```math +\begin{align} +\mu &\sim \text{Normal}(0, 1) \\ +\tau^2 &\sim \text{InverseGamma}(1, 1) \\ +x_i &\sim \text{Normal}(\mu, \sqrt{\tau^2}) +\end{align} +``` + +We can write the log joint probability function as follows, where for the sake of simplicity for the following steps, we will assume that the `mu` and `tau2` parameters are one-element vectors. And `x` is the data. + +```@example gibbs_example +using AbstractMCMC: AbstractMCMC # hide +using LogDensityProblems # hide +using Distributions # hide +using Random # hide +using AbstractMCMC: AbstractMCMC # hide +using AbstractPPL: AbstractPPL # hide +using BangBang: constructorof # hide +function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64}) + # mu is the mean + # tau2 is the variance + # x is data + + # μ ~ Normal(0, 1) + # τ² ~ InverseGamma(1, 1) + # xᵢ ~ Normal(μ, √τ²) + + logp = 0.0 + mu = only(mu) + tau2 = only(tau2) + + mu_prior = Normal(0, 1) + logp += logpdf(mu_prior, mu) + + tau2_prior = InverseGamma(1, 1) + logp += logpdf(tau2_prior, tau2) + + obs_prior = Normal(mu, sqrt(tau2)) + logp += sum(logpdf(obs_prior, xi) for xi in x) + + return logp +end +``` + +To make using `LogDensityProblems` interface, we create a simple type for this model. + +```@example gibbs_example +abstract type AbstractHierNormal end + +struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal + data::Tdata +end + +struct ConditionedHierNormal{Tdata<:NamedTuple,Tconditioned_vars<:NamedTuple} <: + AbstractHierNormal + data::Tdata + + " The variable to be conditioned on and its value" + conditioned_values::Tconditioned_vars +end +``` + +where `ConditionedHierNormal` is a type that represents the model conditioned on some variables, and + +```@example gibbs_example +function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple) + return ConditionedHierNormal(hn.data, conditioned_values) +end +``` + +then we can simply write down the `LogDensityProblems` interface for this model. + +```@example gibbs_example +function LogDensityProblems.logdensity( + hier_normal_model::ConditionedHierNormal{Tdata,Tconditioned_vars}, + params::AbstractVector, +) where {Tdata,Tconditioned_vars} + variable_to_condition = only(fieldnames(Tconditioned_vars)) + data = hier_normal_model.data + conditioned_values = hier_normal_model.conditioned_values + + if variable_to_condition == :mu + return log_joint(; mu=conditioned_values.mu, tau2=params, x=data.x) + elseif variable_to_condition == :tau2 + return log_joint(; mu=params, tau2=conditioned_values.tau2, x=data.x) + else + error("Unsupported conditioning variable: $variable_to_condition") + end +end + +function LogDensityProblems.capabilities(::HierNormal) + return LogDensityProblems.LogDensityOrder{0}() +end + +function LogDensityProblems.capabilities(::ConditionedHierNormal) + return LogDensityProblems.LogDensityOrder{0}() +end +``` + +### Implementing A Sampler with `AbstractMCMC` Interface + +To illustrate the `AbstractMCMC` interface, we will first implement two very simple Metropolis-Hastings samplers: random walk and static proposal. + +Although the interface doesn't force the sampler to implement `Transition` and `State` types, in practice, it has been the convention to do so. + +Here we define some bare minimum types to represent the transitions and states. + +```@example gibbs_example +struct MHTransition{T} + params::Vector{T} +end + +struct MHState{T} + params::Vector{T} + logp::Float64 +end +``` + +Next we define the `state` interface functions mentioned at the beginning of this section. + +```@example gibbs_example +# Interface 1: LogDensityProblems.logdensity +# This function takes the logdensity function and the state (state is defined by the sampler package) +# and returns the logdensity. It allows for optional recomputation of the log probability. +# If recomputation is not needed, it returns the stored log probability from the state. +function LogDensityProblems.logdensity( + logdensity_model::AbstractMCMC.LogDensityModel, state::MHState; recompute_logp=true +) + logdensity_function = logdensity_model.logdensity + return if recompute_logp + AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) + else + state.logp + end +end + +# Interface 2: Base.vec +# This function takes a state and returns a vector of the parameter values stored in the state. +# It is part of the interface for interacting with the state object. +Base.vec(state::MHState) = state.params + +# Interface 3: constructorof and MHState(state::MHState, logp::Float64) +# This function allows the state to be updated with a new log probability. +function MHState(state::MHState, logp::Float64) + return MHState(state.params, logp) +end +``` + +It is very simple to implement the samplers according to the `AbstractMCMC` interface, where we can use `LogDensityProblems.logdensity` to easily read the log probability of the current state. + +```@example gibbs_example +abstract type AbstractMHSampler <: AbstractMCMC.AbstractSampler end + +""" + RandomWalkMH{T} <: AbstractMCMC.AbstractSampler + +A random walk Metropolis-Hastings sampler with a normal proposal distribution. The field σ +is the standard deviation of the proposal distribution. +""" +struct RandomWalkMH{T} <: AbstractMHSampler + σ::T +end + +""" + IndependentMH{T} <: AbstractMCMC.AbstractSampler + +A Metropolis-Hastings sampler with an independent proposal distribution. +""" +struct IndependentMH{T} <: AbstractMHSampler + proposal_dist::T +end + +# the first step of the sampler +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::AbstractMHSampler, + args...; + initial_params, + kwargs..., +) + logdensity_function = logdensity_model.logdensity + transition = MHTransition(initial_params) + state = MHState( + initial_params, + only(LogDensityProblems.logdensity(logdensity_function, initial_params)), + ) + + return transition, state +end + +@inline get_proposal_dist(sampler::RandomWalkMH, current_params::Vector{Float64}) = + MvNormal(current_params, sampler.σ) +@inline get_proposal_dist(sampler::IndependentMH, current_params::Vector{T}) where {T} = + sampler.proposal_dist + +# the subsequent steps of the sampler +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::AbstractMHSampler, + state::MHState, + args...; + kwargs..., +) + logdensity_function = logdensity_model.logdensity + current_params = state.params + proposal_dist = get_proposal_dist(sampler, current_params) + proposed_params = rand(rng, proposal_dist) + logp_proposal = only( + LogDensityProblems.logdensity(logdensity_function, proposed_params) + ) + + if log(rand(rng)) < + compute_log_acceptance_ratio(sampler, state, proposed_params, logp_proposal) + return MHTransition(proposed_params), MHState(proposed_params, logp_proposal) + else + return MHTransition(current_params), MHState(current_params, state.logp) + end +end + +function compute_log_acceptance_ratio( + ::RandomWalkMH, state::MHState, ::Vector{Float64}, logp_proposal::Float64 +) + return min(0, logp_proposal - state.logp) +end + +function compute_log_acceptance_ratio( + sampler::IndependentMH, state::MHState, proposal::Vector{T}, logp_proposal::Float64 +) where {T} + return min( + 0, + logp_proposal - state.logp + logpdf(sampler.proposal_dist, state.params) - + logpdf(sampler.proposal_dist, proposal), + ) +end +``` + +At last, we can proceed to implement a very simple Gibbs sampler. + +```@example gibbs_example +struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler + "Maps variables to their samplers." + sampler_map::T +end + +struct GibbsState{TraceNT<:NamedTuple,StateNT<:NamedTuple,SizeNT<:NamedTuple} + "Contains the values of all parameters up to the last iteration." + trace::TraceNT + + "Maps parameters to their sampler-specific MCMC states." + mcmc_states::StateNT + + "Maps parameters to their sizes." + variable_sizes::SizeNT +end + +struct GibbsTransition{ValuesNT<:NamedTuple} + "Realizations of the parameters, this is considered a \"sample\" in the MCMC chain." + values::ValuesNT +end + +""" + update_trace(trace::NamedTuple, gibbs_state::GibbsState) + +Update the trace with the values from the MCMC states of the sub-problems. +""" +function update_trace( + trace::NamedTuple{trace_names}, gibbs_state::GibbsState{TraceNT,StateNT,SizeNT} +) where {trace_names,TraceNT,StateNT,SizeNT} + for parameter_variable in fieldnames(StateNT) + sub_state = gibbs_state.mcmc_states[parameter_variable] + sub_state_params_values = Base.vec(sub_state) + reshaped_sub_state_params_values = reshape( + sub_state_params_values, gibbs_state.variable_sizes[parameter_variable] + ) + unflattened_sub_state_params = NamedTuple{(parameter_variable,)}(( + reshaped_sub_state_params_values, + )) + trace = merge(trace, unflattened_sub_state_params) + end + return trace +end + +function error_if_not_fully_initialized( + initial_params::NamedTuple{ParamNames}, sampler::Gibbs{<:NamedTuple{SamplerNames}} +) where {ParamNames,SamplerNames} + if Set(ParamNames) != Set(SamplerNames) + throw( + ArgumentError( + "initial_params must contain all parameters in the model, expected $(SamplerNames), got $(ParamNames)", + ), + ) + end +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::Gibbs{Tsamplingmap}; + initial_params::NamedTuple, + kwargs..., +) where {Tsamplingmap} + error_if_not_fully_initialized(initial_params, sampler) + + model_parameter_names = fieldnames(Tsamplingmap) + results = map(model_parameter_names) do parameter_variable + sub_sampler = sampler.sampler_map[parameter_variable] + + variables_to_be_conditioned_on = setdiff( + model_parameter_names, (parameter_variable,) + ) + conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( + Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) + ) + + # LogDensityProblems' `logdensity` function expects a single vector of real numbers + # `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values + # and unflatten after the sampling step + flattened_sub_problem_parameters_values = vec(initial_params[parameter_variable]) + + sub_logdensity_model = AbstractMCMC.LogDensityModel( + AbstractPPL.condition( + logdensity_model.logdensity, conditioning_variables_values + ), + ) + sub_state = last( + AbstractMCMC.step( + rng, + sub_logdensity_model, + sub_sampler; + initial_params=flattened_sub_problem_parameters_values, + kwargs..., + ), + ) + (sub_state, size(initial_params[parameter_variable])) + end + + mcmc_states_tuple = first.(results) + variable_sizes_tuple = last.(results) + + gibbs_state = GibbsState( + initial_params, + NamedTuple{Tuple(model_parameter_names)}(mcmc_states_tuple), + NamedTuple{Tuple(model_parameter_names)}(variable_sizes_tuple), + ) + + trace = update_trace(NamedTuple(), gibbs_state) + return GibbsTransition(trace), gibbs_state +end + +# subsequent steps +function AbstractMCMC.step( + rng::Random.AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::Gibbs{Tsamplingmap}, + gibbs_state::GibbsState; + kwargs..., +) where {Tsamplingmap} + trace = gibbs_state.trace + mcmc_states = gibbs_state.mcmc_states + variable_sizes = gibbs_state.variable_sizes + + model_parameter_names = fieldnames(Tsamplingmap) + mcmc_states = map(model_parameter_names) do parameter_variable + sub_sampler = sampler.sampler_map[parameter_variable] + sub_state = mcmc_states[parameter_variable] + variables_to_be_conditioned_on = setdiff( + model_parameter_names, (parameter_variable,) + ) + conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( + Tuple([trace[g] for g in variables_to_be_conditioned_on]) + ) + cond_logdensity = AbstractPPL.condition( + logdensity_model.logdensity, conditioning_variables_values + ) + cond_logdensity_model = AbstractMCMC.LogDensityModel(cond_logdensity) + + logp = LogDensityProblems.logdensity( + cond_logdensity_model, sub_state; recompute_logp=true + ) + sub_state = constructorof(typeof(sub_state))(sub_state, logp) + sub_state = last( + AbstractMCMC.step( + rng, cond_logdensity_model, sub_sampler, sub_state; kwargs... + ), + ) + trace = update_trace(trace, gibbs_state) + sub_state + end + mcmc_states = NamedTuple{Tuple(model_parameter_names)}(mcmc_states) + + return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes) +end +``` + +We are using `NamedTuple` to store the mapping between variables and samplers. The order will determine the order of the Gibbs sweeps. A limitation is that exactly one sampler for each variable is required, which means it is less flexible than Gibbs in `Turing.jl`. + +We uses the `AbstractPPL.condition` to devide the full model into smaller conditional probability problems. +And each conditional probability problem corresponds to a sampler and corresponding state. + +The `Gibbs` sampler has the same interface as other samplers in `AbstractMCMC` (we don't implement the above state interface for `GibbsState` to keep it simple, but it can be implemented similarly). + +The Gibbs sampler operates in two main phases: + +1. Initialization: + - Set up initial states for each conditional probability problem. + +2. Iterative Sampling: + For each iteration, the sampler performs a sweep over all conditional probability problems: + + a. Condition on other variables: + - Fix the values of all variables except the current one. + b. Update current variable: + - Recompute the log probability of the current state, as other variables may have changed: + - Use `LogDensityProblems.logdensity(cond_logdensity_model, sub_state)` to get the new log probability. + - Update the state with `sub_state = constructorof(typeof(sub_state))(sub_state, logp)` to incorporate the new log probability. + - Perform a sampling step for the current conditional probability problem: + - Use `AbstractMCMC.step(rng, cond_logdensity_model, sub_sampler, sub_state; kwargs...)` to generate a new state. + - Update the global trace: + - Extract parameter values from the new state using `Base.vec(new_sub_state)`. + - Incorporate these values into the overall Gibbs state trace. + +This process allows the Gibbs sampler to iteratively update each variable while conditioning on the others, gradually exploring the joint distribution of all variables. + +Now we can use the Gibbs sampler to sample from the hierarchical Normal model. + +First we generate some data, + +```@example gibbs_example +N = 100 # Number of data points +mu_true = 0.5 # True mean +tau2_true = 2.0 # True variance + +x_data = rand(Normal(mu_true, sqrt(tau2_true)), N) +``` + +Then we can create a `HierNormal` model, with the data we just generated. + +```@example gibbs_example +hn = HierNormal((x=x_data,)) +``` + +Using Gibbs sampling allows us to use random walk MH for `mu` and prior MH for `tau2`, because `tau2` has support only on positive real numbers. + +```@example gibbs_example +samples = sample( + hn, + Gibbs(( + mu=RandomWalkMH(0.3), + tau2=IndependentMH(product_distribution([InverseGamma(1, 1)])), + )), + 10000; + initial_params=(mu=[0.0], tau2=[1.0]), +) +``` + +Then we can extract the samples and compute the mean of the samples. + +```@example gibbs_example +warmup = 5000 +thin = 10 +thinned_samples = samples[(warmup + 1):thin:end] +mu_samples = [sample.values.mu for sample in thinned_samples] +tau2_samples = [sample.values.tau2 for sample in thinned_samples] + +mu_mean = only(mean(mu_samples)) +tau2_mean = only(mean(tau2_samples)) +(mu_mean, tau2_mean) +``` + +which is close to the true values `(5, 2)`. diff --git a/test/gibbs_example/gibbs.jl b/test/gibbs_example/gibbs.jl new file mode 100644 index 00000000..17bc9552 --- /dev/null +++ b/test/gibbs_example/gibbs.jl @@ -0,0 +1,158 @@ +using AbstractMCMC: AbstractMCMC +using AbstractPPL: AbstractPPL +using BangBang: constructorof +using Random + +struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler + "Maps variables to their samplers." + sampler_map::T +end + +struct GibbsState{TraceNT<:NamedTuple,StateNT<:NamedTuple,SizeNT<:NamedTuple} + "Contains the values of all parameters up to the last iteration." + trace::TraceNT + + "Maps parameters to their sampler-specific MCMC states." + mcmc_states::StateNT + + "Maps parameters to their sizes." + variable_sizes::SizeNT +end + +struct GibbsTransition{ValuesNT<:NamedTuple} + "Realizations of the parameters, this is considered a \"sample\" in the MCMC chain." + values::ValuesNT +end + +""" + update_trace(trace::NamedTuple, gibbs_state::GibbsState) + +Update the trace with the values from the MCMC states of the sub-problems. +""" +function update_trace( + trace::NamedTuple{trace_names}, gibbs_state::GibbsState{TraceNT,StateNT,SizeNT} +) where {trace_names,TraceNT,StateNT,SizeNT} + for parameter_variable in fieldnames(StateNT) + sub_state = gibbs_state.mcmc_states[parameter_variable] + sub_state_params_values = Base.vec(sub_state) + reshaped_sub_state_params_values = reshape( + sub_state_params_values, gibbs_state.variable_sizes[parameter_variable] + ) + unflattened_sub_state_params = NamedTuple{(parameter_variable,)}(( + reshaped_sub_state_params_values, + )) + trace = merge(trace, unflattened_sub_state_params) + end + return trace +end + +function error_if_not_fully_initialized( + initial_params::NamedTuple{ParamNames}, sampler::Gibbs{<:NamedTuple{SamplerNames}} +) where {ParamNames,SamplerNames} + if Set(ParamNames) != Set(SamplerNames) + throw( + ArgumentError( + "initial_params must contain all parameters in the model, expected $(SamplerNames), got $(ParamNames)", + ), + ) + end +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::Gibbs{Tsamplingmap}; + initial_params::NamedTuple, + kwargs..., +) where {Tsamplingmap} + error_if_not_fully_initialized(initial_params, sampler) + + model_parameter_names = fieldnames(Tsamplingmap) + results = map(model_parameter_names) do parameter_variable + sub_sampler = sampler.sampler_map[parameter_variable] + + variables_to_be_conditioned_on = setdiff( + model_parameter_names, (parameter_variable,) + ) + conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( + Tuple([initial_params[g] for g in variables_to_be_conditioned_on]) + ) + + # LogDensityProblems' `logdensity` function expects a single vector of real numbers + # `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values + # and unflatten after the sampling step + flattened_sub_problem_parameters_values = vec(initial_params[parameter_variable]) + + sub_logdensity_model = AbstractMCMC.LogDensityModel( + AbstractPPL.condition( + logdensity_model.logdensity, conditioning_variables_values + ), + ) + sub_state = last( + AbstractMCMC.step( + rng, + sub_logdensity_model, + sub_sampler; + initial_params=flattened_sub_problem_parameters_values, + kwargs..., + ), + ) + (sub_state, size(initial_params[parameter_variable])) + end + + mcmc_states_tuple = first.(results) + variable_sizes_tuple = last.(results) + + gibbs_state = GibbsState( + initial_params, + NamedTuple{Tuple(model_parameter_names)}(mcmc_states_tuple), + NamedTuple{Tuple(model_parameter_names)}(variable_sizes_tuple), + ) + + trace = update_trace(NamedTuple(), gibbs_state) + return GibbsTransition(trace), gibbs_state +end + +# subsequent steps +function AbstractMCMC.step( + rng::Random.AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::Gibbs{Tsamplingmap}, + gibbs_state::GibbsState; + kwargs..., +) where {Tsamplingmap} + trace = gibbs_state.trace + mcmc_states = gibbs_state.mcmc_states + variable_sizes = gibbs_state.variable_sizes + + model_parameter_names = fieldnames(Tsamplingmap) + mcmc_states = map(model_parameter_names) do parameter_variable + sub_sampler = sampler.sampler_map[parameter_variable] + sub_state = mcmc_states[parameter_variable] + variables_to_be_conditioned_on = setdiff( + model_parameter_names, (parameter_variable,) + ) + conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}( + Tuple([trace[g] for g in variables_to_be_conditioned_on]) + ) + cond_logdensity = AbstractPPL.condition( + logdensity_model.logdensity, conditioning_variables_values + ) + cond_logdensity_model = AbstractMCMC.LogDensityModel(cond_logdensity) + + logp = LogDensityProblems.logdensity( + cond_logdensity_model, sub_state; recompute_logp=true + ) + sub_state = constructorof(typeof(sub_state))(sub_state, logp) + sub_state = last( + AbstractMCMC.step( + rng, cond_logdensity_model, sub_sampler, sub_state; kwargs... + ), + ) + trace = update_trace(trace, gibbs_state) + sub_state + end + mcmc_states = NamedTuple{Tuple(model_parameter_names)}(mcmc_states) + + return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes) +end diff --git a/test/gibbs_example/gibbs_test.jl b/test/gibbs_example/gibbs_test.jl new file mode 100644 index 00000000..da42fd34 --- /dev/null +++ b/test/gibbs_example/gibbs_test.jl @@ -0,0 +1,74 @@ +include("gibbs.jl") +include("mh.jl") +# include("gmm.jl") +include("hier_normal.jl") + +@testset "hierarchical normal with gibbs" begin + # generate data + N = 1000 # Number of data points + mu_true = 5 # True mean + tau2_true = 2.0 # True variance + x_data = rand(Distributions.Normal(mu_true, sqrt(tau2_true)), N) + + # Store the generated data in the HierNormal structure + hn = HierNormal((x=x_data,)) + + samples = sample( + hn, + Gibbs(( + mu=RandomWalkMH(0.3), + tau2=IndependentMH(product_distribution([InverseGamma(1, 1)])), + )), + 10000; + initial_params=(mu=[0.0], tau2=[1.0]), + ) + + warmup = 5000 + thin = 10 + thinned_samples = samples[(warmup + 1):thin:end] + mu_samples = [sample.values.mu for sample in thinned_samples] + tau2_samples = [sample.values.tau2 for sample in thinned_samples] + + mu_mean = only(mean(mu_samples)) + tau2_mean = only(mean(tau2_samples)) + + @test mu_mean ≈ mu_true rtol = 0.1 + @test tau2_mean ≈ tau2_true rtol = 0.1 +end + +# This is too difficult to sample, disable for now +# @testset "gmm with gibbs" begin +# w = [0.5, 0.5] +# μ = [-3.5, 0.5] +# mixturemodel = Distributions.MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w) + +# N = 60 +# x = rand(mixturemodel, N) + +# gmm = GMM((; x=x)) + +# samples = sample( +# gmm, +# Gibbs( +# ( +# z = IndependentMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), +# w = IndependentMH(Dirichlet(2, 1.0)), +# μ = RandomWalkMH(1), +# ), +# ), +# 100000; +# initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[-3.5, 0.5], w=[0.3, 0.7]), +# ) + +# z_samples = [sample.values.z for sample in samples][20001:end] +# μ_samples = [sample.values.μ for sample in samples][20001:end] +# w_samples = [sample.values.w for sample in samples][20001:end] + +# # thin these samples +# z_samples = z_samples[1:100:end] +# μ_samples = μ_samples[1:100:end] +# w_samples = w_samples[1:100:end] + +# mean(μ_samples) +# mean(w_samples) +# end diff --git a/test/gibbs_example/gmm.jl b/test/gibbs_example/gmm.jl new file mode 100644 index 00000000..7cfe26f7 --- /dev/null +++ b/test/gibbs_example/gmm.jl @@ -0,0 +1,71 @@ +abstract type AbstractGMM end + +struct GMM <: AbstractGMM + data::NamedTuple +end + +struct ConditionedGMM{conditioned_vars} <: AbstractGMM + data::NamedTuple + conditioned_values::NamedTuple{conditioned_vars} +end + +function log_joint(; μ, w, z, x) + # μ is mean of each component + # w is weights of each component + # z is assignment of each data point + # x is data + + K = 2 # assume we know the number of components + D = 2 # dimension of each data point + N = size(x, 2) # number of data points + logp = 0.0 + + μ_prior = MvNormal(zeros(K), I) + logp += logpdf(μ_prior, μ) + + w_prior = Dirichlet(K, 1.0) + logp += logpdf(w_prior, w) + + z_prior = Categorical(w) + + logp += sum([logpdf(z_prior, z[i]) for i in 1:N]) + + obs_priors = [MvNormal(fill(μₖ, D), I) for μₖ in μ] + for i in 1:N + logp += logpdf(obs_priors[z[i]], x[:, i]) + end + + return logp +end + +function AbstractMCMC.condition(gmm::GMM, conditioned_values::NamedTuple) + return ConditionedGMM(gmm.data, conditioned_values) +end + +function LogDensityProblems.logdensity( + gmm::ConditionedGMM{names}, params::AbstractVector +) where {names} + if Set(names) == Set([:μ, :w]) # conditioned on μ, w, so params are z + return log_joint(; + μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params, x=gmm.data.x + ) + elseif Set(names) == Set([:z, :w]) # conditioned on z, w, so params are μ + return log_joint(; + μ=params, w=gmm.conditioned_values.w, z=gmm.conditioned_values.z, x=gmm.data.x + ) + elseif Set(names) == Set([:z, :μ]) # conditioned on z, μ, so params are w + return log_joint(; + μ=gmm.conditioned_values.μ, w=params, z=gmm.conditioned_values.z, x=gmm.data.x + ) + else + error("Unsupported conditioning configuration.") + end +end + +function LogDensityProblems.capabilities(::GMM) + return LogDensityProblems.LogDensityOrder{0}() +end + +function LogDensityProblems.capabilities(::ConditionedGMM) + return LogDensityProblems.LogDensityOrder{0}() +end diff --git a/test/gibbs_example/hier_normal.jl b/test/gibbs_example/hier_normal.jl new file mode 100644 index 00000000..2e3e381a --- /dev/null +++ b/test/gibbs_example/hier_normal.jl @@ -0,0 +1,70 @@ +using AbstractPPL: AbstractPPL + +abstract type AbstractHierNormal end + +struct HierNormal{Tdata<:NamedTuple} <: AbstractHierNormal + data::Tdata +end + +struct ConditionedHierNormal{Tdata<:NamedTuple,Tconditioned_vars<:NamedTuple} <: + AbstractHierNormal + data::Tdata + + " The variable to be conditioned on and its value" + conditioned_values::Tconditioned_vars +end + +# `mu` and `tau2` are length-1 vectors to make +function log_joint(; mu::Vector{Float64}, tau2::Vector{Float64}, x::Vector{Float64}) + # mu is the mean + # tau2 is the variance + # x is data + + # μ ~ Normal(0, 1) + # τ² ~ InverseGamma(1, 1) + # xᵢ ~ Normal(μ, √τ²) + + logp = 0.0 + mu = only(mu) + tau2 = only(tau2) + + mu_prior = Normal(0, 1) + logp += logpdf(mu_prior, mu) + + tau2_prior = InverseGamma(1, 1) + logp += logpdf(tau2_prior, tau2) + + obs_prior = Normal(mu, sqrt(tau2)) + logp += sum(logpdf(obs_prior, xi) for xi in x) + + return logp +end + +function AbstractPPL.condition(hn::HierNormal, conditioned_values::NamedTuple) + return ConditionedHierNormal(hn.data, conditioned_values) +end + +function LogDensityProblems.logdensity( + hier_normal_model::ConditionedHierNormal{Tdata,Tconditioned_vars}, + params::AbstractVector, +) where {Tdata,Tconditioned_vars} + variable_to_condition = only(fieldnames(Tconditioned_vars)) + data = hier_normal_model.data + conditioned_values = hier_normal_model.conditioned_values + + if variable_to_condition == :mu + return log_joint(; mu=conditioned_values.mu, tau2=params, x=data.x) + elseif variable_to_condition == :tau2 + return log_joint(; mu=params, tau2=conditioned_values.tau2, x=data.x) + else + error("Unsupported conditioning variable: $variable_to_condition") + end +end + +function LogDensityProblems.capabilities(::HierNormal) + return LogDensityProblems.LogDensityOrder{0}() +end + +function LogDensityProblems.capabilities(::ConditionedHierNormal) + return LogDensityProblems.LogDensityOrder{0}() +end diff --git a/test/gibbs_example/mh.jl b/test/gibbs_example/mh.jl new file mode 100644 index 00000000..4068268d --- /dev/null +++ b/test/gibbs_example/mh.jl @@ -0,0 +1,124 @@ +using AbstractMCMC: AbstractMCMC, LogDensityProblems +using Distributions +using Random + +abstract type AbstractMHSampler <: AbstractMCMC.AbstractSampler end + +struct MHState{T} + params::Vector{T} + logp::Float64 +end + +# Interface 3: outer constructor that takes a state and a logp +# This function allows the state to be updated with a new log probability. +function MHState(state::MHState, logp::Float64) + return MHState(state.params, logp) +end + +struct MHTransition{T} + params::Vector{T} +end + +# Interface 1: LogDensityProblems.logdensity +# This function takes the logdensity function and the state (state is defined by the sampler package) +# and returns the logdensity. It allows for optional recomputation of the log probability. +# If recomputation is not needed, it returns the stored log probability from the state. +function LogDensityProblems.logdensity( + logdensity_model::AbstractMCMC.LogDensityModel, state::MHState; recompute_logp=true +) + logdensity_function = logdensity_model.logdensity + return if recompute_logp + AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params) + else + state.logp + end +end + +# Interface 2: Base.vec +# This function takes a state and returns a vector of the parameter values stored in the state. +# It is part of the interface for interacting with the state object. +Base.vec(state::MHState) = state.params + +""" + RandomWalkMH{T} <: AbstractMCMC.AbstractSampler + +A random walk Metropolis-Hastings sampler with a normal proposal distribution. The field σ +is the standard deviation of the proposal distribution. +""" +struct RandomWalkMH{T} <: AbstractMHSampler + σ::T +end + +""" + IndependentMH{T} <: AbstractMCMC.AbstractSampler + +A Metropolis-Hastings sampler with an independent proposal distribution. +""" +struct IndependentMH{T} <: AbstractMHSampler + proposal_dist::T +end + +# the first step of the sampler +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::AbstractMHSampler, + args...; + initial_params, + kwargs..., +) + logdensity_function = logdensity_model.logdensity + transition = MHTransition(initial_params) + state = MHState( + initial_params, + only(LogDensityProblems.logdensity(logdensity_function, initial_params)), + ) + + return transition, state +end + +@inline get_proposal_dist(sampler::RandomWalkMH, current_params::Vector{Float64}) = + MvNormal(current_params, sampler.σ) +@inline get_proposal_dist(sampler::IndependentMH, current_params::Vector{T}) where {T} = + sampler.proposal_dist + +# the subsequent steps of the sampler +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::AbstractMHSampler, + state::MHState, + args...; + kwargs..., +) + logdensity_function = logdensity_model.logdensity + current_params = state.params + proposal_dist = get_proposal_dist(sampler, current_params) + proposed_params = rand(rng, proposal_dist) + logp_proposal = only( + LogDensityProblems.logdensity(logdensity_function, proposed_params) + ) + + if log(rand(rng)) < + compute_log_acceptance_ratio(sampler, state, proposed_params, logp_proposal) + return MHTransition(proposed_params), MHState(proposed_params, logp_proposal) + else + return MHTransition(current_params), MHState(current_params, state.logp) + end +end + +function compute_log_acceptance_ratio( + ::RandomWalkMH, state::MHState, ::Vector{Float64}, logp_proposal::Float64 +) + return min(0, logp_proposal - state.logp) +end + +function compute_log_acceptance_ratio( + sampler::IndependentMH, state::MHState, proposal::Vector{T}, logp_proposal::Float64 +) where {T} + return min( + 0, + logp_proposal - state.logp + logpdf(sampler.proposal_dist, state.params) - + logpdf(sampler.proposal_dist, proposal), + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index 909ae8b3..5ecd67d9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,4 +24,5 @@ include("utils.jl") include("stepper.jl") include("transducer.jl") include("logdensityproblems.jl") + include("gibbs_example/gibbs_test.jl") end