Skip to content

Commit

Permalink
feat: more coverage for common NN operations (#55)
Browse files Browse the repository at this point in the history
* feat: more coverage for common NN activations

* feat: support `mean`.

* feat: support `var`.

* feat: add overload for `ifelse`

* chore: relax compat

* test: activation functions and their adjoints

* test: `mean` and `var`

* test: add BatchNorm to the lux test

* fix: update `relu` and `abs2`

* fix: dispatch directly on `ifelse`

* test: skip Lux tests pre-1.9

* fix: overload scalar ops

* refactor: move statistics into extension

* refactor: remove more elem_apply

* fix: ambiguity error
  • Loading branch information
avik-pal authored Aug 5, 2024
1 parent 24152d4 commit 4177c8d
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 39 deletions.
10 changes: 4 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = [
"William Moses <[email protected]>",
"Valentin Churavy <[email protected]>",
"Sergio Sánchez Ramírez <[email protected]>",
"Paul Berg <[email protected]>",
]
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>"]
version = "0.1.8"

[deps]
Expand All @@ -19,11 +14,13 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extensions]
ReactantAdaptExt = "Adapt"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantNNlibExt = "NNlib"
ReactantStatisticsExt = "Statistics"

[compat]
Adapt = "4"
Expand All @@ -33,6 +30,7 @@ Enzyme = "0.11, 0.12"
NNlib = "0.9"
Preferences = "1.4"
Reactant_jll = "0.0.14"
Statistics = "1.9"
julia = "1.9"

[extras]
Expand Down
30 changes: 16 additions & 14 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,25 @@ module ReactantNNlibExt
using NNlib
using Reactant

for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh))
@eval begin
if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast
function Reactant.elem_apply(
::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
return Reactant.TracedRArray{ElType,Shape,N}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1
),
)
end
end
for (jlop, hloop) in (
(:(NNlib.tanh_fast), :tanh),
(:(NNlib.sigmoid_fast), :logistic),
(:(NNlib.sigmoid), :logistic),
)
@eval function $(jlop)(x::Reactant.TracedRArray{T,(),0}) where {T}
return Reactant.TracedRArray{T,(),0}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
),
)
end
end

NNlib.relu(x::Reactant.TracedRArray{T,(),0}) where {T} = max(x, zero(T))

NNlib.gelu(x::Reactant.TracedRArray{T,(),0}) where {T} = x * sigmoid(T(1.702) * x)

# TODO handle non finite cases
function NNlib.softmax!(
out::Reactant.TracedRArray{T,Shape,N}, x::AbstractArray; dims=1
Expand Down
19 changes: 19 additions & 0 deletions ext/ReactantStatisticsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module ReactantStatisticsExt

using Reactant: TracedRArray
using Statistics: Statistics

function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N}
denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)
return mapreduce(identity, +, A; dims) / denom
end

function Statistics.var(
A::TracedRArray{T,Shape,N}; dims=:, mean=nothing, corrected=true
) where {T,Shape,N}
mean === nothing && (mean = Statistics.mean(A; dims))
denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected
return mapreduce(abs2, +, A .- mean; dims) / denom
end

