From d06d34cd2529fefe9b18849bee9642133e84bc8d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 28 Jul 2025 14:47:11 -0600 Subject: [PATCH 1/5] Test against Enzyme --- docs/src/api.md | 3 ++- src/Turing.jl | 3 ++- test/ad.jl | 12 ++++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 0b8351eb3..604718b0e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -93,9 +93,10 @@ See the [AD guide](https://turinglang.org/docs/tutorials/docs-10-using-turing-au | Exported symbol | Documentation | Description | |:----------------- |:------------------------------------ |:---------------------- | +| `AutoEnzyme` | [`ADTypes.AutoEnzyme`](@extref) | Enzyme.jl backend | | `AutoForwardDiff` | [`ADTypes.AutoForwardDiff`](@extref) | ForwardDiff.jl backend | -| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend | | `AutoMooncake` | [`ADTypes.AutoMooncake`](@extref) | Mooncake.jl backend | +| `AutoReverseDiff` | [`ADTypes.AutoReverseDiff`](@extref) | ReverseDiff.jl backend | ### Debugging diff --git a/src/Turing.jl b/src/Turing.jl index 0cdbe2458..9a7da6872 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -23,7 +23,7 @@ using Printf: Printf using Random: Random using LinearAlgebra: I -using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake +using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoMooncake, AutoEnzyme const DEFAULT_ADTYPE = ADTypes.AutoForwardDiff() @@ -121,6 +121,7 @@ export AutoForwardDiff, AutoReverseDiff, AutoMooncake, + AutoEnzyme, # Debugging - Turing setprogress!, # Distributions diff --git a/test/ad.jl b/test/ad.jl index dcfe4ef46..1fc8245ec 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -20,6 +20,14 @@ if INCLUDE_MOONCAKE using Mooncake: Mooncake end +const INCLUDE_ENZYME = !IS_PRERELEASE + +if INCLUDE_ENZYME + import Pkg + Pkg.add("Enzyme") + using Enzyme: Enzyme +end + """Element types that are always valid for a VarInfo regardless of ADType.""" const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) @@ -191,6 +199,10 @@ ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false)] if INCLUDE_MOONCAKE push!(ADTYPES, AutoMooncake(; config=nothing)) end +if INCLUDE_ENZYME + push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Forward))) + push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Reverse))) +end # Check that ADTypeCheckContext itself works as expected. @testset "ADTypeCheckContext" begin From 035485e595613eaee193b751bc9913cad0492061 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Jul 2025 16:29:43 -0500 Subject: [PATCH 2/5] Update ad.jl --- test/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 1fc8245ec..c717fb46b 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -200,8 +200,8 @@ if INCLUDE_MOONCAKE push!(ADTYPES, AutoMooncake(; config=nothing)) end if INCLUDE_ENZYME - push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Forward))) - push!(ADTYPES, AutoEnzyme(; mode = set_runtime_activity(Reverse))) + push!(ADTYPES, AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Forward))) + push!(ADTYPES, AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse))) end # Check that ADTypeCheckContext itself works as expected. From e77976565e647b182c2ba75416f32c513a66a2ed Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 28 Jul 2025 17:32:08 -0500 Subject: [PATCH 3/5] Update ad.jl --- test/ad.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/ad.jl b/test/ad.jl index c717fb46b..7d5f8afe3 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -47,6 +47,9 @@ eltypes_by_adtype = Dict( if INCLUDE_MOONCAKE eltypes_by_adtype[AutoMooncake] = (Mooncake.CoDual,) end +if INCLUDE_ENZYME + eltypes_by_adtype[AutoEnzyme] = () +end """ AbstractWrongADBackendError From d221a170e988eecd36eff9ef30000c0030bbac5c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 4 Aug 2025 23:01:17 +0100 Subject: [PATCH 4/5] Fix dictionary type --- test/ad.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 7d5f8afe3..bd70eeaec 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -32,7 +32,7 @@ end const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) """A dictionary mapping ADTypes to the element types they use.""" -eltypes_by_adtype = Dict( +eltypes_by_adtype = Dict{Type,Tuple}( AutoForwardDiff => (ForwardDiff.Dual,), AutoReverseDiff => ( ReverseDiff.TrackedArray, @@ -203,8 +203,8 @@ if INCLUDE_MOONCAKE push!(ADTYPES, AutoMooncake(; config=nothing)) end if INCLUDE_ENZYME - push!(ADTYPES, AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Forward))) - push!(ADTYPES, AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse))) + push!(ADTYPES, AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward))) + push!(ADTYPES, AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))) end # Check that ADTypeCheckContext itself works as expected. From 17371061c3586bb5f20ef7b644d039af62a5f20b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 4 Aug 2025 23:02:06 +0100 Subject: [PATCH 5/5] mark function as const for good measure --- test/ad.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index bd70eeaec..4422135c0 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -203,8 +203,20 @@ if INCLUDE_MOONCAKE push!(ADTYPES, AutoMooncake(; config=nothing)) end if INCLUDE_ENZYME - push!(ADTYPES, AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Forward))) - push!(ADTYPES, AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))) + push!( + ADTYPES, + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Forward), + function_annotation=Enzyme.Const, + ), + ) + push!( + ADTYPES, + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) end # Check that ADTypeCheckContext itself works as expected.