diff --git a/src/sampler.jl b/src/sampler.jl index b9f7b8c41..1bbcbcb9a 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -87,7 +87,7 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... ) # Sample initial values. - vi = typed_varinfo(rng, model, initialsampler(spl), DefaultContext()) + vi = VarInfo(rng, model, initialsampler(spl), DefaultContext()) # Update the parameters if provided. if initial_params !== nothing diff --git a/src/varinfo.jl b/src/varinfo.jl index 11fe08d0f..458a55362 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -163,34 +163,6 @@ function has_varnamedvector(vi::VarInfo) (vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) end -""" - untyped_varinfo([rng, ]model[, sampler, context]) - -Return an untyped `VarInfo` instance for the model `model`. -""" -function untyped_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), - metadata::Union{Metadata,VarNamedVector}=Metadata(), -) - varinfo = VarInfo(metadata) - return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context))) -end -function untyped_varinfo( - model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}... -) - return untyped_varinfo(Random.default_rng(), model, args...) -end - -""" - typed_varinfo([rng, ]model[, sampler, context]) - -Return a typed `VarInfo` instance for the model `model`. -""" -typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...)) - function VarInfo( rng::Random.AbstractRNG, model::Model, @@ -198,7 +170,9 @@ function VarInfo( context::AbstractContext=DefaultContext(), metadata::Union{Metadata,VarNamedVector}=Metadata(), ) - return typed_varinfo(rng, model, sampler, context, metadata) + varinfo = VarInfo(metadata) + untyped_varinfo = last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context))) + return TypedVarInfo(untyped_varinfo) end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)