Skip to content

Commit ed9b61d

Browse files
refactor: interpolation with higher dim arrays
1 parent caf0b76 commit ed9b61d

File tree

2 files changed

+47
-40
lines changed

2 files changed

+47
-40
lines changed

src/interpolation_caches.jl

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -173,29 +173,24 @@ Extrapolation extends the last cubic polynomial on each side.
173173
for a test based on the normalized standard deviation of the difference with respect
174174
to the straight line (see [`looks_linear`](@ref)). Defaults to 1e-2.
175175
"""
176-
struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T, N} <:
176+
struct AkimaInterpolation{uType, tType, IType, pType, T, N} <:
177177
AbstractInterpolation{T, N}
178178
u::uType
179179
t::tType
180180
I::IType
181-
b::bType
182-
c::cType
183-
d::dType
181+
p::pType
184182
extrapolate::Bool
185183
iguesser::Guesser{tType}
186184
cache_parameters::Bool
187185
linear_lookup::Bool
188186
function AkimaInterpolation(
189-
u, t, I, b, c, d, extrapolate, cache_parameters, assume_linear_t)
187+
u, t, I, p, extrapolate, cache_parameters, assume_linear_t)
190188
linear_lookup = seems_linear(assume_linear_t, t)
191189
N = get_output_dim(u)
192-
new{typeof(u), typeof(t), typeof(I), typeof(b), typeof(c),
193-
typeof(d), eltype(u), N}(u,
190+
new{typeof(u), typeof(t), typeof(I), typeof(p), eltype(u), N}(u,
194191
t,
195192
I,
196-
b,
197-
c,
198-
d,
193+
p,
199194
extrapolate,
200195
Guesser(t),
201196
cache_parameters,
@@ -208,30 +203,11 @@ function AkimaInterpolation(
208203
u, t; extrapolate = false, cache_parameters = false, assume_linear_t = 1e-2)
209204
u, t = munge_data(u, t)
210205
linear_lookup = seems_linear(assume_linear_t, t)
211-
n = length(t)
212-
dt = diff(t)
213-
m = Array{eltype(u)}(undef, n + 3)
214-
m[3:(end - 2)] = diff(u) ./ dt
215-
m[2] = 2m[3] - m[4]
216-
m[1] = 2m[2] - m[3]
217-
m[end - 1] = 2m[end - 2] - m[end - 3]
218-
m[end] = 2m[end - 1] - m[end - 2]
219-
220-
b = 0.5 .* (m[4:end] .+ m[1:(end - 3)])
221-
dm = abs.(diff(m))
222-
f1 = dm[3:(n + 2)]
223-
f2 = dm[1:n]
224-
f12 = f1 + f2
225-
ind = findall(f12 .> 1e-9 * maximum(f12))
226-
b[ind] = (f1[ind] .* m[ind .+ 1] .+
227-
f2[ind] .* m[ind .+ 2]) ./ f12[ind]
228-
c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt
229-
d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2
230-
206+
p = AkimaParameterCache(u, t)
231207
A = AkimaInterpolation(
232-
u, t, nothing, b, c, d, extrapolate, cache_parameters, linear_lookup)
208+
u, t, nothing, p, extrapolate, cache_parameters, linear_lookup)
233209
I = cumulative_integral(A, cache_parameters)
234-
AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters, linear_lookup)
210+
AkimaInterpolation(u, t, I, p, extrapolate, cache_parameters, linear_lookup)
235211
end
236212

237213
"""

src/interpolation_methods.jl

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,14 @@ function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number, igu
8888
end
8989

