Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNODE training fails with autodiff=true #725

Open
sathvikbhagavan opened this issue Aug 20, 2023 · 7 comments
Open

NNODE training fails with autodiff=true #725

sathvikbhagavan opened this issue Aug 20, 2023 · 7 comments

Comments

@sathvikbhagavan
Copy link
Member

MWE:

Running one of the tests in NNODE_tests.jl,

using Flux
using Random, NeuralPDE
using OrdinaryDiffEq, Statistics
import OptimizationOptimisers

Random.seed!(100)

# Run a solve on scalars
linear = (u, p, t) -> cos(2pi * t)
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
prob = ODEProblem(linear, u0, tspan)
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))
opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95))

This works -

sol = solve(prob, NeuralPDE.NNODE(chain, opt), dt = 1 / 20.0f0, verbose = true,
            abstol = 1.0f-10, maxiters = 200)

This errors out -

sol = solve(prob, NeuralPDE.NNODE(chain, opt; autodiff=true), dt = 1 / 20.0f0, verbose = true,
            abstol = 1.0f-10, maxiters = 200)

Stacktrace:

julia> sol = solve(prob, NeuralPDE.NNODE(chain, opt; autodiff=true), dt = 1 / 20.0f0, verbose = true,
                   abstol = 1.0f-10, maxiters = 200)
WARNING: both DomainSets and SciMLBase export "islinear"; uses of it in module NeuralPDE must be qualified
WARNING: both DomainSets and SciMLBase export "isconstant"; uses of it in module NeuralPDE must be qualified
WARNING: both DomainSets and SciMLBase export "issquare"; uses of it in module NeuralPDE must be qualified
┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).typeof(f) = NeuralPDE.var"#150#151"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, Vector{Float32}}
└ @ Zygote ~/.julia/packages/Zygote/4rucm/src/lib/forward.jl:150
ERROR: MethodError: Cannot `convert` an object of type Nothing to an object of type Float32

Closest candidates are:
  convert(::Type{T}, ::Unitful.Gain) where T<:Real
   @ Unitful ~/.julia/packages/Unitful/PMWWU/src/logarithm.jl:62
  convert(::Type{T}, ::Unitful.Level) where T<:Real
   @ Unitful ~/.julia/packages/Unitful/PMWWU/src/logarithm.jl:22
  convert(::Type{T}, ::Unitful.Quantity) where T<:Real
   @ Unitful ~/.julia/packages/Unitful/PMWWU/src/conversion.jl:139
  ...

Stacktrace:
  [1] fill!(dest::Vector{Float32}, x::Nothing)
    @ Base ./array.jl:347
  [2] copyto!
    @ ./broadcast.jl:934 [inlined]
  [3] materialize!
    @ ./broadcast.jl:884 [inlined]
  [4] materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(identity), Tuple{Base.RefValue{Nothing}}})
    @ Base.Broadcast ./broadcast.jl:881
  [5] (::OptimizationZygoteExt.var"#20#29"{OptimizationZygoteExt.var"#19#28"{OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}}})(::Vector{Float32}, ::Vector{Float32})
    @ OptimizationZygoteExt ~/.julia/packages/Optimization/72eCu/ext/OptimizationZygoteExt.jl:56
  [6] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/wD0eI/src/OptimizationOptimisers.jl:65 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/Optimization/72eCu/src/utils.jl:37 [inlined]
  [8] __solve(cache::Optimization.OptimizationCache{OptimizationFunction{false, ADTypes.AutoZygote, OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, OptimizationZygoteExt.var"#20#29"{OptimizationZygoteExt.var"#19#28"{OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}}}, OptimizationZygoteExt.var"#23#32"{OptimizationZygoteExt.var"#19#28"{OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}}}, OptimizationZygoteExt.var"#27#36", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}, Nothing, Nothing, Nothing, Nothing, Nothing, Optimisers.Adam{Float64}, Base.Iterators.Cycle{Tuple{Optimization.NullData}}, Bool, NeuralPDE.var"#176#180"{Float32, Bool}})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/wD0eI/src/OptimizationOptimisers.jl:63
  [9] solve!(cache::Optimization.OptimizationCache{OptimizationFunction{false, ADTypes.AutoZygote, OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, OptimizationZygoteExt.var"#20#29"{OptimizationZygoteExt.var"#19#28"{OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}}}, OptimizationZygoteExt.var"#23#32"{OptimizationZygoteExt.var"#19#28"{OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}}}, OptimizationZygoteExt.var"#27#36", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Optimization.ReInitCache{Vector{Float32}, SciMLBase.NullParameters}, Nothing, Nothing, Nothing, Nothing, Nothing, Optimisers.Adam{Float64}, Base.Iterators.Cycle{Tuple{Optimization.NullData}}, Bool, NeuralPDE.var"#176#180"{Float32, Bool}})
    @ SciMLBase ~/.julia/packages/SciMLBase/kTUaf/src/solve.jl:162
 [10] solve(::OptimizationProblem{true, OptimizationFunction{true, ADTypes.AutoZygote, NeuralPDE.var"#total_loss#179"{Nothing, NeuralPDE.var"#loss#161"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}}}, Float32, Float32, Nothing}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Bool, SciMLBase.NullParameters, Bool, StepRangeLen{Float32, Float64, Float64, Int64}}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::Optimisers.Adam{Float64}; kwargs::Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol}, NamedTuple{(:callback, :maxiters), Tuple{NeuralPDE.var"#176#180"{Float32, Bool}, Int64}}})
    @ SciMLBase ~/.julia/packages/SciMLBase/kTUaf/src/solve.jl:83
 [11] __solve(::ODEProblem{Float32, Tuple{Float32, Float32}, false, SciMLBase.NullParameters, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::NNODE{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Optimisers.Adam{Float64}, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing, Nothing}; dt::Float32, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64)
    @ NeuralPDE ~/NeuralPDE.jl/src/ode_solve.jl:455
 [12] __solve
    @ ~/NeuralPDE.jl/src/ode_solve.jl:356 [inlined]
 [13] #solve_call#33
    @ ~/.julia/packages/DiffEqBase/DEv7n/src/solve.jl:511 [inlined]
 [14] solve_call
    @ ~/.julia/packages/DiffEqBase/DEv7n/src/solve.jl:481 [inlined]
 [15] solve_up(prob::ODEProblem{Float32, Tuple{Float32, Float32}, false, SciMLBase.NullParameters, ODEFunction{false, SciMLBase.AutoSpecialize, var"#3#4", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::Float32, p::SciMLBase.NullParameters, args::NNODE{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Optimisers.Adam{Float64}, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing, Nothing}; kwargs::Base.Pairs{Symbol, Real, NTuple{4, Symbol}, NamedTuple{(:dt, :verbose, :abstol, :maxiters), Tuple{Float32, Bool, Float32, Int64}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/DEv7n/src/solve.jl:972
 [16] solve_up
    @ ~/.julia/packages/DiffEqBase/DEv7n/src/solve.jl:945 [inlined]
 [17] #solve#39
    @ ~/.julia/packages/DiffEqBase/DEv7n/src/solve.jl:882 [inlined]
 [18] top-level scope
    @ REPL[13]:1
@sathvikbhagavan
Copy link
Member Author

@ChrisRackauckas is this a known issue?

@ChrisRackauckas
Copy link
Member

It wasn't but now it is.

@sathvikbhagavan
Copy link
Member Author

Updated MWE:

using Flux
using Random, NeuralPDE
using OrdinaryDiffEq, Statistics
import OptimizationOptimisers

Random.seed!(100)

# Run a solve on scalars
linear = (u, p, t) -> cos(2pi * t)
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
prob = ODEProblem(linear, u0, tspan)
chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1))
opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95))

sol = solve(prob, NeuralPDE.NNODE(chain, opt; autodiff=true), dt = 1 / 20.0f0, verbose = true,
            abstol = 1.0f-10, maxiters = 200)

does not error out with the same error. (There is a check which errors out if autodiff is true in #783)

Removing the check gives me this:

julia> sol = solve(prob, NeuralPDE.NNODE(chain, opt; autodiff=true), dt = 1 / 20.0f0, verbose = true,
                   abstol = 1.0f-10, maxiters = 200)
┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).typeof(f) = NeuralPDE.var"#163#164"{NeuralPDE.ODEPhi{Optimisers.Restructure{Chain{Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Int64, bias::Int64, σ::Tuple{}}, @NamedTuple{weight::Int64, bias::Int64, σ::Tuple{}}}}}, Float32, Float32, Nothing}, Vector{Float32}}
└ @ Zygote ~/.julia/packages/Zygote/WOy6z/src/lib/forward.jl:150
Current loss is: 121.3777458193083, Iteration: 1
Current loss is: 121.3777458193083, Iteration: 2
Current loss is: 121.3777458193083, Iteration: 3
Current loss is: 121.3777458193083, Iteration: 4
Current loss is: 121.3777458193083, Iteration: 5
Current loss is: 121.3777458193083, Iteration: 6
Current loss is: 121.3777458193083, Iteration: 7
Current loss is: 121.3777458193083, Iteration: 8
Current loss is: 121.3777458193083, Iteration: 9
Current loss is: 121.3777458193083, Iteration: 10
Current loss is: 121.3777458193083, Iteration: 11
Current loss is: 121.3777458193083, Iteration: 12
Current loss is: 121.3777458193083, Iteration: 13
Current loss is: 121.3777458193083, Iteration: 14
Current loss is: 121.3777458193083, Iteration: 15
Current loss is: 121.3777458193083, Iteration: 16
Current loss is: 121.3777458193083, Iteration: 17
Current loss is: 121.3777458193083, Iteration: 18
Current loss is: 121.3777458193083, Iteration: 19
Current loss is: 121.3777458193083, Iteration: 20
Current loss is: 121.3777458193083, Iteration: 21
Current loss is: 121.3777458193083, Iteration: 22
Current loss is: 121.3777458193083, Iteration: 23
Current loss is: 121.3777458193083, Iteration: 24
Current loss is: 121.3777458193083, Iteration: 25
Current loss is: 121.3777458193083, Iteration: 26
Current loss is: 121.3777458193083, Iteration: 27
Current loss is: 121.3777458193083, Iteration: 28
Current loss is: 121.3777458193083, Iteration: 29
Current loss is: 121.3777458193083, Iteration: 30
Current loss is: 121.3777458193083, Iteration: 31
Current loss is: 121.3777458193083, Iteration: 32
Current loss is: 121.3777458193083, Iteration: 33
Current loss is: 121.3777458193083, Iteration: 34
Current loss is: 121.3777458193083, Iteration: 35
Current loss is: 121.3777458193083, Iteration: 36
Current loss is: 121.3777458193083, Iteration: 37
Current loss is: 121.3777458193083, Iteration: 38
Current loss is: 121.3777458193083, Iteration: 39
Current loss is: 121.3777458193083, Iteration: 40
Current loss is: 121.3777458193083, Iteration: 41
Current loss is: 121.3777458193083, Iteration: 42
Current loss is: 121.3777458193083, Iteration: 43
Current loss is: 121.3777458193083, Iteration: 44
Current loss is: 121.3777458193083, Iteration: 45
Current loss is: 121.3777458193083, Iteration: 46
Current loss is: 121.3777458193083, Iteration: 47
Current loss is: 121.3777458193083, Iteration: 48
Current loss is: 121.3777458193083, Iteration: 49
Current loss is: 121.3777458193083, Iteration: 50
Current loss is: 121.3777458193083, Iteration: 51
Current loss is: 121.3777458193083, Iteration: 52
Current loss is: 121.3777458193083, Iteration: 53
Current loss is: 121.3777458193083, Iteration: 54
Current loss is: 121.3777458193083, Iteration: 55
Current loss is: 121.3777458193083, Iteration: 56
Current loss is: 121.3777458193083, Iteration: 57
Current loss is: 121.3777458193083, Iteration: 58
Current loss is: 121.3777458193083, Iteration: 59
Current loss is: 121.3777458193083, Iteration: 60
Current loss is: 121.3777458193083, Iteration: 61
Current loss is: 121.3777458193083, Iteration: 62
Current loss is: 121.3777458193083, Iteration: 63
Current loss is: 121.3777458193083, Iteration: 64
Current loss is: 121.3777458193083, Iteration: 65
Current loss is: 121.3777458193083, Iteration: 66
Current loss is: 121.3777458193083, Iteration: 67
Current loss is: 121.3777458193083, Iteration: 68
Current loss is: 121.3777458193083, Iteration: 69
Current loss is: 121.3777458193083, Iteration: 70
Current loss is: 121.3777458193083, Iteration: 71
Current loss is: 121.3777458193083, Iteration: 72
Current loss is: 121.3777458193083, Iteration: 73
Current loss is: 121.3777458193083, Iteration: 74
Current loss is: 121.3777458193083, Iteration: 75
Current loss is: 121.3777458193083, Iteration: 76
Current loss is: 121.3777458193083, Iteration: 77
Current loss is: 121.3777458193083, Iteration: 78
Current loss is: 121.3777458193083, Iteration: 79
Current loss is: 121.3777458193083, Iteration: 80
Current loss is: 121.3777458193083, Iteration: 81
Current loss is: 121.3777458193083, Iteration: 82
Current loss is: 121.3777458193083, Iteration: 83
Current loss is: 121.3777458193083, Iteration: 84
Current loss is: 121.3777458193083, Iteration: 85
Current loss is: 121.3777458193083, Iteration: 86
Current loss is: 121.3777458193083, Iteration: 87
Current loss is: 121.3777458193083, Iteration: 88
Current loss is: 121.3777458193083, Iteration: 89
Current loss is: 121.3777458193083, Iteration: 90
Current loss is: 121.3777458193083, Iteration: 91
Current loss is: 121.3777458193083, Iteration: 92
Current loss is: 121.3777458193083, Iteration: 93
Current loss is: 121.3777458193083, Iteration: 94
Current loss is: 121.3777458193083, Iteration: 95
Current loss is: 121.3777458193083, Iteration: 96
Current loss is: 121.3777458193083, Iteration: 97
Current loss is: 121.3777458193083, Iteration: 98
Current loss is: 121.3777458193083, Iteration: 99
Current loss is: 121.3777458193083, Iteration: 100
Current loss is: 121.3777458193083, Iteration: 101
Current loss is: 121.3777458193083, Iteration: 102
Current loss is: 121.3777458193083, Iteration: 103
Current loss is: 121.3777458193083, Iteration: 104
Current loss is: 121.3777458193083, Iteration: 105
Current loss is: 121.3777458193083, Iteration: 106
Current loss is: 121.3777458193083, Iteration: 107
Current loss is: 121.3777458193083, Iteration: 108
Current loss is: 121.3777458193083, Iteration: 109
Current loss is: 121.3777458193083, Iteration: 110
Current loss is: 121.3777458193083, Iteration: 111
Current loss is: 121.3777458193083, Iteration: 112
Current loss is: 121.3777458193083, Iteration: 113
Current loss is: 121.3777458193083, Iteration: 114
Current loss is: 121.3777458193083, Iteration: 115
Current loss is: 121.3777458193083, Iteration: 116
Current loss is: 121.3777458193083, Iteration: 117
Current loss is: 121.3777458193083, Iteration: 118
Current loss is: 121.3777458193083, Iteration: 119
Current loss is: 121.3777458193083, Iteration: 120
Current loss is: 121.3777458193083, Iteration: 121
Current loss is: 121.3777458193083, Iteration: 122
Current loss is: 121.3777458193083, Iteration: 123
Current loss is: 121.3777458193083, Iteration: 124
Current loss is: 121.3777458193083, Iteration: 125
Current loss is: 121.3777458193083, Iteration: 126
Current loss is: 121.3777458193083, Iteration: 127
Current loss is: 121.3777458193083, Iteration: 128
Current loss is: 121.3777458193083, Iteration: 129
Current loss is: 121.3777458193083, Iteration: 130
Current loss is: 121.3777458193083, Iteration: 131
Current loss is: 121.3777458193083, Iteration: 132
Current loss is: 121.3777458193083, Iteration: 133
Current loss is: 121.3777458193083, Iteration: 134
Current loss is: 121.3777458193083, Iteration: 135
Current loss is: 121.3777458193083, Iteration: 136
Current loss is: 121.3777458193083, Iteration: 137
Current loss is: 121.3777458193083, Iteration: 138
Current loss is: 121.3777458193083, Iteration: 139
Current loss is: 121.3777458193083, Iteration: 140
Current loss is: 121.3777458193083, Iteration: 141
Current loss is: 121.3777458193083, Iteration: 142
Current loss is: 121.3777458193083, Iteration: 143
Current loss is: 121.3777458193083, Iteration: 144
Current loss is: 121.3777458193083, Iteration: 145
Current loss is: 121.3777458193083, Iteration: 146
Current loss is: 121.3777458193083, Iteration: 147
Current loss is: 121.3777458193083, Iteration: 148
Current loss is: 121.3777458193083, Iteration: 149
Current loss is: 121.3777458193083, Iteration: 150
Current loss is: 121.3777458193083, Iteration: 151
Current loss is: 121.3777458193083, Iteration: 152
Current loss is: 121.3777458193083, Iteration: 153
Current loss is: 121.3777458193083, Iteration: 154
Current loss is: 121.3777458193083, Iteration: 155
Current loss is: 121.3777458193083, Iteration: 156
Current loss is: 121.3777458193083, Iteration: 157
Current loss is: 121.3777458193083, Iteration: 158
Current loss is: 121.3777458193083, Iteration: 159
Current loss is: 121.3777458193083, Iteration: 160
Current loss is: 121.3777458193083, Iteration: 161
Current loss is: 121.3777458193083, Iteration: 162
Current loss is: 121.3777458193083, Iteration: 163
Current loss is: 121.3777458193083, Iteration: 164
Current loss is: 121.3777458193083, Iteration: 165
Current loss is: 121.3777458193083, Iteration: 166
Current loss is: 121.3777458193083, Iteration: 167
Current loss is: 121.3777458193083, Iteration: 168
Current loss is: 121.3777458193083, Iteration: 169
Current loss is: 121.3777458193083, Iteration: 170
Current loss is: 121.3777458193083, Iteration: 171
Current loss is: 121.3777458193083, Iteration: 172
Current loss is: 121.3777458193083, Iteration: 173
Current loss is: 121.3777458193083, Iteration: 174
Current loss is: 121.3777458193083, Iteration: 175
Current loss is: 121.3777458193083, Iteration: 176
Current loss is: 121.3777458193083, Iteration: 177
Current loss is: 121.3777458193083, Iteration: 178
Current loss is: 121.3777458193083, Iteration: 179
Current loss is: 121.3777458193083, Iteration: 180
Current loss is: 121.3777458193083, Iteration: 181
Current loss is: 121.3777458193083, Iteration: 182
Current loss is: 121.3777458193083, Iteration: 183
Current loss is: 121.3777458193083, Iteration: 184
Current loss is: 121.3777458193083, Iteration: 185
Current loss is: 121.3777458193083, Iteration: 186
Current loss is: 121.3777458193083, Iteration: 187
Current loss is: 121.3777458193083, Iteration: 188
Current loss is: 121.3777458193083, Iteration: 189
Current loss is: 121.3777458193083, Iteration: 190
Current loss is: 121.3777458193083, Iteration: 191
Current loss is: 121.3777458193083, Iteration: 192
Current loss is: 121.3777458193083, Iteration: 193
Current loss is: 121.3777458193083, Iteration: 194
Current loss is: 121.3777458193083, Iteration: 195
Current loss is: 121.3777458193083, Iteration: 196
Current loss is: 121.3777458193083, Iteration: 197
Current loss is: 121.3777458193083, Iteration: 198
Current loss is: 121.3777458193083, Iteration: 199
Current loss is: 121.3777458193083, Iteration: 200
Current loss is: 121.3777458193083, Iteration: 201
retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0f0:0.05f0:1.0f0
u: 21-element Vector{Float32}:
  0.0
  0.006315714
  0.011539376
  0.015674114
  0.018725103
  0.020699587
  0.021606745
  0.021457678
  
 -0.007863174
 -0.015861165
 -0.024744594
 -0.034488622
 -0.04506795
 -0.05645652
 -0.06862798

The loss is constant and the NNODE is not getting trained.

@sathvikbhagavan sathvikbhagavan changed the title NNODE training errors when dt is given with autodiff=true NNODE training with autodiff=true Jan 22, 2024
@sathvikbhagavan sathvikbhagavan changed the title NNODE training with autodiff=true NNODE training fails with autodiff=true Jan 22, 2024
@ChrisRackauckas
Copy link
Member

The Flux type conversion drops duals, so that's something to start with removing. Let's start by transforming everything to Lux first, clean up and delete code, then isolate.

@sathvikbhagavan
Copy link
Member Author

As the Flux removing is done - #789, I visited this back to see what was happening.

With [email protected],

julia> sol = solve(prob, NeuralPDE.NNODE(luxchain, opt, autodiff = true), dt = 1 / 20.0f0, verbose = true,
                   abstol = 1.0f-10, maxiters = 200)
┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).typeof(f) = NeuralPDE.var"#163#164"{NeuralPDE.ODEPhi{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Float32, Float32, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:10, Axis(weight = ViewAxis(1:5, ShapedAxis((5, 1), NamedTuple())), bias = ViewAxis(6:10, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(11:16, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5), NamedTuple())), bias = ViewAxis(6:6, ShapedAxis((1, 1), NamedTuple())))))}}}}
└ @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/forward.jl:150
Current loss is: 133.45653590504156, Iteration: 1
Current loss is: 133.45653590504156, Iteration: 2
Current loss is: 133.45653590504156, Iteration: 3
Current loss is: 133.45653590504156, Iteration: 4
Current loss is: 133.45653590504156, Iteration: 5
Current loss is: 133.45653590504156, Iteration: 6
Current loss is: 133.45653590504156, Iteration: 7
Current loss is: 133.45653590504156, Iteration: 8
Current loss is: 133.45653590504156, Iteration: 9
Current loss is: 133.45653590504156, Iteration: 10
Current loss is: 133.45653590504156, Iteration: 11
Current loss is: 133.45653590504156, Iteration: 12
Current loss is: 133.45653590504156, Iteration: 13
Current loss is: 133.45653590504156, Iteration: 14
Current loss is: 133.45653590504156, Iteration: 15
Current loss is: 133.45653590504156, Iteration: 16
Current loss is: 133.45653590504156, Iteration: 17
Current loss is: 133.45653590504156, Iteration: 18
Current loss is: 133.45653590504156, Iteration: 19
Current loss is: 133.45653590504156, Iteration: 20
Current loss is: 133.45653590504156, Iteration: 21
Current loss is: 133.45653590504156, Iteration: 22
Current loss is: 133.45653590504156, Iteration: 23
Current loss is: 133.45653590504156, Iteration: 24
Current loss is: 133.45653590504156, Iteration: 25
Current loss is: 133.45653590504156, Iteration: 26
Current loss is: 133.45653590504156, Iteration: 27
Current loss is: 133.45653590504156, Iteration: 28
Current loss is: 133.45653590504156, Iteration: 29
Current loss is: 133.45653590504156, Iteration: 30
Current loss is: 133.45653590504156, Iteration: 31
Current loss is: 133.45653590504156, Iteration: 32
Current loss is: 133.45653590504156, Iteration: 33
Current loss is: 133.45653590504156, Iteration: 34
Current loss is: 133.45653590504156, Iteration: 35
Current loss is: 133.45653590504156, Iteration: 36
Current loss is: 133.45653590504156, Iteration: 37
Current loss is: 133.45653590504156, Iteration: 38
Current loss is: 133.45653590504156, Iteration: 39
Current loss is: 133.45653590504156, Iteration: 40
Current loss is: 133.45653590504156, Iteration: 41
Current loss is: 133.45653590504156, Iteration: 42
Current loss is: 133.45653590504156, Iteration: 43
Current loss is: 133.45653590504156, Iteration: 44
Current loss is: 133.45653590504156, Iteration: 45
Current loss is: 133.45653590504156, Iteration: 46
Current loss is: 133.45653590504156, Iteration: 47
Current loss is: 133.45653590504156, Iteration: 48
Current loss is: 133.45653590504156, Iteration: 49
Current loss is: 133.45653590504156, Iteration: 50
Current loss is: 133.45653590504156, Iteration: 51
Current loss is: 133.45653590504156, Iteration: 52
Current loss is: 133.45653590504156, Iteration: 53
Current loss is: 133.45653590504156, Iteration: 54
Current loss is: 133.45653590504156, Iteration: 55
Current loss is: 133.45653590504156, Iteration: 56
Current loss is: 133.45653590504156, Iteration: 57
Current loss is: 133.45653590504156, Iteration: 58
Current loss is: 133.45653590504156, Iteration: 59
Current loss is: 133.45653590504156, Iteration: 60
Current loss is: 133.45653590504156, Iteration: 61
Current loss is: 133.45653590504156, Iteration: 62
Current loss is: 133.45653590504156, Iteration: 63
Current loss is: 133.45653590504156, Iteration: 64
Current loss is: 133.45653590504156, Iteration: 65
Current loss is: 133.45653590504156, Iteration: 66
Current loss is: 133.45653590504156, Iteration: 67
Current loss is: 133.45653590504156, Iteration: 68
Current loss is: 133.45653590504156, Iteration: 69
Current loss is: 133.45653590504156, Iteration: 70
Current loss is: 133.45653590504156, Iteration: 71
Current loss is: 133.45653590504156, Iteration: 72
Current loss is: 133.45653590504156, Iteration: 73
Current loss is: 133.45653590504156, Iteration: 74
Current loss is: 133.45653590504156, Iteration: 75
Current loss is: 133.45653590504156, Iteration: 76
Current loss is: 133.45653590504156, Iteration: 77
Current loss is: 133.45653590504156, Iteration: 78
Current loss is: 133.45653590504156, Iteration: 79
Current loss is: 133.45653590504156, Iteration: 80
Current loss is: 133.45653590504156, Iteration: 81
Current loss is: 133.45653590504156, Iteration: 82
Current loss is: 133.45653590504156, Iteration: 83
Current loss is: 133.45653590504156, Iteration: 84
Current loss is: 133.45653590504156, Iteration: 85
Current loss is: 133.45653590504156, Iteration: 86
Current loss is: 133.45653590504156, Iteration: 87
Current loss is: 133.45653590504156, Iteration: 88
Current loss is: 133.45653590504156, Iteration: 89
Current loss is: 133.45653590504156, Iteration: 90
Current loss is: 133.45653590504156, Iteration: 91
Current loss is: 133.45653590504156, Iteration: 92
Current loss is: 133.45653590504156, Iteration: 93
Current loss is: 133.45653590504156, Iteration: 94
Current loss is: 133.45653590504156, Iteration: 95
Current loss is: 133.45653590504156, Iteration: 96
Current loss is: 133.45653590504156, Iteration: 97
Current loss is: 133.45653590504156, Iteration: 98
Current loss is: 133.45653590504156, Iteration: 99
Current loss is: 133.45653590504156, Iteration: 100
Current loss is: 133.45653590504156, Iteration: 101
Current loss is: 133.45653590504156, Iteration: 102
Current loss is: 133.45653590504156, Iteration: 103
Current loss is: 133.45653590504156, Iteration: 104
Current loss is: 133.45653590504156, Iteration: 105
Current loss is: 133.45653590504156, Iteration: 106
Current loss is: 133.45653590504156, Iteration: 107
Current loss is: 133.45653590504156, Iteration: 108
Current loss is: 133.45653590504156, Iteration: 109
Current loss is: 133.45653590504156, Iteration: 110
Current loss is: 133.45653590504156, Iteration: 111
Current loss is: 133.45653590504156, Iteration: 112
Current loss is: 133.45653590504156, Iteration: 113
Current loss is: 133.45653590504156, Iteration: 114
Current loss is: 133.45653590504156, Iteration: 115
Current loss is: 133.45653590504156, Iteration: 116
Current loss is: 133.45653590504156, Iteration: 117
Current loss is: 133.45653590504156, Iteration: 118
Current loss is: 133.45653590504156, Iteration: 119
Current loss is: 133.45653590504156, Iteration: 120
Current loss is: 133.45653590504156, Iteration: 121
Current loss is: 133.45653590504156, Iteration: 122
Current loss is: 133.45653590504156, Iteration: 123
Current loss is: 133.45653590504156, Iteration: 124
Current loss is: 133.45653590504156, Iteration: 125
Current loss is: 133.45653590504156, Iteration: 126
Current loss is: 133.45653590504156, Iteration: 127
Current loss is: 133.45653590504156, Iteration: 128
Current loss is: 133.45653590504156, Iteration: 129
Current loss is: 133.45653590504156, Iteration: 130
Current loss is: 133.45653590504156, Iteration: 131
Current loss is: 133.45653590504156, Iteration: 132
Current loss is: 133.45653590504156, Iteration: 133
Current loss is: 133.45653590504156, Iteration: 134
Current loss is: 133.45653590504156, Iteration: 135
Current loss is: 133.45653590504156, Iteration: 136
Current loss is: 133.45653590504156, Iteration: 137
Current loss is: 133.45653590504156, Iteration: 138
Current loss is: 133.45653590504156, Iteration: 139
Current loss is: 133.45653590504156, Iteration: 140
Current loss is: 133.45653590504156, Iteration: 141
Current loss is: 133.45653590504156, Iteration: 142
Current loss is: 133.45653590504156, Iteration: 143
Current loss is: 133.45653590504156, Iteration: 144
Current loss is: 133.45653590504156, Iteration: 145
Current loss is: 133.45653590504156, Iteration: 146
Current loss is: 133.45653590504156, Iteration: 147
Current loss is: 133.45653590504156, Iteration: 148
Current loss is: 133.45653590504156, Iteration: 149
Current loss is: 133.45653590504156, Iteration: 150
Current loss is: 133.45653590504156, Iteration: 151
Current loss is: 133.45653590504156, Iteration: 152
Current loss is: 133.45653590504156, Iteration: 153
Current loss is: 133.45653590504156, Iteration: 154
Current loss is: 133.45653590504156, Iteration: 155
Current loss is: 133.45653590504156, Iteration: 156
Current loss is: 133.45653590504156, Iteration: 157
Current loss is: 133.45653590504156, Iteration: 158
Current loss is: 133.45653590504156, Iteration: 159
Current loss is: 133.45653590504156, Iteration: 160
Current loss is: 133.45653590504156, Iteration: 161
Current loss is: 133.45653590504156, Iteration: 162
Current loss is: 133.45653590504156, Iteration: 163
Current loss is: 133.45653590504156, Iteration: 164
Current loss is: 133.45653590504156, Iteration: 165
Current loss is: 133.45653590504156, Iteration: 166
Current loss is: 133.45653590504156, Iteration: 167
Current loss is: 133.45653590504156, Iteration: 168
Current loss is: 133.45653590504156, Iteration: 169
Current loss is: 133.45653590504156, Iteration: 170
Current loss is: 133.45653590504156, Iteration: 171
Current loss is: 133.45653590504156, Iteration: 172
Current loss is: 133.45653590504156, Iteration: 173
Current loss is: 133.45653590504156, Iteration: 174
Current loss is: 133.45653590504156, Iteration: 175
Current loss is: 133.45653590504156, Iteration: 176
Current loss is: 133.45653590504156, Iteration: 177
Current loss is: 133.45653590504156, Iteration: 178
Current loss is: 133.45653590504156, Iteration: 179
Current loss is: 133.45653590504156, Iteration: 180
Current loss is: 133.45653590504156, Iteration: 181
Current loss is: 133.45653590504156, Iteration: 182
Current loss is: 133.45653590504156, Iteration: 183
Current loss is: 133.45653590504156, Iteration: 184
Current loss is: 133.45653590504156, Iteration: 185
Current loss is: 133.45653590504156, Iteration: 186
Current loss is: 133.45653590504156, Iteration: 187
Current loss is: 133.45653590504156, Iteration: 188
Current loss is: 133.45653590504156, Iteration: 189
Current loss is: 133.45653590504156, Iteration: 190
Current loss is: 133.45653590504156, Iteration: 191
Current loss is: 133.45653590504156, Iteration: 192
Current loss is: 133.45653590504156, Iteration: 193
Current loss is: 133.45653590504156, Iteration: 194
Current loss is: 133.45653590504156, Iteration: 195
Current loss is: 133.45653590504156, Iteration: 196
Current loss is: 133.45653590504156, Iteration: 197
Current loss is: 133.45653590504156, Iteration: 198
Current loss is: 133.45653590504156, Iteration: 199
Current loss is: 133.45653590504156, Iteration: 200
Current loss is: 133.45653590504156, Iteration: 201
retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0f0:0.05f0:1.0f0
u: 21-element Vector{Float32}:
  0.0
 -0.026961738
 -0.054784633
 -0.08346676
 -0.1130049
 -0.14339462
 -0.17463014
 -0.20670454
 -0.23960975
 -0.27333638
 -0.30787417
 -0.3432115
 -0.37933594
 -0.41623402
 -0.4538912
 -0.49229237
 -0.53142136
 -0.5712613
 -0.61179453
 -0.6530033
 -0.6948684

where the loss remains constant.

But with [email protected], I get an error:

ulia> sol = solve(prob, NeuralPDE.NNODE(luxchain, opt, autodiff = true), dt = 1 / 20.0f0, verbose = true,
                   abstol = 1.0f-10, maxiters = 200)
┌ Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).typeof(f) = NeuralPDE.var"#163#164"{NeuralPDE.ODEPhi{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{true, typeof(sigmoid_fast), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}, layer_2::Lux.Dense{true, typeof(identity), typeof(WeightInitializers.glorot_uniform), typeof(WeightInitializers.zeros32)}}, Nothing}, Float32, Float32, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:10, Axis(weight = ViewAxis(1:5, ShapedAxis((5, 1), NamedTuple())), bias = ViewAxis(6:10, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(11:16, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5), NamedTuple())), bias = ViewAxis(6:6, ShapedAxis((1, 1), NamedTuple())))))}}}}
└ @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/forward.jl:150
ERROR: MethodError: no method matching zero(::Type{ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{}}}})

Closest candidates are:
  zero(::Type{Union{}}, Any...)
   @ Base number.jl:310
  zero(::Type{Dates.Time})
   @ Dates ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Dates/src/types.jl:440
  zero(::Type{Pkg.Resolve.FieldValue})
   @ Pkg ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Pkg/src/Resolve/fieldvalues.jl:38
  ...

