We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 121a238 commit 4d7835eCopy full SHA for 4d7835e
Project.toml
@@ -1,6 +1,6 @@
1
name = "FiniteDifferences"
2
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
3
-version = "0.12.21"
+version = "0.12.22"
4
5
[deps]
6
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
src/grad.jl
@@ -70,7 +70,7 @@ Compute an adjoint with any types of arguments `x` for which [`to_vec`](@ref) is
70
function j′vp(fdm, f, ȳ, x)
71
x_vec, vec_to_x = to_vec(x)
72
ȳ_vec, _ = to_vec(ȳ)
73
- return (vec_to_x(_j′vp(fdm, first ∘ to_vec ∘ f ∘ vec_to_x, ȳ_vec, x_vec)), )
+ return (vec_to_x(_j′vp(fdm, x -> first(to_vec(f(vec_to_x(x)))), ȳ_vec, x_vec)), )
74
end
75
76
j′vp(fdm, f, ȳ, xs...) = j′vp(fdm, xs->f(xs...), ȳ, xs)[1]
test/grad.jl
@@ -146,9 +146,10 @@ using FiniteDifferences: grad, jacobian, _jvp, jvp, j′vp, _j′vp, to_vec
146
x, y = randn(rng, T, N), randn(rng, T, M)
147
z̄ = randn(rng, T, N + M)
148
xy = vcat(x, y)
149
- x̄ȳ_manual = j′vp(fdm, xy->sin.(xy), z̄, xy)[1]
150
- x̄ȳ_auto = j′vp(fdm, x->sin.(vcat(x[1], x[2])), z̄, (x, y))[1]
151
- x̄ȳ_multi = j′vp(fdm, (x, y)->sin.(vcat(x, y)), z̄, x, y)
+ # Type inference: https://github.com/JuliaDiff/FiniteDifferences.jl/issues/199
+ x̄ȳ_manual = @inferred(j′vp(fdm, xy->sin.(xy), z̄, xy))[1]
+ x̄ȳ_auto = @inferred(j′vp(fdm, x->sin.(vcat(x[1], x[2])), z̄, (x, y)))[1]
152
+ x̄ȳ_multi = @inferred(j′vp(fdm, (x, y)->sin.(vcat(x, y)), z̄, x, y))
153
@test x̄ȳ_manual ≈ vcat(x̄ȳ_auto...)
154
@test x̄ȳ_manual ≈ vcat(x̄ȳ_multi...)
155
0 commit comments