Skip to content

Commit 891b46a

Browse files
committed
fixed incorrect call to untyped_varinfo in _determine_varinfo_jet
1 parent c253e9b commit 891b46a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

ext/DynamicPPLJETExt.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ end
5151
function DynamicPPL._determine_varinfo_jet(
5252
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_tilde::Bool=true
5353
)
54-
# First we try with the typed varinfo.
55-
varinfo = if DynamicPPL.hassampler(context)
56-
# Don't need to add sampling context for this to work.
57-
DynamicPPL.typed_varinfo(model, context)
54+
# We need a sampling context in the stack to initialize the varinfo.
55+
sampling_context = if DynamicPPL.hassampler(context)
56+
context
5857
else
59-
# Need a sampling context to initialize the varinfo.
6058
DynamicPPL.typed_varinfo(model, DynamicPPL.SamplingContext(context))
6159
end
60+
# First we try with the typed varinfo.
61+
varinfo = DynamicPPL.typed_varinfo(model, sampling_context)
6262
issuccess = true
6363

6464
# Let's make sure that both evaluation and sampling doesn't result in type errors.
@@ -78,7 +78,7 @@ function DynamicPPL._determine_varinfo_jet(
7878
else
7979
# Warn the user that we can't use the type stable one.
8080
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
81-
DynamicPPL.untyped_varinfo(model)
81+
DynamicPPL.untyped_varinfo(model, sampling_context)
8282
end
8383
end
8484

0 commit comments

Comments
 (0)