end
4 changes: 2 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ function Base.isapprox(x::ConcreteRArray{ElType,(),0}, y; kwargs...) where {ElTy
end

function Base.isapprox(x, y::ConcreteRArray{ElType,(),0}; kwargs...) where {ElType}
return Base.isapprox(to_float(x), y; kwargs...)
return Base.isapprox(x, to_float(y); kwargs...)
end

function Base.isapprox(
x::ConcreteRArray{ElType,(),0}, y::ConcreteRArray{ElType2,(),0}; kwargs...
) where {ElType,ElType2}
return Base.isapprox(to_float(x), y; kwargs...)
return Base.isapprox(to_float(x), to_float(y); kwargs...)
end

function Base.print_array(io::IO, X::ConcreteRArray)
Expand Down
75 changes: 64 additions & 11 deletions src/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,49 @@ for (jlop, hloop, RT) in (
)
end

function $jlop(lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,N}(
function $jlop(
lhs::TracedRArray{ElType,(),0}, rhs::TracedRArray{ElType,(),0}
) where {ElType}
return TracedRArray{$RT,(),0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end
end

function $jlop(lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
for otherType in (Number, Any, TracedRArray{S,(),0} where {S})
@eval begin
function $jlop(
lhs::TracedRArray{ElType,Shape,N}, rhs::$otherType
) where {ElType,Shape,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end

function $jlop(
lhs::$otherType, rhs::TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end
end
end
end

Base.abs2(x::Reactant.TracedRArray{T,(),0}) where {T} = x * conj(x)

function Base.literal_pow(
::Base.RefValue{typeof(^)}, x::Reactant.TracedRArray{T,(),0}, ::Base.RefValue{Val{P}}
) where {T,P}
Expand Down Expand Up @@ -137,9 +158,41 @@ for (jlop, hloop, RT) in (
),
)
end

# Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity
function $jlop(lhs::TracedRArray{ElType,Shape,0}, rhs::Number) where {ElType,Shape}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end

function $jlop(lhs::Number, rhs::TracedRArray{ElType,Shape,0}) where {ElType,Shape}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end
end
end

function Base.ifelse(
pred::TracedRArray{Bool,(),0}, x::TracedRArray{T1,(),0}, y::TracedRArray{T2,(),0}
) where {T1,T2}
return TracedRArray{promote_type(T1, T2),(),0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1
),
)
end

function Base.:*(
lhs::TracedRArray{ElType,Shape,2}, rhs::TracedRArray{ElType,Shape2,2}
) where {ElType,Shape,Shape2}
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
36 changes: 36 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Reactant
using Test
using Enzyme
using Statistics

# Reactant.set_default_backend("gpu")

Expand Down Expand Up @@ -152,3 +153,38 @@ end

@test contains(res_repr, "stablehlo.dot_general")
end

@testset "Statistics: `mean` & `var`" begin
x = randn(2, 3, 4)
x_ca = Reactant.ConcreteRArray(x)

mean_fn1(x) = mean(x)
mean_fn2(x) = mean(x; dims=1)
mean_fn3(x) = mean(x; dims=(1, 2))
mean_fn4(x) = mean(x; dims=(1, 3))

mean_fn1_compiled = Reactant.compile(mean_fn1, (x_ca,))
mean_fn2_compiled = Reactant.compile(mean_fn2, (x_ca,))
mean_fn3_compiled = Reactant.compile(mean_fn3, (x_ca,))
mean_fn4_compiled = Reactant.compile(mean_fn4, (x_ca,))

@test mean_fn1(x) mean_fn1_compiled(x_ca)
@test mean_fn2(x) mean_fn2_compiled(x_ca)
@test mean_fn3(x) mean_fn3_compiled(x_ca)
@test mean_fn4(x) mean_fn4_compiled(x_ca)

var_fn1(x) = var(x)
var_fn2(x) = var(x; dims=1)
var_fn3(x) = var(x; dims=(1, 2), corrected=false)
var_fn4(x) = var(x; dims=(1, 3), corrected=false)

var_fn1_compiled = Reactant.compile(var_fn1, (x_ca,))
var_fn2_compiled = Reactant.compile(var_fn2, (x_ca,))
var_fn3_compiled = Reactant.compile(var_fn3, (x_ca,))
var_fn4_compiled = Reactant.compile(var_fn4, (x_ca,))

@test var_fn1(x) var_fn1_compiled(x_ca)
@test var_fn2(x) var_fn2_compiled(x_ca)
@test var_fn3(x) var_fn3_compiled(x_ca)
@test var_fn4(x) var_fn4_compiled(x_ca)
end
33 changes: 32 additions & 1 deletion test/bcast.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

using Reactant

using Enzyme, NNlib
using Reactant.MLIR

@noinline function no(@nospecialize(x))
Expand Down Expand Up @@ -56,3 +56,34 @@ function test()
end
end
test()

@testset "Activation Functions" begin
sumabs2(f, x) = sum(abs2, f.(x))

function ∇sumabs2(f, x)
dx = Enzyme.make_zero(x)
Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
return dx
end

x_act = randn(Float32, 10, 10)
x_act_ca = Reactant.ConcreteRArray(x_act)

@testset "Activation: $act" for act in (
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
)
f_compile = Reactant.compile(sumabs2, (act, x_act))

y_simple = sumabs2(act, x_act)
y_compile = f_compile(act, x_act_ca)

∂x_enz = Enzyme.make_zero(x_act)
Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))

∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca))

∂x_compile = ∇sumabs2_compiled(act, x_act_ca)

@test y_simple y_compile
end
end
11 changes: 7 additions & 4 deletions test/nn_lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-ele
# Define our model, a multi-layer perceptron with one hidden layer of size 3:
model = Lux.Chain(
Lux.Dense(2 => 3, tanh), # activation function inside layer
Lux.BatchNorm(3, gelu),
Lux.Dense(3 => 2),
softmax,
)
Expand All @@ -17,8 +18,7 @@ ps, st = Lux.setup(Xoshiro(123), model)
using BenchmarkTools

origout, _ = model(noisy, ps, st)
@show origout[3]
@btime model($noisy, $ps, $st) # 52.731 μs (10 allocations: 32.03 KiB)
@btime model($noisy, $ps, $st) # 68.444 μs (46 allocations: 45.88 KiB)

cmodel = Reactant.make_tracer(IdDict(), model, (), Reactant.ArrayToConcrete)
cps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete)
Expand All @@ -31,8 +31,9 @@ f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cs
# # @show @code_typed f(cmodel,cnoisy)
# # @show @code_llvm f(cmodel,cnoisy)
comp = f(cmodel, cnoisy, cps, cst)
@show comp[3]
@btime f($cmodel, $cnoisy, $cps, $cst) # 4.430 μs (5 allocations: 160 bytes)
@btime f($cmodel, $cnoisy, $cps, $cst) # 21.790 μs (6 allocations: 224 bytes)

@test comp origout atol = 1e-5 rtol = 1e-2

# To train the model, we use batches of 64 samples, and one-hot encoding:

Expand Down Expand Up @@ -81,6 +82,8 @@ compiled_gradient = Reactant.compile(
gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst)
)

@test length(compiled_gradient(cmodel, cnoisy, ctarget, cps, cst)) == 2

# # Training loop, using the whole data set 1000 times:
# losses = []
# for epoch in 1:1_000
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@ include("nn.jl")
include("struct.jl")
include("closure.jl")
include("compile.jl")
include("nn_lux.jl")

if VERSION v"1.10-" # Lux isn't supported on 1.9
include("nn_lux.jl")
end

0 comments on commit 4177c8d

Please sign in to comment.