Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using JET.jl to determine if typed varinfo is okay #728

Merged
merged 65 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
361c45e
fixed calls to `to_linked_internal_transform`
torfjelde Nov 27, 2024
545cfab
fixed incorrect call to `acclogp_assume!!`
torfjelde Nov 28, 2024
abd432f
added `determine_varinfo` and an implementation using JET for this
torfjelde Nov 28, 2024
5cd9009
Merge remote-tracking branch 'origin/torfjelde/minor-bugfixes' into t…
torfjelde Nov 28, 2024
d503c3c
made filtering for errors only in the tilde pipeline optional
torfjelde Nov 28, 2024
acb2cb0
formatting
torfjelde Nov 28, 2024
902641f
fixed incorrect comment
torfjelde Nov 28, 2024
d93006b
added test for the branch we were currently imssing
torfjelde Nov 28, 2024
64ff18a
formatting
torfjelde Nov 28, 2024
90c2df0
Merge branch 'master' into torfjelde/minor-bugfixes
torfjelde Nov 28, 2024
a94dbd5
Merge branch 'torfjelde/minor-bugfixes' into torfjelde/determine-varinfo
torfjelde Nov 28, 2024
67723d6
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Nov 28, 2024
3d8ad44
renamed `determine_varinfo` to `determine_suitable_varinfo` with
torfjelde Nov 29, 2024
c06b080
removed now-redundant init used with Requires.jl, since this is no
torfjelde Nov 29, 2024
d1a5bab
`determine_suitable_varinfo` now only performs checks using the
torfjelde Nov 29, 2024
5370e55
formatting
torfjelde Nov 29, 2024
dd408ee
updated error hint
torfjelde Nov 29, 2024
c253e9b
added def of `untyped_varinfo` which takes just `model` and `context`
torfjelde Nov 29, 2024
891b46a
fixed incorrect call to `untyped_varinfo` in `_determine_varinfo_jet`
torfjelde Nov 29, 2024
686ed9f
explicitly call `typed_varinfo` when we want such a thing rather than
torfjelde Nov 29, 2024
d7d785a
`typed_varinfo` and `untyped_varinfo` handles wrapping passed context
torfjelde Nov 29, 2024
dda56ec
use `determine_suitable_varinfo` in `LogDensityFunction` when not con…
torfjelde Nov 29, 2024
46ea18c
formatting
torfjelde Nov 29, 2024
c20ede3
formatting
torfjelde Nov 29, 2024
0b3c36e
fixed a bug in `DynamicPPLJETExt.is_tilde_instance`
torfjelde Nov 29, 2024
f76658a
updated docs
torfjelde Nov 29, 2024
690b017
Update docs/src/internals/varinfo.md
torfjelde Nov 29, 2024
97258f3
added back def of `untyped_varinfo` that shouldn't have been removed +
torfjelde Nov 29, 2024
95bb3a9
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Nov 29, 2024
155ce66
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Nov 29, 2024
4998d08
minor codestyle improvement
torfjelde Nov 29, 2024
5c27677
temporary hack to debug what's happening
torfjelde Nov 30, 2024
99d4df7
more debugging
torfjelde Nov 30, 2024
3b9a9eb
use the `target_modules` kwarg in `report_call` instead of manually
torfjelde Nov 30, 2024
99fb153
formatting
torfjelde Nov 30, 2024
3588597
more debugging
torfjelde Nov 30, 2024
7a302e5
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Nov 30, 2024
040cb54
more debugging
torfjelde Nov 30, 2024
889c370
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Nov 30, 2024
c98fe49
more debugging: try with new bijectors.jl
torfjelde Nov 30, 2024
123b644
formatting
torfjelde Nov 30, 2024
37fabb0
removed the hacky debugging stuff used for the CI
torfjelde Nov 30, 2024
33e5b98
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Nov 30, 2024
7ddec2c
removed now-redudant filtering methods since we use JET's own filters
torfjelde Nov 30, 2024
b6b4bff
bump Bijectors.jl compat entry to 0.15.1 in test so JET.jl tests pass
torfjelde Nov 30, 2024
e07ecdb
moved the JET.jl-dependent experimental `determine_varinfo` into a
torfjelde Dec 3, 2024
8ba8f82
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Dec 3, 2024
9ec1556
forgot to add the experimenta.jl file in previous commit
torfjelde Dec 3, 2024
599488b
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Dec 3, 2024
fa155a4
reverted changes to `default_varinfo` and `LogDensityFunction`
torfjelde Dec 3, 2024
8496968
added a bunch of docs for introduced and existing methods
torfjelde Dec 4, 2024
fd82871
added doctests to `determine_suitable_varinfo`
torfjelde Dec 4, 2024
bb87ba0
added JET.jl as a dep to docs
torfjelde Dec 4, 2024
62c5cd1
fixed referencing in docs
torfjelde Dec 4, 2024
55dc91e
fixed docstring
torfjelde Dec 4, 2024
ae51778
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Dec 4, 2024
a692ec3
fixed doctest
torfjelde Dec 5, 2024
d5eb404
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Dec 5, 2024
17b6ec9
Update Project.toml
torfjelde Dec 5, 2024
bfa88b2
applied suggestions from @mhauru
torfjelde Dec 5, 2024
82578cf
fixed doctests
torfjelde Dec 5, 2024
3aad34f
finally fixed doctests
torfjelde Dec 6, 2024
da3eefe
removed unnecessary `typed_varinfo` and `untyped_varinfo` methods
torfjelde Dec 6, 2024
325c5f9
added filter to ignore source of warnings in doctest
torfjelde Dec 6, 2024
4a17e82
Merge branch 'master' into torfjelde/determine-varinfo
Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
mhauru marked this conversation as resolved.
Show resolved Hide resolved
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand All @@ -29,6 +30,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Expand All @@ -37,6 +39,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
Expand All @@ -55,6 +58,7 @@ Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
Expand All @@ -72,6 +76,7 @@ julia = "1.10"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
90 changes: 90 additions & 0 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
module DynamicPPLJETExt

using DynamicPPL: DynamicPPL
using JET: JET

"""
is_tilde_instance(x)

Return `true` if `x` is a method instance of a tilde function, otherwise `false`.
"""
is_tilde_instance(x) = false
is_tilde_instance(frame::JET.VirtualFrame) = is_tilde_instance(frame.linfo)
is_tilde_instance(mi::Core.MethodInstance) = is_tilde_instance(mi.specTypes.parameters[1])
is_tilde_instance(::Type{typeof(DynamicPPL.tilde_assume!!)}) = true
is_tilde_instance(::Type{typeof(DynamicPPL.tilde_observe!!)}) = true
is_tilde_instance(::Type{typeof(DynamicPPL.dot_tilde_assume!!)}) = true
is_tilde_instance(::Type{typeof(DynamicPPL.dot_tilde_observe!!)}) = true

"""
report_has_error_in_tilde(report)

Return `true` if the given error `report` contains a tilde function in its frames, otherwise `false`.

This is used to filter out reports that occur outside of the tilde pipeline, in an attempt to avoid
warning the user about DynamicPPL doing something wrong when it is in fact an issue with the user's code.
"""
function report_has_error_in_tilde(report)
frames = report.vst
return any(is_tilde_instance, frames)
end

function DynamicPPL.determine_varinfo(
model::DynamicPPL.Model,
context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext();
verbose::Bool=false,
only_tilde::Bool=true,
)
# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(model)
issuccess = true
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

# Let's make sure that both evaluation and sampling doesn't result in type errors.
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo, context
)
result_eval = JET.report_call(f_eval, argtypes_eval)
reports_eval = JET.get_reports(result_eval)
if only_tilde
reports_eval = filter(report_has_error_in_tilde, reports_eval)
end
# If we get reports => we had issues so we use the untyped varinfo.
issuccess &= length(reports_eval) == 0
if issuccess
# Evaluation succeeded, let's try sampling.
f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo, DynamicPPL.SamplingContext(context)
)
result_sample = JET.report_call(f_sample, argtypes_sample)
reports_sample = JET.get_reports(result_sample)
if only_tilde
reports_sample = filter(report_has_error_in_tilde, reports_sample)
end
# If we get reports => we had issues so we use the untyped varinfo.
issuccess &= length(reports_sample) == 0
if !issuccess && verbose
# Show the user the issues.
@warn "Sampling with typed varinfo failed with the following issues:"
for report in reports_sample
@warn report
end
end
elseif verbose
# Show the user the issues.
@warn "Evaluaton with typed varinfo failed with the following issues:"
for report in reports_eval
@warn report
end
end

# If we didn't fail anywhere, we return the type stable one.
return if issuccess
varinfo
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(model)
end
end

