Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ADTypes"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = ["Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors"]
version = "1.18.0"
version = "1.19.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 6 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ AutoChainRules
AutoDiffractor
```

### Forward, reverse, or sparse mode

```@docs
AutoReactant{<:AutoEnzyme}
```

### Symbolic mode

```@docs
Expand Down
3 changes: 2 additions & 1 deletion src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ export AutoChainRules,
AutoTracker,
AutoZygote,
NoAutoDiff,
NoAutoDiffSelectedError
NoAutoDiffSelectedError,
AutoReactant
@public AbstractMode
@public ForwardMode, ReverseMode, ForwardOrReverseMode, SymbolicMode
@public mode
Expand Down
39 changes: 39 additions & 0 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,45 @@ function Base.show(io::IO, backend::AutoEnzyme{M, A}) where {M, A}
print(io, ")")
end


"""
AutoReactant{M<:AutoEnzyme}

Struct used to select the [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl) compilation atop Enzyme for automatic differentiation.

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Constructors

AutoReactant(; mode::Union{AutoEnzyme,Nothing}=nothing)

# Fields

- `mode::M` specifies the Enzyme mode of differentiation

+ an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
+ `nothing` to choose the best mode automatically
"""
struct AutoReactant{M<:AutoEnzyme} <: AbstractADType
mode::M
end

function AutoReactant(;
mode::Union{AutoEnzyme,Nothing} = nothing)
if mode === nothing
mode = AutoEnzyme()
end
return AutoReactant(mode)
end

mode(r::AutoReactant) = mode(r.mode)

function Base.show(io::IO, backend::AutoReactant)
print(io, AutoReactant, "(")
print(io, "mode=", repr(backend.mode; context = io))
print(io, ")")
end

"""
AutoFastDifferentiation

Expand Down
2 changes: 1 addition & 1 deletion src/symbols.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ADTypes.AutoZygote()
"""
Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...)

for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation,
for backend in (:ChainRules, :Diffractor, :Enzyme, :Reactant, :FastDifferentiation,
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :Mooncake, :PolyesterForwardDiff,
:ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote)
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(
Expand Down
28 changes: 28 additions & 0 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,34 @@ end
@test ad.mode == EnzymeCore.Reverse
end

@testset "AutoReactant" begin
ad = AutoReactant()
@test ad isa AbstractADType
@test ad isa AutoReactant{<:AutoEnzyme}
@test ad.mode isa AutoEnzyme
@test ad.mode.mode === nothing
@test mode(ad) isa ForwardOrReverseMode

ad = AutoReactant(; mode=AutoEnzyme(; mode = EnzymeCore.Forward))
@test ad isa AbstractADType
@test ad isa AutoReactant{<:AutoEnzyme{typeof(EnzymeCore.Forward), Nothing}}
@test mode(ad) isa ForwardMode
@test ad.mode.mode == EnzymeCore.Forward

ad = AutoReactant(; mode=AutoEnzyme(; function_annotation = EnzymeCore.Const))
@test ad isa AbstractADType
@test ad isa AutoReactant{<:AutoEnzyme{Nothing, EnzymeCore.Const}}
@test mode(ad) isa ForwardOrReverseMode
@test ad.mode.mode === nothing

ad = AutoReactant(; mode=AutoEnzyme(;
mode = EnzymeCore.Reverse, function_annotation = EnzymeCore.Duplicated))
@test ad isa AbstractADType
@test ad isa AutoReactant{<:AutoEnzyme{typeof(EnzymeCore.Reverse), EnzymeCore.Duplicated}}
@test mode(ad) isa ReverseMode
@test ad.mode.mode == EnzymeCore.Reverse
end

@testset "AutoFastDifferentiation" begin
ad = AutoFastDifferentiation()
@test ad isa AbstractADType
Expand Down
1 change: 1 addition & 0 deletions test/symbols.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Test
@test ADTypes.Auto(:Mooncake) isa AutoMooncake
@test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff
@test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff
@test ADTypes.Auto(:Reactant) isa AutoReactant
@test ADTypes.Auto(:Symbolics) isa AutoSymbolics
@test ADTypes.Auto(:Tapir) isa AutoTapir
@test ADTypes.Auto(:Tracker) isa AutoTracker
Expand Down
Loading