diff --git a/Project.toml b/Project.toml index 5dde3a427..ebc70b5ab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.30.4" +version = "0.30.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -14,6 +14,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" diff --git a/docs/src/api.md b/docs/src/api.md index 638f6f3ee..d7531af0f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -212,6 +212,19 @@ And some which might be useful to determine certain properties of the model base DynamicPPL.has_static_constraints ``` +For determining whether one might have type instabilities in the model, the following can be useful + +```@docs +DynamicPPL.DebugUtils.model_warntype +DynamicPPL.DebugUtils.model_typed +``` + +Interally, the type-checking methods make use of the following method for construction of the call with the argument types: + +```@docs +DynamicPPL.DebugUtils.gen_evaluator_call_with_types +``` + ## Advanced ### Variable names diff --git a/src/debug_utils.jl b/src/debug_utils.jl index a614508bf..dcd3fcc37 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -5,6 +5,7 @@ using ..DynamicPPL: broadcast_safe, AbstractContext, childcontext using Random: Random using Accessors: Accessors +using InteractiveUtils: InteractiveUtils using DocStringExtensions using Distributions @@ -678,4 +679,83 @@ function has_static_constraints( return all_the_same(transforms) end +""" + gen_evaluator_call_with_types(model[, varinfo, context]) + +Generate the evaluator call and the types of the arguments. + +# Arguments +- `model::Model`: The model whose evaluator is of interest. +- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. +- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). + +# Returns +A 2-tuple with the following elements: +- `f`: This is either `model.f` or `Core.kwcall`, depending on whether + the model has keyword arguments. +- `argtypes::Type{<:Tuple}`: The types of the arguments for the evaluator. +""" +function gen_evaluator_call_with_types( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=DefaultContext(), +) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + return if isempty(kwargs) + (model.f, Base.typesof(args...)) + else + (Core.kwcall, Tuple{typeof(kwargs),Core.Typeof(model.f),map(Core.Typeof, args)...}) + end +end + +""" + model_warntype(model[, varinfo, context]; optimize=true) + +Check the type stability of the model's evaluator, warning about any potential issues. + +This simply calls `@code_warntype` on the model's evaluator, filling in internal arguments where needed. + +# Arguments +- `model::Model`: The model to check. +- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. +- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). + +# Keyword Arguments +- `optimize::Bool`: Whether to generate optimized code. Default: `false`. +""" +function model_warntype( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=DefaultContext(); + optimize::Bool=false, +) + ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context) + return InteractiveUtils.code_warntype(ftype, argtypes; optimize=optimize) +end + +""" + model_typed(model[, varinfo, context]; optimize=true) + +Return the type inference for the model's evaluator. + +This simply calls `@code_typed` on the model's evaluator, filling in internal arguments where needed. + +# Arguments +- `model::Model`: The model to check. +- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. +- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). + +# Keyword Arguments +- `optimize::Bool`: Whether to generate optimized code. Default: `true`. +""" +function model_typed( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=DefaultContext(); + optimize::Bool=true, +) + ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context) + return only(InteractiveUtils.code_typed(ftype, argtypes; optimize=optimize)) +end + end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index b1897aa9b..50bb5d4be 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -186,4 +186,18 @@ @test check_model(model; error_on_failure=true) end end + + @testset "model_warntype & model_codetyped" begin + @model demo_without_kwargs(x) = y ~ Normal(x, 1) + @model demo_with_kwargs(x; z=1) = y ~ Normal(x, z) + + for model in [demo_without_kwargs(1.0), demo_with_kwargs(1.0)] + codeinfo, retype = DynamicPPL.DebugUtils.model_typed(model) + @test codeinfo isa Core.CodeInfo + @test retype <: Tuple + + # Just make sure the following is runnable. + @test (DynamicPPL.DebugUtils.model_warntype(model); true) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 099c96f78..a832a0f08 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,8 @@ include("test_util.jl") include("pointwise_logdensities.jl") include("lkj.jl") + + include("debug_utils.jl") end @testset "compat" begin