diff --git a/Project.toml b/Project.toml index d6c26eb..a0a1ebc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.32" +version = "0.12.33" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/methods.jl b/src/methods.jl index f3d8ea7..bcc993f 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -371,17 +371,27 @@ function estimate_step( return _limit_step(m, x, step, acc) end +function finite_or_zero(fs::AbstractArray{<:Number}) + ifelse.(isfinite.(fs), fs, zero(fs)) +end + +function finite_or_zero(fs::AbstractArray{<:AbstractArray}) + finite_or_zero.(fs) +end + function _estimate_magnitudes( m::FiniteDifferenceMethod{P,Q}, f::TF, x::T, ) where {P,Q,TF,T<:AbstractFloat} step = first(estimate_step(m, f, x)) fs = _eval_function(m, f, x, step) + fs = finite_or_zero(fs) # Estimate magnitude of `∇f` in a neighbourhood of `x`. ∇fs = SVector{3}( _compute_estimate(m, fs, x, step, m.coefs_neighbourhood[1]), _compute_estimate(m, fs, x, step, m.coefs_neighbourhood[2]), _compute_estimate(m, fs, x, step, m.coefs_neighbourhood[3]) ) + ∇fs = finite_or_zero(∇fs) ∇f_magnitude = maximum(maximum.(abs, ∇fs)) # Estimate magnitude of `f` in a neighbourhood of `x`. f_magnitude = maximum(maximum.(abs, fs)) diff --git a/test/grad.jl b/test/grad.jl index 095777e..ac9673c 100644 --- a/test/grad.jl +++ b/test/grad.jl @@ -217,3 +217,26 @@ using FiniteDifferences: grad, jacobian, _jvp, jvp, j′vp, _j′vp, to_vec @test [real(ȳ), imag(ȳ)] ≈ Jy'z̄_vec end end + +using LinearAlgebra + +function partial_nan_returning(x) + return Float64[NaN, x] +end + +randvar = 1 +function partial_nondet_returning(x) + global randvar + y = Float64[randvar, x] + randvar += 1 + return y +end + +@testset "jvp: Estimate step correctly for when some terms are nan/infinite" begin + fdm = FiniteDifferences.central_fdm(5, 1) + res = jvp(fdm, partial_nan_returning, (3.1, 2.7)) + @test res[2] ≈ 2.7 + + res = jvp(fdm, partial_nondet_returning, (3.1, 2.7)) + @test res[2] ≈ 2.7 +end