Skip to content

Commit

Permalink
Make AdvancedMH compatible with AbstractMCMC 5 (#92)
Browse files Browse the repository at this point in the history
* Make AdvancedMH compatible with AbstractMCMC 5

* Fix typo

* Change is breaking
  • Loading branch information
devmotion authored Oct 28, 2023
1 parent 8ddb81e commit 3749df0
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.7.6"
version = "0.8.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -23,18 +23,18 @@ AdvancedMHMCMCChainsExt = "MCMCChains"
AdvancedMHStructArraysExt = "StructArrays"

[compat]
AbstractMCMC = "4, 5"
AbstractMCMC = "5"
DiffResults = "1"
Distributions = "0.20 - 0.25"
LinearAlgebra = "1.6 - 1.11"
Random = "1.6 - 1.11"
Distributions = "0.25"
FillArrays = "1"
ForwardDiff = "0.10"
LogDensityProblems = "2"
MCMCChains = "5, 6"
MCMCChains = "6.0.4"
Requires = "1"
StructArrays = "0.6"
julia = "1.6"
LinearAlgebra = "1.6"
Random = "1.6"

[extras]
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,21 @@ AdvancedMH.jl implements the interface of [AbstractMCMC](https://github.com/Turi

```julia
# Sample 4 chains from the posterior serially, without thread or process parallelism.
chain = sample(model, RWMH(init_params), MCMCSerial(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
chain = sample(model, spl, MCMCSerial(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)

# Sample 4 chains from the posterior using multiple threads.
chain = sample(model, RWMH(init_params), MCMCThreads(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
chain = sample(model, spl, MCMCThreads(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)

# Sample 4 chains from the posterior using multiple processes.
chain = sample(model, RWMH(init_params), MCMCDistributed(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
chain = sample(model, spl, MCMCDistributed(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
```

## Metropolis-adjusted Langevin algorithm (MALA)

AdvancedMH.jl also offers an implementation of [MALA](https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm) if the `ForwardDiff` and `DiffResults` packages are available.

A `MALA` sampler can be constructed by `MALA(proposal)` where `proposal` is a function that
takes the gradient computed at the current sample. It is required to specify an initial sample `init_params` when calling `sample`.
takes the gradient computed at the current sample. It is required to specify an initial sample `initial_params` when calling `sample`.

```julia
# Import the package.
Expand Down Expand Up @@ -180,7 +180,7 @@ model = DensityModel(density)
spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))

# Sample from the posterior.
chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
chain = sample(model, spl, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
```

### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl)
Expand All @@ -192,5 +192,5 @@ Using our implementation of the `LogDensityProblems.jl` interface above:
```julia
using LogDensityProblemsAD
model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity())
sample(model_with_ad, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
sample(model_with_ad, spl, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
```
6 changes: 3 additions & 3 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ spl = MetropolisHastings(proposal)
When using `MetropolisHastings` with the function `sample`, the following keyword
arguments are allowed:
- `init_params` defines the initial parameterization for your model. If
- `initial_params` defines the initial parameterization for your model. If
none is given, the initial parameters will be drawn from the sampler's proposals.
- `param_names` is a vector of strings to be assigned to parameters. This is only
used if `chain_type=Chains`.
Expand Down Expand Up @@ -77,10 +77,10 @@ function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModelOrLogDensityModel,
sampler::MHSampler;
init_params=nothing,
initial_params=nothing,
kwargs...
)
params = init_params === nothing ? propose(rng, sampler, model) : init_params
params = initial_params === nothing ? propose(rng, sampler, model) : initial_params
transition = AdvancedMH.transition(sampler, model, params)
return transition, transition
end
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ include("util.jl")
val = [0.4, 1.2]

# Sample from the posterior.
chain1 = sample(model, spl1, 10, init_params = val)
chain1 = sample(model, spl1, 10, initial_params = val)

@test chain1[1].params == val
end
Expand Down Expand Up @@ -265,7 +265,7 @@ include("util.jl")
spl1 = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))

# Sample from the posterior with initial parameters.
chain1 = sample(model, spl1, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
chain1 = sample(model, spl1, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])

@test mean(chain1.μ) 0.0 atol=0.1
@test mean(chain1.σ) 1.0 atol=0.1
Expand All @@ -276,7 +276,7 @@ include("util.jl")
admodel,
spl1,
100000;
init_params=ones(2),
initial_params=ones(2),
chain_type=StructArray,
param_names=["μ", "σ"]
)
Expand Down

2 comments on commit 3749df0

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/94285

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.0 -m "<description of version>" 3749df0075ae59fc8bc4ef576a73ce770b75eec0
git push origin v0.8.0

Please sign in to comment.