9090
function _interpolate(
91-
A::LagrangeInterpolation{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
91+
A::LagrangeInterpolation{<:AbstractArray}, t::Number, iguess)
9292
idx = get_idx(A, t, iguess)
9393
findRequiredIdxs!(A, t, idx)
9494
ax = axes(A.u)[1:(end - 1)]
9595
if A.t[A.idxs[1]] == t
9696
return A.u[ax..., A.idxs[1]]
9797
end
98-
N1 = zero(A.u[ax..., 1])
98+
N = zero(A.u[ax..., 1])
9999
D = zero(A.t[1])
100100
tmp = D
101101
for i in 1:length(A.idxs)
@@ -113,15 +113,22 @@ function _interpolate(
113113
end
114114
tmp = inv((t - A.t[A.idxs[i]]) * mult)
115115
D += tmp
116-
@. N1 += (tmp * A.u[ax..., A.idxs[i]])
116+
@. N += (tmp * A.u[ax..., A.idxs[i]])
117117
end
118-
N1 / D
118+
N / D
119119
end
120120

121121
function _interpolate(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess)
122122
idx = get_idx(A, t, iguess)
123123
wj = t - A.t[idx]
124-
@evalpoly wj A.u[idx] A.b[idx] A.c[idx] A.d[idx]
124+
@evalpoly wj A.u[idx] A.p.b[idx] A.p.c[idx] A.p.d[idx]
125+
end
126+
127+
function _interpolate(A::AkimaInterpolation{<:AbstractArray}, t::Number, iguess)
128+
idx = get_idx(A, t, iguess)
129+
wj = t - A.t[idx]
130+
ax = axes(A.u)[1:(end - 1)]
131+
@. @evalpoly wj A.u[ax..., idx] A.p.b[ax..., idx] A.p.c[ax..., idx] A.p.d[ax..., idx]
125132
end
126133

127134
# ConstantInterpolation Interpolation
@@ -137,7 +144,7 @@ function _interpolate(A::ConstantInterpolation{<:AbstractVector}, t::Number, igu
137144
end
138145

139146
function _interpolate(
140-
A::ConstantInterpolation{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
147+
A::ConstantInterpolation{<:AbstractArray}, t::Number, iguess)
141148
if A.dir === :left
142149
# :left means that value to the left is used for interpolation
143150
idx = get_idx(A, t, iguess; lb = 1, ub_shift = 0)
@@ -158,7 +165,7 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
158165
end
159166

160167
function _interpolate(
161-
A::QuadraticSpline{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
168+
A::QuadraticSpline{<:AbstractArray}, t::Number, iguess)
162169
idx = get_idx(A, t, iguess)
163170
ax = axes(A.u)[1:(end - 1)]
164171
Cᵢ = A.u[ax..., idx]
@@ -179,7 +186,7 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
179186
I + C + D
180187
end
181188

182-
function _interpolate(A::CubicSpline{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
189+
function _interpolate(A::CubicSpline{<:AbstractArray}, t::Number, iguess)
183190
idx = get_idx(A, t, iguess)
184191
Δt₁ = t - A.t[idx]
185192
Δt₂ = A.t[idx + 1] - t
@@ -238,6 +245,18 @@ function _interpolate(
238245
out
239246
end
240247

248+
function _interpolate(
249+
A::CubicHermiteSpline{<:AbstractArray}, t::Number, iguess)
250+
idx = get_idx(A, t, iguess)
251+
Δt₀ = t - A.t[idx]
252+
Δt₁ = t - A.t[idx + 1]
253+
ax = axes(A.u)[1:(end - 1)]
254+
out = A.u[ax..., idx] .+ Δt₀ .* A.du[ax..., idx]
255+
c₁, c₂ = get_parameters(A, idx)
256+
out .+= Δt₀^2 .* (c₁ .+ Δt₁ .* c₂)
257+
out
258+
end
259+
241260
# Quintic Hermite Spline
242261
function _interpolate(
243262
A::QuinticHermiteSpline{<:AbstractVector{<:Number}}, t::Number, iguess)
@@ -249,3 +268,15 @@ function _interpolate(
249268
out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁))
250269
out
251270
end
271+
272+
function _interpolate(
273+
A::QuinticHermiteSpline{<:AbstractArray}, t::Number, iguess)
274+
idx = get_idx(A, t, iguess)
275+
Δt₀ = t - A.t[idx]
276+
Δt₁ = t - A.t[idx + 1]
277+
ax = axes(A.u)[1:(end - 1)]
278+
out = A.u[ax..., idx] + Δt₀ * (A.du[ax..., idx] + A.ddu[ax..., idx] * Δt₀ / 2)
279+
c₁, c₂, c₃ = get_parameters(A, idx)
280+
out .+= Δt₀^3 .* (c₁ .+ Δt₁ .* (c₂ .+ c₃ .* Δt₁))
281+
out
282+
end

0 commit comments

Comments
 (0)