Skip to content

Commit

Permalink
bump Enzyme version in v0.2 (#132)
Browse files Browse the repository at this point in the history
* bump Enzyme version, update Enzyme interface

* drop testing on Julia 1.6 for v0.2

* fix disable testing on Enzyme
  • Loading branch information
Red-Portal authored Oct 24, 2024
1 parent 4af5f82 commit e5e26d0
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- os: macOS-latest
arch: x86
include:
- version: '1.6'
- version: '1.10'
os: ubuntu-latest
arch: x64
- os: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ DiffResults = "1"
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.11"
Enzyme = "0.13"
LinearAlgebra = "1.6"
ForwardDiff = "0.10.3"
Flux = "0.14"
Expand Down
21 changes: 12 additions & 9 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,22 @@ function AdvancedVI.grad!(
out::DiffResults.MutableDiffResult,
args...
)
f(θ) =
if (q isa Distributions.Distribution)
-vo(alg, AdvancedVI.update(q, θ), model, args...)
else
-vo(alg, q(θ), model, args...)
end
# Use `Enzyme.ReverseWithPrimal` once it is released:
# https://github.com/EnzymeAD/Enzyme.jl/pull/598
f(θ) = if (q isa Distributions.Distribution)
-vo(alg, AdvancedVI.update(q, θ), model, args...)
else
-vo(alg, q(θ), model, args...)
end

y = f(θ)
DiffResults.value!(out, y)
dy = DiffResults.gradient(out)
fill!(dy, 0)
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true),
Enzyme.Const(f),
Enzyme.Active,
Enzyme.Duplicated(θ, dy)
)
return out
end

Expand Down
5 changes: 1 addition & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ using ReverseDiff: ReverseDiff
using Tracker: Tracker
using Zygote: Zygote
using Enzyme: Enzyme
Enzyme.API.runtimeActivity!(true);
Enzyme.API.typeWarning!(false);

using AdvancedVI

Expand All @@ -22,7 +20,7 @@ include("optimisers.jl")
AutoReverseDiff(),
AutoTracker(),
AutoZygote(),
# AutoEnzyme() # results in incorrect result
# AutoEnzyme()
]
target = MvNormal(ones(2))
logπ(z) = logpdf(target, z)
Expand All @@ -42,5 +40,4 @@ include("optimisers.jl")

xs = rand(target, 10)
@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) 0.05

end

0 comments on commit e5e26d0

Please sign in to comment.