Skip to content

Commit

Permalink
Simplify varinfo constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Oct 17, 2024
1 parent 632ab09 commit cc2e84d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function AbstractMCMC.step(
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
)
# Sample initial values.
vi = typed_varinfo(rng, model, initialsampler(spl), DefaultContext())
vi = VarInfo(rng, model, initialsampler(spl), DefaultContext())

# Update the parameters if provided.
if initial_params !== nothing
Expand Down
32 changes: 3 additions & 29 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,42 +163,16 @@ function has_varnamedvector(vi::VarInfo)
(vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata)))
end

"""
untyped_varinfo([rng, ]model[, sampler, context])
Return an untyped `VarInfo` instance for the model `model`.
"""
function untyped_varinfo(
rng::Random.AbstractRNG,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
context::AbstractContext=DefaultContext(),
metadata::Union{Metadata,VarNamedVector}=Metadata(),
)
varinfo = VarInfo(metadata)
return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)))
end
function untyped_varinfo(
model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}...
)
return untyped_varinfo(Random.default_rng(), model, args...)
end

"""
typed_varinfo([rng, ]model[, sampler, context])
Return a typed `VarInfo` instance for the model `model`.
"""
typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...))

function VarInfo(
rng::Random.AbstractRNG,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
context::AbstractContext=DefaultContext(),
metadata::Union{Metadata,VarNamedVector}=Metadata(),
)
return typed_varinfo(rng, model, sampler, context, metadata)
varinfo = VarInfo(metadata)
untyped_varinfo = last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)))
return TypedVarInfo(untyped_varinfo)
end
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)

Expand Down

0 comments on commit cc2e84d

Please sign in to comment.