Stacktrace:
  [1] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{}})(::ComponentArrays.ComponentVector{Float32, Vector{…}, Tuple{…}}, ::ComponentArrays.ComponentVector{Float32, Vector{…}, Tuple{…}})
    @ OptimizationZygoteExt ~/.julia/packages/Optimization/79XSq/ext/OptimizationZygoteExt.jl:93
  [2] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/Optimization/79XSq/src/utils.jl:41 [inlined]
  [4] __solve(cache::Optimization.OptimizationCache{OptimizationFunction{…}, Optimization.ReInitCache{…}, Nothing, Nothing, Nothing, Nothing, Nothing, Optimisers.Adam, Base.Iterators.Cycle{…}, Bool, NeuralPDE.var"#192#196"{…}})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
  [5] solve!(cache::Optimization.OptimizationCache{OptimizationFunction{…}, Optimization.ReInitCache{…}, Nothing, Nothing, Nothing, Nothing, Nothing, Optimisers.Adam, Base.Iterators.Cycle{…}, Bool, NeuralPDE.var"#192#196"{…}})
    @ SciMLBase ~/.julia/packages/SciMLBase/slQep/src/solve.jl:179
  [6] solve(::OptimizationProblem{true, OptimizationFunction{…}, ComponentArrays.ComponentVector{…}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, @Kwargs{}}, ::Optimisers.Adam; kwargs::@Kwargs{callback::NeuralPDE.var"#192#196"{}, maxiters::Int64})
    @ SciMLBase ~/.julia/packages/SciMLBase/slQep/src/solve.jl:96
  [7] __solve(::ODEProblem{…}, ::NNODE{…}; dt::Float32, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64, tstops::Nothing)
    @ NeuralPDE ~/NeuralPDE.jl/src/ode_solve.jl:489
  [8] __solve
    @ ~/NeuralPDE.jl/src/ode_solve.jl:373 [inlined]
  [9] solve_call(_prob::ODEProblem{…}, args::NNODE{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:609
 [10] solve_call
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:567 [inlined]
 [11] #solve_up#42
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:1058 [inlined]
 [12] solve_up
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:1044 [inlined]
 [13] #solve#40
    @ DiffEqBase ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:981 [inlined]
 [14] top-level scope
    @ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.

which is because of SciML/Optimization.jl#679. This traces from the ForwardDiff.jacobian for computing derivatives in the equation.

@ChrisRackauckas
Copy link
Member

Instead of using ForwardDiff.jacobian, we could do the dual evaluation directly. It's the same as this:

https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/src/derivative_wrappers.jl#L84-L103

        T = typeof(ForwardDiff.Tag(NeuralPDETag(), eltype(t)))
        tdual = Dual{T, eltype(df), 1}(t, ForwardDiff.Partials((one(typeof(t)),)))
        first.(ForwardDiff.partials.(phi(tdual, θ)))

and add a struct NeuralPDETag end. Doing it like this keeps the math intact and removes the higher level interface, so we just differentiate it directly. Since this definition is completely non-mutating it should just work.

@sathvikbhagavan
Copy link
Member Author

Ok, will try this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants