-
-
Notifications
You must be signed in to change notification settings - Fork 204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Training Strategies to NNDAE #876
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, now that you have written the strategies for DAEs, next step is to refactor commonalities with this and NNODE such that we don't repeat code. Next is to actually try out the problem in #721 using what is implemented here.
src/dae_solve.jl
Outdated
@@ -47,6 +47,25 @@ function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = | |||
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs) | |||
end | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/dae_solve.jl
Outdated
end | ||
return loss | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/ode_solve.jl
Outdated
@@ -304,6 +304,7 @@ function generate_loss( | |||
return loss | |||
end | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/dae_solve.jl
Outdated
end | ||
|
||
function dfdx(phi::ODEPhi{C, T, U}, t::Number, θ, | ||
autodiff::Bool,differential_vars::AbstractVector) where {C, T, U <: AbstractVector} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
autodiff::Bool,differential_vars::AbstractVector) where {C, T, U <: AbstractVector} | |
autodiff::Bool, differential_vars::AbstractVector) where {C, T, U <: AbstractVector} |
src/dae_solve.jl
Outdated
if autodiff | ||
ForwardDiff.jacobian(t -> phi(t, θ), t) | ||
else | ||
(phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this use only differential variables? See other methods of dfdx
Hi, I am trying to refactor the But for the following code:
I tried resolving this issue by addressing the prompt from the above error message:
But I still get the following error:
So, I am unsure of how to proceed. |
I tried a different strategy by directly changing the code in ode_solve.jl and dae_solve.jl instead of creating a new script with the refactored code to overcome the last issue. But I am getting the following error: |
b22a055
to
6f4580f
Compare
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Add any other context about the problem here.