Skip to content

Commit 2dad6a7

Browse files
mobledevmotion
andauthored
Handle Base.hypot with more than 2 arguments (#824)
* Handle Base.hypot with more than 2 arguments * Bump version * Apply suggestions from code review Co-authored-by: David Müller-Widmann <[email protected]> * Test mixed Real/Complex inputs to hypot --------- Co-authored-by: David Müller-Widmann <[email protected]>
1 parent 0923b1e commit 2dad6a7

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.72.5"
3+
version = "1.72.6"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/base.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,38 @@ function rrule(::typeof(hypot), z::Complex)
119119
return (Ω, hypot_pullback)
120120
end
121121

122+
# Note that `hypot` with two arguments has rules in `fastmath_able.jl`
123+
124+
function frule(
125+
(_, Δx, Δy, Δz, Δxs...),
126+
::typeof(hypot),
127+
x::Union{Real,Complex},
128+
y::Union{Real,Complex},
129+
z::Union{Real,Complex},
130+
xs::Union{Real,Complex}...,
131+
)
132+
Ω = hypot(x, y, z, xs...)
133+
n = ifelse(iszero(Ω), oneunit(Ω), Ω)
134+
∂Ω = sum(map(realdot, (x, y, z, xs...), (Δx, Δy, Δz, Δxs...))) / n
135+
return Ω, ∂Ω
136+
end
137+
138+
function rrule(
139+
::typeof(hypot),
140+
x::Union{Real,Complex},
141+
y::Union{Real,Complex},
142+
z::Union{Real,Complex},
143+
xs::Union{Real,Complex}...,
144+
)
145+
Ω = hypot(x, y, z, xs...)
146+
function hypot_pullback(ΔΩ)
147+
c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)
148+
return (NoTangent(), c * x, c * y, c * z, map(xi -> c * xi, xs)...)
149+
end
150+
return (Ω, hypot_pullback)
151+
end
152+
153+
122154
@scalar_rule fma(x, y, z) (y, x, true)
123155
@scalar_rule muladd(x, y, z) (y, x, true)
124156
@scalar_rule muladd(x::Union{Number, ZeroTangent}, y::Union{Number, ZeroTangent}, z::Union{Number, ZeroTangent}) (y, x, true)

test/rulesets/Base/base.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,4 +263,14 @@ end
263263
test_rrule(merge, (; a=1.0), (; b=2.0))
264264
test_rrule(merge, (; a=1.0), (; a=2.0))
265265
end
266+
267+
@testset "hypot(x, y, z, xs...)" begin
268+
for n in (3, 4)
269+
for sig in Iterators.product(ntuple(_->(Float64, ComplexF64), n)...)
270+
args = randn.(sig)
271+
test_frule(hypot, args...)
272+
test_rrule(hypot, args...)
273+
end
274+
end
275+
end
266276
end

0 commit comments

Comments
 (0)