-
-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient returns error when using ComplexF64
ODE and a custom struct
for parameters
#1146
Comments
There are a couple compounding factors at play here. First its inconsistently defined. Second is that when we were working on #1135 we were missing JuliaArrays/ArrayInterface.jl#456. Third is that the way should also be solved by #1147 |
Hi @DhairyaLGandhi, I tried with the branch of #1147, and I get a different error instead p = rand(T, 4)
Zygote.gradient(my_f, p) ERROR: MethodError: no method matching recursive_copyto!(::Vector{ComplexF64}, ::NTuple{4, ComplexF64})
The function `recursive_copyto!` exists, but no method is defined for this combination of argument types.
Closest candidates are:
recursive_copyto!(::Tuple, ::Tuple)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/parameters_handling.jl:11
recursive_copyto!(::AbstractArray, ::AbstractArray)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/parameters_handling.jl:9
recursive_copyto!(::Any, ::Nothing)
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/parameters_handling.jl:16
...
Stacktrace:
[1] vec_pjac!(out::Vector{…}, λ::Vector{…}, y::Vector{…}, t::Float64, S::SciMLSensitivity.GaussIntegrand{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/gauss_adjoint.jl:492
[2] GaussIntegrand
@ ~/.julia/packages/SciMLSensitivity/qX1o7/src/gauss_adjoint.jl:517 [inlined]
[3] (::SciMLSensitivity.var"#265#266"{…})(out::Vector{…}, u::Vector{…}, t::Float64, integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/gauss_adjoint.jl:558
[4] (::DiffEqCallbacks.SavingIntegrandSumAffect{…})(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ DiffEqCallbacks ~/.julia/packages/DiffEqCallbacks/00gNi/src/integrating_sum.jl:50
[5] apply_discrete_callback!
@ ~/.julia/packages/DiffEqBase/frOsk/src/callbacks.jl:615 [inlined]
[6] apply_discrete_callback!
@ ~/.julia/packages/DiffEqBase/frOsk/src/callbacks.jl:631 [inlined]
[7] handle_callbacks!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/integrators/integrator_utils.jl:355
[8] _loopfooter!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/integrators/integrator_utils.jl:243
[9] loopfooter!
@ ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/integrators/integrator_utils.jl:207 [inlined]
[10] solve!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/solve.jl:579
[11] #__solve#75
@ ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/solve.jl:7 [inlined]
[12] __solve
@ ~/.julia/packages/OrdinaryDiffEqCore/2K6jv/src/solve.jl:1 [inlined]
[13] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:612
[14] solve_call
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:569 [inlined]
[15] #solve_up#53
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:1092 [inlined]
[16] solve_up
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:1078 [inlined]
[17] #solve#51
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:1015 [inlined]
[18] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::Nothing, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/gauss_adjoint.jl:578
[19] _adjoint_sensitivities
@ ~/.julia/packages/SciMLSensitivity/qX1o7/src/gauss_adjoint.jl:531 [inlined]
[20] #adjoint_sensitivities#63
@ ~/.julia/packages/SciMLSensitivity/qX1o7/src/sensitivity_interface.jl:401 [inlined]
[21] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#315"{…})(Δ::ODESolution{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/qX1o7/src/concrete_solve.jl:627
[22] ZBack
@ ~/.julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:212 [inlined]
[23] (::Zygote.var"#kw_zpullback#56"{…})(dy::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/chainrules.jl:238
[24] #294
@ ~/.julia/packages/Zygote/nyzjS/src/lib/lib.jl:206 [inlined]
[25] (::Zygote.var"#2169#back#296"{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[26] #solve#51
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:1015 [inlined]
[27] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[28] #294
@ ~/.julia/packages/Zygote/nyzjS/src/lib/lib.jl:206 [inlined]
[29] (::Zygote.var"#2169#back#296"{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[30] solve
@ ~/.julia/packages/DiffEqBase/frOsk/src/solve.jl:1005 [inlined]
[31] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[32] my_f
@ ~/GitHub/Research/Undef/Autodiff QuantumToolbox/autodiff.jl:158 [inlined]
[33] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface2.jl:0
[34] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:91
[35] gradient(f::Function, args::Vector{ComplexF64})
@ Zygote ~/.julia/packages/Zygote/nyzjS/src/compiler/interface.jl:148
[36] top-level scope
@ ~/GitHub/Research/Undef/Autodiff QuantumToolbox/autodiff.jl:167
Some type information was truncated. Use `show(err)` to see complete types. It's very strange because |
Yes, that's the third point from my comment. The adjoint is a Tuple since that's how the struct is stored in memory. We can add a dispatch to |
Ok. But still I don't understand why the |
I don't know if #1149 is also related to this, where I get a null-gradient when using complex |
Describe the bug 🐞
The calculation of the gradient on a ODE of
Float64
type works when using params as bothVector
or a customstruct
(using SciMLStructures.jl). However, it fails when I simply change the type of the ODE toComplexF64
.It seems that, in the
CompleF64
case, it converts the parameters to aVector
. But they are a customstruct
, sop.p1
doesn't work.It works when using a
Vector
instead of a customstruct
.Expected behavior
Returning the correct gradient as in the
Float64
or as in theComplexF64
case withVector
parameters.Minimal Reproducible Example 👇
Definition of the custom struct
ODE Problem
Gradient Calculation (fails)
Error & Stacktrace⚠️
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
versioninfo()
The text was updated successfully, but these errors were encountered: