-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: more coverage for common NN operations (#55)
* feat: more coverage for common NN activations * feat: support `mean`. * feat: support `var`. * feat: add overload for `ifelse` * chore: relax compat * test: activation functions and their adjoints * test: `mean` and `var` * test: add BatchNorm to the lux test * fix: update `relu` and `abs2` * fix: dispatch directly on `ifelse` * test: skip Lux tests pre-1.9 * fix: overload scalar ops * refactor: move statistics into extension * refactor: remove more elem_apply * fix: ambiguity error
- Loading branch information
Showing
10 changed files
with
185 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,6 @@ | ||
name = "Reactant" | ||
uuid = "3c362404-f566-11ee-1572-e11a4b42c853" | ||
authors = [ | ||
"William Moses <[email protected]>", | ||
"Valentin Churavy <[email protected]>", | ||
"Sergio Sánchez Ramírez <[email protected]>", | ||
"Paul Berg <[email protected]>", | ||
] | ||
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>"] | ||
version = "0.1.8" | ||
|
||
[deps] | ||
|
@@ -19,11 +14,13 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" | |
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" | ||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
|
||
[extensions] | ||
ReactantAdaptExt = "Adapt" | ||
ReactantArrayInterfaceExt = "ArrayInterface" | ||
ReactantNNlibExt = "NNlib" | ||
ReactantStatisticsExt = "Statistics" | ||
|
||
[compat] | ||
Adapt = "4" | ||
|
@@ -33,6 +30,7 @@ Enzyme = "0.11, 0.12" | |
NNlib = "0.9" | ||
Preferences = "1.4" | ||
Reactant_jll = "0.0.14" | ||
Statistics = "1.9" | ||
julia = "1.9" | ||
|
||
[extras] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
module ReactantStatisticsExt | ||
|
||
using Reactant: TracedRArray | ||
using Statistics: Statistics | ||
|
||
function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} | ||
denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) | ||
return mapreduce(identity, +, A; dims) / denom | ||
end | ||
|
||
function Statistics.var( | ||
A::TracedRArray{T,Shape,N}; dims=:, mean=nothing, corrected=true | ||
) where {T,Shape,N} | ||
mean === nothing && (mean = Statistics.mean(A; dims)) | ||
denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected | ||
return mapreduce(abs2, +, A .- mean; dims) / denom | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters