Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using JET.jl to determine if typed varinfo is okay #728

Open
wants to merge 45 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
361c45e
fixed calls to `to_linked_internal_transform`
torfjelde Nov 27, 2024
545cfab
fixed incorrect call to `acclogp_assume!!`
torfjelde Nov 28, 2024
abd432f
added `determine_varinfo` and an implementation using JET for this
torfjelde Nov 28, 2024
5cd9009
Merge remote-tracking branch 'origin/torfjelde/minor-bugfixes' into t…
torfjelde Nov 28, 2024
d503c3c
made filtering for errors only in the tilde pipeline optional
torfjelde Nov 28, 2024
acb2cb0
formatting
torfjelde Nov 28, 2024
902641f
fixed incorrect comment
torfjelde Nov 28, 2024
d93006b
added test for the branch we were currently imssing
torfjelde Nov 28, 2024
64ff18a
formatting
torfjelde Nov 28, 2024
90c2df0
Merge branch 'master' into torfjelde/minor-bugfixes
torfjelde Nov 28, 2024
a94dbd5
Merge branch 'torfjelde/minor-bugfixes' into torfjelde/determine-varinfo
torfjelde Nov 28, 2024
67723d6
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Nov 28, 2024
3d8ad44
renamed `determine_varinfo` to `determine_suitable_varinfo` with
torfjelde Nov 29, 2024
c06b080
removed now-redundant init used with Requires.jl, since this is no
torfjelde Nov 29, 2024
d1a5bab
`determine_suitable_varinfo` now only performs checks using the
torfjelde Nov 29, 2024
5370e55
formatting
torfjelde Nov 29, 2024
dd408ee
updated error hint
torfjelde Nov 29, 2024
c253e9b
added def of `untyped_varinfo` which takes just `model` and `context`
torfjelde Nov 29, 2024
891b46a
fixed incorrect call to `untyped_varinfo` in `_determine_varinfo_jet`
torfjelde Nov 29, 2024
686ed9f
explicitly call `typed_varinfo` when we want such a thing rather than
torfjelde Nov 29, 2024
d7d785a
`typed_varinfo` and `untyped_varinfo` handles wrapping passed context
torfjelde Nov 29, 2024
dda56ec
use `determine_suitable_varinfo` in `LogDensityFunction` when not con…
torfjelde Nov 29, 2024
46ea18c
formatting
torfjelde Nov 29, 2024
c20ede3
formatting
torfjelde Nov 29, 2024
0b3c36e
fixed a bug in `DynamicPPLJETExt.is_tilde_instance`
torfjelde Nov 29, 2024
f76658a
updated docs
torfjelde Nov 29, 2024
690b017
Update docs/src/internals/varinfo.md
torfjelde Nov 29, 2024
97258f3
added back def of `untyped_varinfo` that shouldn't have been removed +
torfjelde Nov 29, 2024
95bb3a9
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Nov 29, 2024
155ce66
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Nov 29, 2024
4998d08
minor codestyle improvement
torfjelde Nov 29, 2024
5c27677
temporary hack to debug what's happening
torfjelde Nov 30, 2024
99d4df7
more debugging
torfjelde Nov 30, 2024
3b9a9eb
use the `target_modules` kwarg in `report_call` instead of manually
torfjelde Nov 30, 2024
99fb153
formatting
torfjelde Nov 30, 2024
3588597
more debugging
torfjelde Nov 30, 2024
7a302e5
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Nov 30, 2024
040cb54
more debugging
torfjelde Nov 30, 2024
889c370
Merge branch 'master' into torfjelde/determine-varinfo
torfjelde Nov 30, 2024
c98fe49
more debugging: try with new bijectors.jl
torfjelde Nov 30, 2024
123b644
formatting
torfjelde Nov 30, 2024
37fabb0
removed the hacky debugging stuff used for the CI
torfjelde Nov 30, 2024
33e5b98
Merge remote-tracking branch 'origin/torfjelde/determine-varinfo' int…
torfjelde Nov 30, 2024
7ddec2c
removed now-redudant filtering methods since we use JET's own filters
torfjelde Nov 30, 2024
b6b4bff
bump Bijectors.jl compat entry to 0.15.1 in test so JET.jl tests pass
torfjelde Nov 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand All @@ -29,6 +30,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Expand All @@ -37,6 +39,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
Expand All @@ -55,6 +58,7 @@ Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
Expand All @@ -72,6 +76,7 @@ julia = "1.10"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
4 changes: 1 addition & 3 deletions docs/src/internals/varinfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ For example, with the model above we have

```@example varinfo-design
# Type-unstable `VarInfo`
varinfo_untyped = DynamicPPL.untyped_varinfo(
demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()
)
varinfo_untyped = DynamicPPL.untyped_varinfo(demo())
typeof(varinfo_untyped.metadata)
```

