diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 7ec4e8b07..acb719368 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -578,7 +578,7 @@ function DiffEqBase._concrete_solve_adjoint( (Δu[i] isa NoTangent || eltype(Δu) <: NoTangent) && return if Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray || Δ isa Tangent - x = (Δ isa AbstractVectorOfArray || Δ isa Tangent) ? Δu[i] : Δ[i] + x = Δ isa AbstractVectorOfArray ? Δu.u[i] : (Δ isa Tangent ? Δu[i] : Δ[i]) if _save_idxs isa Number _out[_save_idxs] = x[_save_idxs] elseif _save_idxs isa Colon @@ -1681,8 +1681,8 @@ function DiffEqBase._concrete_solve_adjoint(prob::SciMLBase.AbstractODEProblem, function adjoint_sensitivity_backpass(Δ) function df(_out, u, p, t, i) - if Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray - x = Δ isa AbstractVectorOfArray ? Δ.u[i] : Δ[i] + if Δ isa AbstractArray{<:AbstractArray} Δ isa AbstractVectorOfArray || Δ isa Tangent + x = (Δ isa AbstractVectorOfArray || Δ isa Tangent) ? unthunk(Δ.u[i]) : Δ[i] if _save_idxs isa Number _out[_save_idxs] = x[_save_idxs] elseif _save_idxs isa Colon diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index a15fe6cb2..2a978269c 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -689,7 +689,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, ytmp = _tmp5 end - tmp1 .= 0 # should be removed for dλ + Enzyme.make_zero!(tmp1) # should be removed for dλ vec(ytmp) .= vec(y) #if dgrad !== nothing @@ -707,13 +707,18 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, #if dy !== nothing # tmp3 = dy #else - tmp3 .= 0 + Enzyme.make_zero!(tmp3) #end vec(tmp4) .= vec(λ) isautojacvec = get_jacvec(sensealg) + # Correctness over speed + # TODO: Get a fix for `make_zero!` to allow reusing zero'd memory + # https://github.com/EnzymeAD/Enzyme.jl/issues/2400 + _tmp6 = Enzyme.make_zero(_tmp6) + if inplace_sensitivity(S) if W === nothing Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6), diff --git a/src/gauss_adjoint.jl b/src/gauss_adjoint.jl index 11379bd7a..53d351e6f 100644 --- a/src/gauss_adjoint.jl +++ b/src/gauss_adjoint.jl @@ -500,7 +500,13 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand) tmp3, tmp4, tmp6 = paramjac_config vtmp4 = vec(tmp4) vtmp4 .= λ - out .= 0 + Enzyme.make_zero!(out) + + # Correctness over speed + # TODO: Get a fix for `make_zero!` to allow reusing zero'd memory + # https://github.com/EnzymeAD/Enzyme.jl/issues/2400 + tmp6 = Enzyme.make_zero(tmp6) + Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index 692bd3cb8..3ee07ae7e 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -300,7 +300,12 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand) tmp3, tmp4, tmp6 = paramjac_config vtmp4 = vec(tmp4) vtmp4 .= λ - out .= 0 + Enzyme.make_zero!(out) + + # Correctness over speed + # TODO: Get a fix for `make_zero!` to allow reusing zero'd memory + # https://github.com/EnzymeAD/Enzyme.jl/issues/2400 + tmp6 = Enzyme.make_zero(tmp6) Enzyme.autodiff( Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), diff --git a/test/hybrid_de.jl b/test/hybrid_de.jl index 763fce3a7..f31360b25 100644 --- a/test/hybrid_de.jl +++ b/test/hybrid_de.jl @@ -51,4 +51,4 @@ end res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()), ps), Adam(0.05); callback = cba, maxiters = 200) -@test loss_n_ode(res.u, nothing) < 0.4 +@test loss_n_ode(res.u, nothing) < 0.5