From 0f289206b22da5feee03218c157afb28706ed8ec Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 06:01:50 +0000 Subject: [PATCH] Define rand defaults for AbstractProbabilisticProgram (#79) This PR adds a 3-arg form of `rand` (suggested by @devmotion in https://github.com/TuringLang/DynamicPPL.jl/issues/466#issuecomment-1436670214) to the interface for `AbstractProbabilisticProgram` and implements the default 1- and 2-arg methods that dispatch to this. Currently tests fail because this breaks the fallbacks for `GraphPPL.Model`, which expects `rand` to forward to its `rand!` method. I'm not certain how we want to define the interface for this `Model`. Co-authored-by: Xianda Sun --- Project.toml | 2 +- src/abstractprobprog.jl | 20 ++++++++++++++++++++ src/graphinfo.jl | 10 +++++++--- test/abstractprobprog.jl | 37 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 test/abstractprobprog.jl diff --git a/Project.toml b/Project.toml index 5d8c9da..9d9a654 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.6.2" +version = "0.6.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/abstractprobprog.jl b/src/abstractprobprog.jl index a6773fb..30b8b35 100644 --- a/src/abstractprobprog.jl +++ b/src/abstractprobprog.jl @@ -1,5 +1,6 @@ using AbstractMCMC using DensityInterface +using Random """ @@ -60,3 +61,22 @@ m = decondition(condition(m, obs)) should hold for generative models `m` and arbitrary `obs`. """ function condition end + + +""" + rand([rng=Random.default_rng()], [T=NamedTuple], model::AbstractProbabilisticProgram) -> T + +Draw a sample from the joint distribution of the model specified by the probabilistic program. + +The sample will be returned as format specified by `T`. +""" +Base.rand(rng::Random.AbstractRNG, ::Type, model::AbstractProbabilisticProgram) +function Base.rand(rng::Random.AbstractRNG, model::AbstractProbabilisticProgram) + return rand(rng, NamedTuple, model) +end +function Base.rand(::Type{T}, model::AbstractProbabilisticProgram) where {T} + return rand(Random.default_rng(), T, model) +end +function Base.rand(model::AbstractProbabilisticProgram) + return rand(Random.default_rng(), NamedTuple, model) +end diff --git a/src/graphinfo.jl b/src/graphinfo.jl index 4c415db..3bd7fcc 100644 --- a/src/graphinfo.jl +++ b/src/graphinfo.jl @@ -444,9 +444,9 @@ function Random.rand!(m::AbstractPPL.GraphPPL.Model{T}) where T end """ - rand!(rng::AbstractRNG, m::Model) + rand(m::Model) -Draw random samples from the model and mutate the node values. +Draw random samples from the model and return the samples as NamedTuple. # Examples @@ -470,11 +470,15 @@ julia> rand(m) (μ = 1.0, s2 = 1.0907695400401212, y = 0.05821954440386368) ``` """ -function Random.rand(rng::AbstractRNG, sm::Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}}) where {Tnames, Tinput, Tvalue, Teval, Tkind} +function Base.rand(rng::AbstractRNG, sm::Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}}) where {Tnames, Tinput, Tvalue, Teval, Tkind} m = deepcopy(sm[]) get_model_values(rand!(rng, m)) end +function Base.rand(rng::AbstractRNG, ::Type{NamedTuple}, m::Model) + rand(rng, Random.SamplerTrivial(m)) +end + """ logdensityof(m::Model) diff --git a/test/abstractprobprog.jl b/test/abstractprobprog.jl new file mode 100644 index 0000000..00230be --- /dev/null +++ b/test/abstractprobprog.jl @@ -0,0 +1,37 @@ +using AbstractPPL +using Random +using Test + +mutable struct RandModel <: AbstractProbabilisticProgram + rng + T +end + +function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::RandModel) where {T} + model.rng = rng + model.T = T + return nothing +end + +@testset "AbstractProbabilisticProgram" begin + @testset "rand defaults" begin + model = RandModel(nothing, nothing) + rand(model) + @test model.rng == Random.default_rng() + @test model.T === NamedTuple + rngs = [Random.default_rng(), Random.MersenneTwister(42)] + Ts = [NamedTuple, Dict] + @testset for T in Ts + model = RandModel(nothing, nothing) + rand(T, model) + @test model.rng == Random.default_rng() + @test model.T === T + end + @testset for rng in rngs + model = RandModel(nothing, nothing) + rand(rng, model) + @test model.rng === rng + @test model.T === NamedTuple + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e210b72..3707090 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ using Test @testset "AbstractPPL.jl" begin include("deprecations.jl") include("varname.jl") + include("abstractprobprog.jl") include("graphinfo/graphinfo.jl") @testset "doctests" begin DocMeta.setdocmeta!(