Expand Down
52 changes: 52 additions & 0 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
module DynamicPPLJETExt

using DynamicPPL: DynamicPPL
using JET: JET

function DynamicPPL.is_suitable_varinfo(
model::DynamicPPL.Model,
context::DynamicPPL.AbstractContext,
varinfo::DynamicPPL.AbstractVarInfo;
only_ddpl::Bool=true,
)
# Let's make sure that both evaluation and sampling doesn't result in type errors.
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo, context
)
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
# This way we don't just fall back to untyped if the user's code is the issue.
result = if only_ddpl
JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),))
else
JET.report_call(f, argtypes)
end
return length(JET.get_reports(result)) == 0, result
end

function DynamicPPL._determine_varinfo_jet(
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
)
# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(model, context)
issuccess = true

# Let's make sure that both evaluation and sampling doesn't result in type errors.
issuccess, result = DynamicPPL.is_suitable_varinfo(model, context, varinfo; only_ddpl)

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug result
end

# If we didn't fail anywhere, we return the type stable one.
return if issuccess
varinfo
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(model, context)
end
end

end
44 changes: 21 additions & 23 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,30 +196,28 @@
include("debug_utils.jl")
using .DebugUtils

if !isdefined(Base, :get_extension)
using Requires
end

@static if !isdefined(Base, :get_extension)
# Better error message if users forget to load the AD package
if isdefined(Base.Experimental, :register_error_hint)
function __init__()
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
"../ext/DynamicPPLChainRulesCoreExt.jl"
)
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
"../ext/DynamicPPLForwardDiffExt.jl"
)
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
"../ext/DynamicPPLReverseDiffExt.jl"
)
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
"../ext/DynamicPPLZygoteRulesExt.jl"
)
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
requires_jet =
exc.f === _determine_varinfo_jet &&
length(argtypes) >= 2 &&
argtypes[1] <: Model &&
argtypes[2] <: AbstractContext
requires_jet |=
exc.f === is_suitable_varinfo &&
length(argtypes) >= 3 &&
argtypes[1] <: Model &&
argtypes[2] <: AbstractContext &&
argtypes[3] <: AbstractVarInfo
if requires_jet
print(

Check warning on line 215 in src/DynamicPPL.jl

View check run for this annotation

Codecov / codecov/patch

src/DynamicPPL.jl#L215

Added line #L215 was not covered by tests
io,
"\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).",
)
end
end
end
end

Expand Down
12 changes: 9 additions & 3 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,16 @@ function LogDensityFunction(
return LogDensityFunction(varinfo, model, SamplingContext(sampler, context))
end

function LogDensityFunction(model::Model, context::Union{Nothing,AbstractContext}=nothing)
# Determine the suitable varinfo for the given model and context.
varinfo = determine_suitable_varinfo(
model, context === nothing ? leafcontext(model.context) : context
)
return LogDensityFunction(varinfo, model, context)
end

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::Union{Nothing,AbstractContext}=nothing,
model::Model, varinfo::AbstractVarInfo, context::Union{Nothing,AbstractContext}=nothing
)
return LogDensityFunction(varinfo, model, context)
end
Expand Down
57 changes: 57 additions & 0 deletions src/model_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,60 @@
values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}())
end
end

"""
is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...)

Check if the `model` supports evaluation using the provided `context` and `varinfo`.

!!! warning
Loading JET.jl is required before calling this function.

# Arguments
- `model`: The model to to verify the support for.
- `context`: The context to use for the model evaluation.
- `varinfo`: The varinfo to verify the support for.

# Keyword Arguments
- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`.

# Returns
- `issuccess`: `true` if the model supports the varinfo, otherwise `false`.
- `report`: The result of `report_call` from JET.jl.
"""
function is_suitable_varinfo end

# Internal hook for JET.jl to overload.
function _determine_varinfo_jet end

"""
determine_suitable_varinfo(model[, context]; verbose::Bool=false, only_ddpl::Bool=true)

Return a suitable varinfo for the given `model`.

See also: [`DynamicPPL.is_suitable_varinfo`](@ref).

!!! warning
For full functionality, this requires JET.jl to be loaded.
If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo.

# Arguments
- `model`: The model for which to determine the varinfo.
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`.

# Keyword Arguments
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl.
"""
function determine_suitable_varinfo(
model::Model, context::AbstractContext=SamplingContext(); only_ddpl::Bool=true
)
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
_determine_varinfo_jet(model, context; only_ddpl)
else
# Warn the user.
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."

Check warning on line 262 in src/model_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/model_utils.jl#L262

Added line #L262 was not covered by tests
# Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat).
typed_varinfo(model, context)

Check warning on line 264 in src/model_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/model_utils.jl#L264

Added line #L264 was not covered by tests
end
end
7 changes: 4 additions & 3 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ function default_varinfo(
context::AbstractContext,
)
init_sampler = initialsampler(sampler)
return VarInfo(rng, model, init_sampler, context)
varinfo = determine_suitable_varinfo(model, SamplingContext(rng, init_sampler, context))
return varinfo
end

function AbstractMCMC.sample(
Expand Down Expand Up @@ -126,7 +127,7 @@ By default, `data` is returned.
loadstate(data) = data

"""
default_chaintype(sampler)
default_chain_type(sampler)

Default type of the chain of posterior samples from `sampler`.
"""
Expand All @@ -140,7 +141,7 @@ Return the sampler that is used for generating the initial parameters when sampl

By default, it returns an instance of [`SampleFromPrior`](@ref).
"""
initialsampler(spl::Sampler) = SampleFromPrior()
initialsampler(spl) = SampleFromPrior()

function set_values!!(
varinfo::AbstractVarInfo,
Expand Down
13 changes: 11 additions & 2 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,23 @@ function untyped_varinfo(
context::AbstractContext=DefaultContext(),
metadata::Union{Metadata,VarNamedVector}=Metadata(),
)
varinfo = VarInfo(metadata)
return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)))
return untyped_varinfo(model, SamplingContext(rng, sampler, context), metadata)
end
function untyped_varinfo(
model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}...
)
return untyped_varinfo(Random.default_rng(), model, args...)
end
function untyped_varinfo(
model::Model,
context::AbstractContext,
metadata::Union{Metadata,VarNamedVector}=Metadata(),
)
varinfo = VarInfo(metadata)
return last(
evaluate!!(model, varinfo, hassampler(context) ? context : SamplingContext(context))
)
end

"""
typed_varinfo([rng, ]model[, sampler, context])
Expand Down
3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand All @@ -31,7 +32,7 @@ ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
Accessors = "0.1"
Bijectors = "0.13.9, 0.14, 0.15"
Bijectors = "0.15.1"
Combinatorics = "1"
Compat = "4.3.0"
Distributions = "0.25"
Expand Down
75 changes: 75 additions & 0 deletions test/ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
@testset "DynamicPPLJETExt.jl" begin
@testset "determine_suitable_varinfo" begin
@model function demo1()
x ~ Bernoulli()
if x
y ~ Normal()
else
z ~ Normal()
end
end
model = demo1()
@test DynamicPPL.determine_suitable_varinfo(model) isa DynamicPPL.UntypedVarInfo

@model demo2() = x ~ Normal()
@test DynamicPPL.determine_suitable_varinfo(demo2()) isa DynamicPPL.TypedVarInfo

@model function demo3()
# Just making sure that nothing strange happens when type inference fails.
x = Vector(undef, 1)
x[1] ~ Bernoulli()
if x[1]
y ~ Normal()
else
z ~ Normal()
end
end
@test DynamicPPL.determine_suitable_varinfo(demo3()) isa DynamicPPL.UntypedVarInfo

# Evaluation works (and it would even do so in practice), but sampling
# fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`.
@model function demo4()
x ~ Bernoulli()
if x
y ~ Normal()
else
y ~ Cauchy() # different distibution, but same transformation
end
end
@test DynamicPPL.determine_suitable_varinfo(demo4()) isa DynamicPPL.UntypedVarInfo

# In this model, the type error occurs in the user code rather than in DynamicPPL.
@model function demo5()
x ~ Normal()
xs = Any[]
push!(xs, x)
# `sum(::Vector{Any})` can potentially error unless the dynamic manages to resolve the
# correct `zero` method. As a result, this code will run, but JET will raise this is an issue.
return sum(xs)
end
# Should pass if we're only checking the tilde statements.
@test DynamicPPL.determine_suitable_varinfo(demo5()) isa DynamicPPL.TypedVarInfo
# Should fail if we're including errors in the model body.
@test DynamicPPL.determine_suitable_varinfo(demo5(); only_ddpl=false) isa
DynamicPPL.UntypedVarInfo
end

@testset "demo models" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
# Use debug logging below.
varinfo = DynamicPPL.DynamicPPL.determine_suitable_varinfo(model)
# They should all result in typed.
@test varinfo isa DynamicPPL.TypedVarInfo
# But let's also make sure that they're not lying.
f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo
)
JET.test_call(f_eval, argtypes_eval)

f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo, DynamicPPL.SamplingContext()
)
JET.test_call(f_sample, argtypes_sample)
end
end
end
Loading