Skip to content

Commit 8253801

Browse files
committed
Fix step size calculation if some results are non-deterministic (inf derivative) or nan
1 parent 87e0a26 commit 8253801

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

src/methods.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,20 @@ function estimate_step(
371371
return _limit_step(m, x, step, acc)
372372
end
373373

374+
function finite_or_zero(fs::AbstractArray{<:Number})
375+
ifelse.(isfinite.(fs), fs, zero(fs))
376+
end
377+
378+
function finite_or_zero(fs::AbstractArray{<:AbstractArray})
379+
finite_or_zero.(fs)
380+
end
381+
374382
function _estimate_magnitudes(
375383
m::FiniteDifferenceMethod{P,Q}, f::TF, x::T,
376384
) where {P,Q,TF,T<:AbstractFloat}
377385
step = first(estimate_step(m, f, x))
378386
fs = _eval_function(m, f, x, step)
387+
fs = finite_or_zero(fs)
379388
# Estimate magnitude of `∇f` in a neighbourhood of `x`.
380389
∇fs = SVector{3}(
381390
_compute_estimate(m, fs, x, step, m.coefs_neighbourhood[1]),

test/grad.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,35 @@ using FiniteDifferences: grad, jacobian, _jvp, jvp, j′vp, _j′vp, to_vec
217217
@test [real(ȳ), imag(ȳ)] Jy'z̄_vec
218218
end
219219
end
220+
221+
using LinearAlgebra
222+
223+
function partial_nan_returning(x)
224+
y = Matrix{Float64}(undef, 5, 5)
225+
y .= NaN
226+
y = Hermitian(y)
227+
y .= x
228+
return parent(y)
229+
end
230+
231+
randvar = 1
232+
function partial_nondet_returning(x)
233+
global randvar
234+
y = Matrix{Float64}(undef, 5, 5)
235+
y .= randvar
236+
randvar += 1
237+
y = Hermitian(y)
238+
y .= x
239+
return parent(y)
240+
end
241+
242+
@testset "jvp: Estimate step correctly for when some terms are nan/infinite" begin
243+
fdm = FiniteDifferences.central_fdm(5, 1)
244+
res = jvp(fdm, partial_nan_returning, 3.1, 2.7)
245+
@show res
246+
@test Hermitian(res) .≈ 2.7
247+
248+
res = jvp(fdm, partial_nondet_returning, 3.1, 2.7)
249+
@show res
250+
@test Hermitian(res) .≈ 2.7
251+
end

0 commit comments

Comments
 (0)