Skip to content

Commit

Permalink
typed_varinfo and untyped_varinfo handles wrapping passed context
Browse files Browse the repository at this point in the history
in sampling context now so no need to handle this explicitly elsewhere
  • Loading branch information
torfjelde committed Nov 29, 2024
1 parent 686ed9f commit d7d785a
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,8 @@ end
function DynamicPPL._determine_varinfo_jet(
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_tilde::Bool=true
)
# We need a sampling context in the stack to initialize the varinfo.
sampling_context = if DynamicPPL.hassampler(context)
context
else
DynamicPPL.typed_varinfo(model, DynamicPPL.SamplingContext(context))
end
# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(model, sampling_context)
varinfo = DynamicPPL.typed_varinfo(model, context)
issuccess = true

# Let's make sure that both evaluation and sampling doesn't result in type errors.
Expand All @@ -78,7 +72,7 @@ function DynamicPPL._determine_varinfo_jet(
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(model, sampling_context)
DynamicPPL.untyped_varinfo(model, context)
end
end

Expand Down

0 comments on commit d7d785a

Please sign in to comment.