diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0d25447dd..5e009b53e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -31,6 +31,7 @@ jobs: - SDE3 version: - '1' + - '1.11' - 'lts' steps: - uses: actions/checkout@v4 diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 7a003b4d3..16e57ee24 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -1,5 +1,8 @@ module SciMLSensitivity +# Enzyme is not compatible with Julia 1.12+ +const ENZYME_ENABLED = VERSION < v"1.12" + using ADTypes: ADTypes, AutoEnzyme, AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote using Accessors: @reset @@ -45,7 +48,9 @@ using OrdinaryDiffEqCore: OrdinaryDiffEqCore, BrownFullBasicInit, DefaultInit, # AD Backends using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent, AbstractThunk, AbstractTangent -using Enzyme: Enzyme +@static if ENZYME_ENABLED + using Enzyme: Enzyme +end using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using Tracker: Tracker, TrackedArray @@ -86,6 +91,8 @@ include("sde_tools.jl") export extract_local_sensitivities +export ENZYME_ENABLED + export ODEForwardSensitivityFunction, ODEForwardSensitivityProblem, SensitivityFunction, ODEAdjointProblem, AdjointSensitivityIntegrand, SDEAdjointProblem, RODEAdjointProblem, SensitivityAlg, diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index a2af3e485..7b35d61c9 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1285,105 +1285,107 @@ function DiffEqBase._concrete_solve_adjoint( p) end -function DiffEqBase._concrete_solve_adjoint( - prob::Union{SciMLBase.AbstractDiscreteProblem, - SciMLBase.AbstractODEProblem, - SciMLBase.AbstractDAEProblem, - SciMLBase.AbstractDDEProblem, - SciMLBase.AbstractSDEProblem, - SciMLBase.AbstractSDDEProblem, - SciMLBase.AbstractRODEProblem - }, - alg, sensealg::EnzymeAdjoint, - u0, p, originator::SciMLBase.ADOriginator, - args...; kwargs...) - kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) - du0 = Enzyme.make_zero(u0) - dp = Enzyme.make_zero(p) - mode = sensealg.mode - - # Force no FunctionWrappers for Enzyme - _prob = remake(prob, f = f = ODEFunction{isinplace(prob), SciMLBase.FullSpecialize}(unwrapped_f(prob.f)) ) - - diff_func = (u0, - p) -> solve(_prob, alg, args...; u0 = u0, p = p, - sensealg = SensitivityADPassThrough(), - kwargs_filtered...) - - splitmode = if mode isa Enzyme.ForwardMode - error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.") - elseif mode === nothing || mode isa Enzyme.ReverseMode - Enzyme.set_runtime_activity(Enzyme.ReverseSplitWithPrimal) - end +@static if ENZYME_ENABLED + function DiffEqBase._concrete_solve_adjoint( + prob::Union{SciMLBase.AbstractDiscreteProblem, + SciMLBase.AbstractODEProblem, + SciMLBase.AbstractDAEProblem, + SciMLBase.AbstractDDEProblem, + SciMLBase.AbstractSDEProblem, + SciMLBase.AbstractSDDEProblem, + SciMLBase.AbstractRODEProblem + }, + alg, sensealg::EnzymeAdjoint, + u0, p, originator::SciMLBase.ADOriginator, + args...; kwargs...) + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) + du0 = Enzyme.make_zero(u0) + dp = Enzyme.make_zero(p) + mode = sensealg.mode - forward, - reverse = Enzyme.autodiff_thunk( - splitmode, Enzyme.Const{typeof(diff_func)}, Enzyme.Duplicated, - Enzyme.Duplicated{typeof(u0)}, Enzyme.Duplicated{typeof(p)}) - tape, result, - shadow_result = forward( - Enzyme.Const(diff_func), Enzyme.Duplicated(copy(u0), du0), Enzyme.Duplicated(copy(p), dp)) - - function enzyme_sensitivity_backpass(Δ) - if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray) - for (x, y) in zip(shadow_result.u, Δ.u) - x .= y - end - else - error("typeof(Δ) = $(typeof(Δ)) is not currently handled in EnzymeAdjoint. Please open an issue with an MWE to add support") + # Force no FunctionWrappers for Enzyme + _prob = remake(prob, f = f = ODEFunction{isinplace(prob), SciMLBase.FullSpecialize}(unwrapped_f(prob.f)) ) + + diff_func = (u0, + p) -> solve(_prob, alg, args...; u0 = u0, p = p, + sensealg = SensitivityADPassThrough(), + kwargs_filtered...) + + splitmode = if mode isa Enzyme.ForwardMode + error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.") + elseif mode === nothing || mode isa Enzyme.ReverseMode + Enzyme.set_runtime_activity(Enzyme.ReverseSplitWithPrimal) end - reverse(Enzyme.Const(diff_func), Enzyme.Duplicated(u0, du0), Enzyme.Duplicated(p, dp), tape) - if originator isa SciMLBase.TrackerOriginator || - originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), du0, dp, NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) - else - (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) + + forward, + reverse = Enzyme.autodiff_thunk( + splitmode, Enzyme.Const{typeof(diff_func)}, Enzyme.Duplicated, + Enzyme.Duplicated{typeof(u0)}, Enzyme.Duplicated{typeof(p)}) + tape, result, + shadow_result = forward( + Enzyme.Const(diff_func), Enzyme.Duplicated(copy(u0), du0), Enzyme.Duplicated(copy(p), dp)) + + function enzyme_sensitivity_backpass(Δ) + if (Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray) + for (x, y) in zip(shadow_result.u, Δ.u) + x .= y + end + else + error("typeof(Δ) = $(typeof(Δ)) is not currently handled in EnzymeAdjoint. Please open an issue with an MWE to add support") + end + reverse(Enzyme.Const(diff_func), Enzyme.Duplicated(u0, du0), Enzyme.Duplicated(p, dp), tape) + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end end + result, enzyme_sensitivity_backpass end - result, enzyme_sensitivity_backpass -end -# NOTE: This is needed to prevent a method ambiguity error -function DiffEqBase._concrete_solve_adjoint( - prob::AbstractNonlinearProblem, alg, sensealg::EnzymeAdjoint, - u0, p, originator::SciMLBase.ADOriginator, - args...; kwargs...) - kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) - - du0 = make_zero(u0) - dp = make_zero(p) - mode = sensealg.mode + # NOTE: This is needed to prevent a method ambiguity error + function DiffEqBase._concrete_solve_adjoint( + prob::AbstractNonlinearProblem, alg, sensealg::EnzymeAdjoint, + u0, p, originator::SciMLBase.ADOriginator, + args...; kwargs...) + kwargs_filtered = NamedTuple(filter(x -> x[1] != :sensealg, kwargs)) - f = (u0, - p) -> solve(prob, alg, args...; u0 = u0, p = p, - sensealg = SensitivityADPassThrough(), - kwargs_filtered...) + du0 = make_zero(u0) + dp = make_zero(p) + mode = sensealg.mode - splitmode = if mode isa Forward - error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.") - elseif mode === nothing || mode === Reverse - ReverseSplitWithPrimal - end + f = (u0, + p) -> solve(prob, alg, args...; u0 = u0, p = p, + sensealg = SensitivityADPassThrough(), + kwargs_filtered...) - forward, - reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated, - Duplicated{typeof(u0)}, Duplicated{typeof(p)}) - tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp)) + splitmode = if mode isa Forward + error("EnzymeAdjoint currently only allows mode=Reverse. File an issue if this is necessary.") + elseif mode === nothing || mode === Reverse + ReverseSplitWithPrimal + end - function enzyme_sensitivity_backpass(Δ) - reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape) - if originator isa SciMLBase.TrackerOriginator || - originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), du0, dp, NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) - else - (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) + forward, + reverse = autodiff_thunk(splitmode, Const{typeof(f)}, Duplicated, + Duplicated{typeof(u0)}, Duplicated{typeof(p)}) + tape, result, shadow_result = forward(Const(f), Duplicated(u0, du0), Duplicated(p, dp)) + + function enzyme_sensitivity_backpass(Δ) + reverse(Const(f), Duplicated(u0, du0), Duplicated(p, dp), Δ, tape) + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end end + sol, enzyme_sensitivity_backpass end - sol, enzyme_sensitivity_backpass end const ENZYME_TRACKED_REAL_ERROR_MESSAGE = """ diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index f63b8f0ec..124378ea9 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -801,10 +801,21 @@ EnzymeAdjoint(mode = nothing) Currently fails on almost every solver. """ -struct EnzymeAdjoint{M <: Union{Nothing, Enzyme.EnzymeCore.Mode}} <: - AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} - mode::M - EnzymeAdjoint(mode = nothing) = new{typeof(mode)}(mode) +@static if ENZYME_ENABLED + struct EnzymeAdjoint{M <: Union{Nothing, Enzyme.EnzymeCore.Mode}} <: + AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} + mode::M + EnzymeAdjoint(mode = nothing) = new{typeof(mode)}(mode) + end +else + # Dummy type for Julia 1.12+ - Enzyme is not loaded on this version + struct EnzymeAdjoint{M <: Nothing} <: + AbstractAdjointSensitivityAlgorithm{nothing, true, nothing} + mode::M + function EnzymeAdjoint(mode = nothing) + error("EnzymeAdjoint is not supported on Julia 1.12+. Please use a different sensitivity algorithm.") + end + end end """ diff --git a/test/adjoint.jl b/test/adjoint.jl index c6b4ded54..acc1269d4 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -139,16 +139,18 @@ easy_res11 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP(true))) -_, -easy_res12 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, - abstol = 1e-14, - reltol = 1e-14, - sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) -_, -easy_res13 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, - abstol = 1e-14, - reltol = 1e-14, - sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +@static if SciMLSensitivity.ENZYME_ENABLED + _, + easy_res12 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) + _, + easy_res13 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +end _, easy_res14 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, @@ -179,11 +181,13 @@ easy_res143 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true))) -_, -easy_res144 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, - abstol = 1e-14, - reltol = 1e-14, - sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +@static if SciMLSensitivity.ENZYME_ENABLED + _, + easy_res144 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +end _, easy_res145 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, @@ -212,11 +216,13 @@ easy_res143k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, reltol = 1e-14, sensealg = GaussKronrodAdjoint(autojacvec = ReverseDiffVJP(true))) -_, -easy_res144k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, - abstol = 1e-14, - reltol = 1e-14, - sensealg = GaussKronrodAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +@static if SciMLSensitivity.ENZYME_ENABLED + _, + easy_res144k = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg, + abstol = 1e-14, + reltol = 1e-14, + sensealg = GaussKronrodAdjoint(autojacvec = SciMLSensitivity.EnzymeVJP())) +end _, easy_res145k = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discrete = dg, abstol = 1e-14, @@ -1049,34 +1055,36 @@ function dynamics!(du, u, p, t) du[2] = -u[2] + tanh(p[3] * u[1] + p[4] * u[2]) end -function backsolve_grad(sol, lqr_params, checkpointing) - bwd_sol = solve( - ODEAdjointProblem(sol, - BacksolveAdjoint(autojacvec = EnzymeVJP(), - checkpointing = checkpointing), +@static if SciMLSensitivity.ENZYME_ENABLED + function backsolve_grad(sol, lqr_params, checkpointing) + bwd_sol = solve( + ODEAdjointProblem(sol, + BacksolveAdjoint(autojacvec = EnzymeVJP(), + checkpointing = checkpointing), + Tsit5(), + nothing, nothing, nothing, nothing, nothing, + (x, lqr_params, t) -> cost(x, lqr_params)), Tsit5(), - nothing, nothing, nothing, nothing, nothing, - (x, lqr_params, t) -> cost(x, lqr_params)), - Tsit5(), + dense = false, + save_everystep = false) + + bwd_sol.u[end][1:(end - x_dim)] + #fwd_sol, bwd_sol + end + + x0 = ones(x_dim) + fwd_sol = solve(ODEProblem(dynamics!, x0, (0, T), params), + Tsit5(), abstol = 1e-9, reltol = 1e-9, + u0 = x0, + p = params, dense = false, - save_everystep = false) + save_everystep = true) - bwd_sol.u[end][1:(end - x_dim)] - #fwd_sol, bwd_sol -end + backsolve_results = backsolve_grad(fwd_sol, params, false) + backsolve_checkpointing_results = backsolve_grad(fwd_sol, params, true) -x0 = ones(x_dim) -fwd_sol = solve(ODEProblem(dynamics!, x0, (0, T), params), - Tsit5(), abstol = 1e-9, reltol = 1e-9, - u0 = x0, - p = params, - dense = false, - save_everystep = true) - -backsolve_results = backsolve_grad(fwd_sol, params, false) -backsolve_checkpointing_results = backsolve_grad(fwd_sol, params, true) - -@test backsolve_results != backsolve_checkpointing_results + @test backsolve_results != backsolve_checkpointing_results +end int_u0, int_p = adjoint_sensitivities(fwd_sol, Tsit5(), diff --git a/test/autodiff_events.jl b/test/autodiff_events.jl index c799e0f2c..2b7438aa8 100644 --- a/test/autodiff_events.jl +++ b/test/autodiff_events.jl @@ -1,5 +1,5 @@ using SciMLSensitivity -using OrdinaryDiffEq, Calculus, Test +using OrdinaryDiffEq, OrdinaryDiffEqCore, Calculus, Test using Zygote function f(du, u, p, t) @@ -56,11 +56,11 @@ g4 = Zygote.gradient(θ -> test_f2(θ, ReverseDiffAdjoint(), PIController(7 // 5 p) g6 = Zygote.gradient( θ -> test_f2(θ, ForwardDiffSensitivity(), - OrdinaryDiffEq.PredictiveController(), TRBDF2()), + OrdinaryDiffEqCore.PredictiveController(), TRBDF2()), p) @test_broken g7 = Zygote.gradient( θ -> test_f2(θ, ReverseDiffAdjoint(), - OrdinaryDiffEq.PredictiveController(), + OrdinaryDiffEqCore.PredictiveController(), TRBDF2()), p) diff --git a/test/concrete_solve_derivatives.jl b/test/concrete_solve_derivatives.jl index 0a7b14c4a..988939e55 100644 --- a/test/concrete_solve_derivatives.jl +++ b/test/concrete_solve_derivatives.jl @@ -93,15 +93,17 @@ dp7 = Zygote.gradient( sensealg = MooncakeAdjoint())), u0, p) -du08, -dp8 = Zygote.gradient( - (u0, - p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, - abstol = 1e-14, reltol = 1e-14, - saveat = 0.1, - sensealg = EnzymeAdjoint())), - u0, - p) +@static if SciMLSensitivity.ENZYME_ENABLED + du08, + dp8 = Zygote.gradient( + (u0, + p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = EnzymeAdjoint())), + u0, + p) +end @test ū0≈du01 rtol=1e-12 @test ū0 == du02 @@ -110,7 +112,9 @@ dp8 = Zygote.gradient( #@test ū0 ≈ du05 rtol=1e-12 @test ū0≈du06 rtol=1e-12 @test_broken ū0≈du07 rtol=1e-12 -@test ū0≈du08 rtol=1e-12 +@static if SciMLSensitivity.ENZYME_ENABLED + @test ū0≈du08 rtol=1e-12 +end @test adj≈dp1' rtol=1e-12 @test adj == dp2' @test adj≈dp3' rtol=1e-12 @@ -118,7 +122,9 @@ dp8 = Zygote.gradient( #@test adj ≈ dp5' rtol=1e-12 @test adj≈dp6' rtol=1e-12 @test_broken adj≈dp7' rtol=1e-12 -@test adj≈dp8' rtol=1e-12 +@static if SciMLSensitivity.ENZYME_ENABLED + @test adj≈dp8' rtol=1e-12 +end ### ### Direct from prob @@ -406,15 +412,17 @@ dp7 = Zygote.gradient( sensealg = MooncakeAdjoint())), u0, p) -du08, -dp8 = Zygote.gradient( - (u0, - p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, - abstol = 1e-14, reltol = 1e-14, - saveat = 0.1, - sensealg = EnzymeAdjoint())), - u0, - p) +@static if SciMLSensitivity.ENZYME_ENABLED + du08, + dp8 = Zygote.gradient( + (u0, + p) -> sum(solve(proboop, Tsit5(), u0 = u0, p = p, + abstol = 1e-14, reltol = 1e-14, + saveat = 0.1, + sensealg = EnzymeAdjoint())), + u0, + p) +end du09, dp9 = Zygote.gradient( (u0, @@ -441,7 +449,9 @@ dp10 = Zygote.gradient( #@test ū0 ≈ du05 rtol=1e-12 @test ū0≈du06 rtol=1e-12 @test_broken ū0≈du07 rtol=1e-12 -@test ū0≈du08 rtol=1e-12 +@static if SciMLSensitivity.ENZYME_ENABLED + @test ū0≈du08 rtol=1e-12 +end @test ū0≈du09 rtol=1e-12 @test ū0≈du010 rtol=1e-12 @test adj≈dp1' rtol=1e-12 @@ -451,7 +461,9 @@ dp10 = Zygote.gradient( #@test adj ≈ dp5' rtol=1e-12 @test adj≈dp6' rtol=1e-12 @test_broken adj≈dp7' rtol=1e-12 -@test adj≈dp8' rtol=1e-12 +@static if SciMLSensitivity.ENZYME_ENABLED + @test adj≈dp8' rtol=1e-12 +end @test adj≈dp9' rtol=1e-12 @test adj≈dp10' rtol=1e-12 diff --git a/test/runtests.jl b/test/runtests.jl index cfb90641e..13848db88 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,7 +80,9 @@ end if GROUP == "All" || GROUP == "Core6" @testset "Core 6" begin - @time @safetestset "Enzyme Closures" include("enzyme_closure.jl") + if SciMLSensitivity.ENZYME_ENABLED + @time @safetestset "Enzyme Closures" include("enzyme_closure.jl") + end @time @safetestset "Complex Matrix FiniteDiff Adjoint" include("complex_matrix_finitediff.jl") @time @safetestset "Null Parameters" include("null_parameters.jl") @time @safetestset "Forward Mode Prob Kwargs" include("forward_prob_kwargs.jl")