end
6 changes: 3 additions & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ function assume(
else
r = init(rng, dist, sampler)
if istrans(vi)
f = to_linked_internal_transform(vi, dist)
f = to_linked_internal_transform(vi, vn, dist)
push!!(vi, vn, f(r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
Expand Down Expand Up @@ -500,7 +500,7 @@ end
# HACK: These methods are only used in the `get_and_set_val!` methods below.
# FIXME: Remove these.
function _link_broadcast_new(vi, vn, dist, r)
b = to_linked_internal_transform(vi, dist)
b = to_linked_internal_transform(vi, vn, dist)
return b(r)
end

Expand Down Expand Up @@ -591,7 +591,7 @@ function get_and_set_val!(
push!!.((vi,), vns, _link_broadcast_new.((vi,), vns, dists, r), dists, (spl,))
# NOTE: Need to add the correction.
# FIXME: This is not great.
acclogp_assume!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
acclogp!!(vi, sum(logabsdetjac.(link_transform.(dists), r)))
# `push!!` sets the trans-flag to `false` by default.
settrans!!.((vi,), true, vns)
else
Expand Down
20 changes: 20 additions & 0 deletions src/model_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,23 @@ function value_iterator_from_chain(vi::AbstractVarInfo, chain)
values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}())
end
end

"""
determine_varinfo(model[, context]; verbose::Bool=false)

Return a suitable varinfo for the given `model`.

This method uses JET.jl in an attempt to determine if the model is compatible with typed varinfo.
If it is, a typed varinfo is returned, otherwise an untyped varinfo is returned.

!!! warning
This requires loading JET.jl before calling this function.

# Arguments
- `model`: The model for which to determine the varinfo.
- `context`: The context to use for the evaluation and sampling. Default: `DefaultContext()`.

# Keyword Arguments
- `verbose`: If `true`, the user will be warned about the issues that caused the fallback to untyped varinfo.
"""
function determine_varinfo end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand Down
78 changes: 78 additions & 0 deletions test/ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
@testset "DynamicPPLJETExt.jl" begin
@testset "determine_varinfo" begin
@model function demo1()
x ~ Bernoulli()
if x
y ~ Normal()
else
z ~ Normal()
end
end
model = demo1()
@test DynamicPPL.determine_varinfo(model; verbose=true) isa
DynamicPPL.UntypedVarInfo

@model demo2() = x ~ Normal()
@test DynamicPPL.determine_varinfo(demo2()) isa DynamicPPL.TypedVarInfo

@model function demo3()
# Just making sure that nothing strange happens when type inference fails.
x = Vector(undef, 1)
x[1] ~ Bernoulli()
if x[1]
y ~ Normal()
else
z ~ Normal()
end
end
@test DynamicPPL.determine_varinfo(demo3(); verbose=true) isa
DynamicPPL.UntypedVarInfo

# Evaluation works (and it would even do so in practice), but sampling
# fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`.
@model function demo4()
x ~ Bernoulli()
if x
y ~ Normal()
else
y ~ Cauchy() # different distibution, but same transformation => should work
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
end
@test DynamicPPL.determine_varinfo(demo4(); verbose=true) isa
DynamicPPL.UntypedVarInfo

# In this model, the type error occurs in the user code rather than in DynamicPPL.
@model function demo5()
x ~ Normal()
xs = Any[]
push!(xs, x)
# `sum(::Vector{Any})` can potentially error unless the dynamic manages to resolve the
# correct `zero` method. As a result, this code will run, but JET will raise this is an issue.
return sum(xs)
end
# Should pass if we're only checking the tilde statements.
@test DynamicPPL.determine_varinfo(demo5(); verbose=true) isa
DynamicPPL.TypedVarInfo
# Should fail if we're including errors in the model body.
@test DynamicPPL.determine_varinfo(demo5(); verbose=true, only_tilde=false) isa
DynamicPPL.UntypedVarInfo
end

@testset "demo models" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
varinfo = DynamicPPL.DynamicPPL.determine_varinfo(model)
# They should all result in typed.
@test varinfo isa DynamicPPL.TypedVarInfo
# But let's also make sure that they're not lying.
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo
)
JET.test_call(f_eval, argtypes_eval)

f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo, DynamicPPL.SamplingContext()
)
JET.test_call(f_sample, argtypes_sample)
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ using Test
using Distributions
using LinearAlgebra # Diagonal

using JET: JET

using Combinatorics: combinations

using DynamicPPL: getargs_dottilde, getargs_tilde, Selector
Expand Down Expand Up @@ -71,6 +73,7 @@ include("test_util.jl")

@testset "extensions" begin
include("ext/DynamicPPLMCMCChainsExt.jl")
include("ext/DynamicPPLJETExt.jl")
end

@testset "ad" begin
Expand Down
Loading