Skip to content

Commit

Permalink
Add getparameters and setparameters!! (#86)
Browse files Browse the repository at this point in the history
* added state_from_transition, parameters and setparameters!!

* Update src/AbstractMCMC.jl

Co-authored-by: David Widmann <[email protected]>

* renamed state_from_transition to updatestate!!

* adhere to julia convention

* added docs

* fixed docs

* fixed docs

* added example for why updatestate!! is useful

* improved MixtureState example

* further improvements to docs

* renamed parameters and setparameters!! to values and setvalues!!

* fixed typo in docs

* fixed documenting values

* improved and fixed some bugs in docs

* fixed typo in docs

* renamed values and setvalues!! to realize and realize!!

* added model to updatestate!!

* Apply suggestions from code review

Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update docs/src/api.md

Co-authored-by: Xianda Sun <[email protected]>

* Apply suggestions from code review

Co-authored-by: Xianda Sun <[email protected]>

* Update docs/src/api.md

Co-authored-by: Xianda Sun <[email protected]>

* version bump

---------

Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Xianda Sun <[email protected]>
  • Loading branch information
5 people authored Oct 12, 2024
1 parent fc8cfa6 commit 467b076
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.4.0"
version = "5.5.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
141 changes: 141 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,144 @@ For chains of this type, AbstractMCMC defines the following two methods.
AbstractMCMC.chainscat
AbstractMCMC.chainsstack
```

## Interacting with states of samplers

To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods:
```@docs
AbstractMCMC.getparams
AbstractMCMC.setparams!!
```
These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers.

### Example: `MixtureSampler`

In a `MixtureSampler` we need two things:
- `components`: collection of samplers.
- `weights`: collection of weights representing the probability of choosing the corresponding sampler.

```julia
struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler
components::C
weights::W
end
```

To implement the state, we need to keep track of a couple of things:
- `index`: the index of the sampler used in this `step`.
- `states`: the current states of _all_ the components.
We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously.
The reason is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc.


```julia
struct MixtureState{S}
index::Int
states::S
end
```
The `step` for a `MixtureSampler` is defined by the following generative process
```math
\begin{aligned}
i &\sim \mathrm{Categorical}(w_1, \dots, w_k) \\
X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1})
\end{aligned}
```
where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler.
[`AbstractMCMC.getparams`](@ref) and [`AbstractMCMC.setparams!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler.

If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code:

```julia
# Update the corresponding state, i.e. `state.states[i]`, using
# the state and transition from the previous iteration.
state_current = AbstractMCMC.setparams!!(
state.states[i],
AbstractMCMC.getparams(state.states[i_prev]),
)

# Take a `step` for this sampler using the updated state.
transition, state_current = AbstractMCMC.step(
rng, model, sampler_current, sampler_state;
kwargs...
)
```

The full [`AbstractMCMC.step`](@ref) implementation would then be something like:

```julia
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler, state; kwargs...)
# Sample the component to use in this `step`.
i = rand(Categorical(sampler.weights))
sampler_current = sampler.components[i]

# Update the corresponding state, i.e. `state.states[i]`, using
# the state and transition from the previous iteration.
i_prev = state.index
state_current = AbstractMCMC.setparams!!(
state.states[i],
AbstractMCMC.getparams(state.states[i_prev]),
)

# Take a `step` for this sampler using the updated state.
transition, state_current = AbstractMCMC.step(
rng, model, sampler_current, state_current;
kwargs...
)

# Create the new states.
# NOTE: Code below will result in `states_new` being a `Vector`.
# If we wanted to allow usage of alternative containers, e.g. `Tuple`,
# it would be better to use something like `@set states[i] = state_current`
# where `@set` is from Setfield.jl.
states_new = map(1:length(state.states)) do j
if j == i
# Replace the i-th state with the new one.
state_current
else
# Otherwise we just carry over the previous ones.
state.states[j]
end
end

# Create the new `MixtureState`.
state_new = MixtureState(i, states_new)

return transition, state_new
end
```

And for the initial [`AbstractMCMC.step`](@ref) we have:

```julia
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler; kwargs...)
# Initialize every state.
transitions_and_states = map(sampler.components) do spl
AbstractMCMC.step(rng, model, spl; kwargs...)
end

# Sample the component to use this `step`.
i = rand(Categorical(sampler.weights))
# Extract the corresponding transition.
transition = first(transitions_and_states[i])
# Extract states.
states = map(last, transitions_and_states)
# Create new `MixtureState`.
state = MixtureState(i, states)

return transition, state
end
```

Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `getparams` and `setparams!!`.


To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do

```julia
sampler = MixtureSampler([sampler1, sampler2], [0.1, 0.9])
transition, state = AbstractMCMC.step(rng, model, sampler)
while ...
transition, state = AbstractMCMC.step(rng, model, sampler, state)
end
```
21 changes: 21 additions & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,27 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr
"""
struct MCMCSerial <: AbstractMCMCEnsemble end

"""
getparams(state[; kwargs...])
Retrieve the values of parameters from the sampler's `state` as a `Vector{<:Real}`.
"""
function getparams end

"""
setparams!!(state, params)
Set the values of parameters in the sampler's `state` from a `Vector{<:Real}`.
This function should follow the `BangBang` interface: mutate `state` in-place if possible and
return the mutated `state`. Otherwise, it should return a new `state` containing the updated parameters.
Although not enforced, it should hold that `setparams!!(state, getparams(state)) == state`. In another
word, the sampler should implement a consistent transformation between its internal representation
and the vector representation of the parameter values.
"""
function setparams!! end

include("samplingstats.jl")
include("logging.jl")
include("interface.jl")
Expand Down

2 comments on commit 467b076

@sunxd3
Copy link
Member

@sunxd3 sunxd3 commented on 467b076 Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

This PR introduce two functions on sampler state, getparams and setparams!!. getparams allows retrieving the values of the parameters as a vector. setparams!! is the counter part of getparams.

This is part of the effort of making interaction between samplers easier. The motivation of this PR is that a "high-order" MCMC (e.g. Gibbs) can use getparams and setparams!! without having to know the exact type of the sampler state, thus paving way to a more generalized sampler design and implementation.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/117303

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v5.5.0 -m "<description of version>" 467b07621975d81856792d0788f0e490c3ecaaa5
git push origin v5.5.0

Please sign in to comment.