Skip to content

Commit

Permalink
Merge pull request #176 from devmotion/fix_forwarddiff_interpolant
Browse files Browse the repository at this point in the history
Enable autodifferentiation of in-place interpolants
  • Loading branch information
ChrisRackauckas authored Aug 27, 2017
2 parents cc6306c + f919b6b commit 22b5257
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,17 @@ function ode_addsteps!{calcVal,calcVal2,calcVal3}(k,t,uprev,u,dt,f,cache,always_
end

@inline function ode_interpolant{TI}(Θ,dt,y₀,y₁,k,cache::OrdinaryDiffEqMutableCache,idxs,T::Type{Val{TI}})
# determine output type
# required for calculation of time derivatives with autodifferentiation
oneunit_Θ = oneunit(Θ)
S = promote_type(typeof(oneunit_Θ * oneunit(eltype(y₀))), # Θ*y₀
typeof(oneunit_Θ * oneunit(eltype(y₁))), # Θ*y₁
typeof(oneunit_Θ * oneunit(dt) * oneunit(eltype(k[1])))) # Θ*dt*k

if typeof(idxs) <: Void
out = similar(y₀)
out = similar(y₀, S)
else
!(typeof(idxs) <: Number) && (out = similar(y₀,indices(idxs)))
!(typeof(idxs) <: Number) && (out = similar(y₀, S, indices(idxs)))
end
if typeof(idxs) <: Number
return ode_interpolant!(nothing,Θ,dt,y₀,y₁,k,cache,idxs,T)
Expand Down

0 comments on commit 22b5257

Please sign in to comment.