Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDifferences"
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
version = "0.12.32"
version = "0.12.33"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
9 changes: 9 additions & 0 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,20 @@ function estimate_step(
return _limit_step(m, x, step, acc)
end

function finite_or_zero(fs::AbstractArray{<:Number})
ifelse.(isfinite.(fs), fs, zero(fs))
end

function finite_or_zero(fs::AbstractArray{<:AbstractArray})
finite_or_zero.(fs)
end

function _estimate_magnitudes(
m::FiniteDifferenceMethod{P,Q}, f::TF, x::T,
) where {P,Q,TF,T<:AbstractFloat}
step = first(estimate_step(m, f, x))
fs = _eval_function(m, f, x, step)
fs = finite_or_zero(fs)
# Estimate magnitude of `∇f` in a neighbourhood of `x`.
∇fs = SVector{3}(
_compute_estimate(m, fs, x, step, m.coefs_neighbourhood[1]),
Expand Down
30 changes: 30 additions & 0 deletions test/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,33 @@ using FiniteDifferences: grad, jacobian, _jvp, jvp, j′vp, _j′vp, to_vec
@test [real(ȳ), imag(ȳ)] ≈ Jy'z̄_vec
end
end

using LinearAlgebra

function partial_nan_returning(x)
y = Matrix{Float64}(undef, 5, 5)
y .= NaN
y = Hermitian(y)
y .= x
return parent(y)
end

randvar = 1
function partial_nondet_returning(x)
global randvar
y = Matrix{Float64}(undef, 5, 5)
y .= randvar
randvar += 1
y = Hermitian(y)
y .= x
return parent(y)
end

@testset "jvp: Estimate step correctly for when some terms are nan/infinite" begin
fdm = FiniteDifferences.central_fdm(5, 1)
res = jvp(fdm, partial_nan_returning, (3.1, 2.7))
@test all(Hermitian(res) .≈ 2.7)

res = jvp(fdm, partial_nondet_returning, (3.1, 2.7))
@test all(Hermitian(res) .≈ 2.7)
end
Loading