-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic rewrite of the package 2023 edition Part I: ADVI #49
Merged
Red-Portal
merged 213 commits into
TuringLang:master
from
Red-Portal:rewriting_advancedvi_optimize
Dec 8, 2023
Merged
Changes from 175 commits
Commits
Show all changes
213 commits
Select commit
Hold shift + click to select a range
b49cf3e
refactor ADVI, change gradient operation interface
Red-Portal 88e0b79
remove unused file, remove unused dependency
Red-Portal c2fb3f8
fix ADVI elbo computation more efficiently
Red-Portal 83161fd
fix missing entropy regularization term
Red-Portal efa8106
add LogDensityProblem interface
Red-Portal 4ae2fbf
refactor use bijectors directly instead of transformed distributions
Red-Portal 2bf2a42
Merge branch 'master' of https://github.com/TuringLang/AdvancedVI.jl …
Red-Portal 1cadb51
fix type restrictions
Red-Portal 3474e8d
remove unused file
Red-Portal 03a2767
fix use of with_logabsdet_jacobian
Red-Portal 09c44fb
restructure project; move the main VI routine to its own file
Red-Portal b7407ce
remove redundant import
Red-Portal 4040149
restructure project into more modular objective estimators
Red-Portal 2a4514e
migrate to AbstractDifferentiation
Red-Portal 93a16d8
add location scale pre-packaged variational family, add functors
Red-Portal 2b6e9eb
Revert "migrate to AbstractDifferentiation"
Red-Portal 1bfec36
fix use optimized MvNormal specialization, add logpdf for Loc.Scale.
Red-Portal 1003606
remove dead code
Red-Portal 60a9987
fix location-scale logpdf
Red-Portal cd84f02
add sticking-the-landing (STL) estimator
Red-Portal 768641b
migrate to Optimisers.jl
Red-Portal ca02fa3
remove execution time measurement (replace later with somethin else)
Red-Portal a48377f
fix use multiple dispatch for deciding whether to stop entropy grad.
Red-Portal 0b40ccf
add termination decision, callback arguments
Red-Portal 21db3fb
add Base.show to modules
Red-Portal 25c51b4
add interface calling `restructure`, rename rebuild -> restructure
Red-Portal fc20046
add estimator state interface, add control variate interface to ADVI
Red-Portal 6faa807
fix `show(advi)` to show control variate
Red-Portal 7095d27
fix simplify `show(advi.control_variate)`
Red-Portal 9169ae2
fix type piracy by wrapping location-scale bijected distribution
Red-Portal 3db7301
remove old AdvancedVI custom optimizers
Red-Portal e6a082a
fix Location Scale to not depend on Bijectors
Red-Portal a034ebd
fix RNG namespace
Red-Portal e19abd3
fix location scale logpdf bug
Red-Portal 680c186
add Accessors dependency
Red-Portal 6c3efa8
Merge branch 'master' of https://github.com/TuringLang/AdvancedVI.jl …
Red-Portal 4c6cabf
add location scale, autodiff tests
Red-Portal 06db2f0
add Accessors import statement
Red-Portal 12de2bd
remove optimiser tests
Red-Portal bbb2cc6
refactor slightly generalize the distribution tests for the future
Red-Portal 1974846
migrate to SimpleUnPack, migrate to ADTypes
Red-Portal 19c62c8
rename vi.jl to optimize.jl
Red-Portal 63da51d
fix estimate_gradient to use adtypes
Red-Portal 65ab473
add exact inference tests
Red-Portal 3e5a452
remove Turing dependency in tests
Red-Portal 3117cec
remove unused projection
Red-Portal b1ca9cf
remove redundant `ADVIEnergy` object (now baked into `ADVI`)
Red-Portal fcbb729
add more tests, fix rng seed for tests
Red-Portal 0f6f6a4
add more tests, fix seed for tests
Red-Portal f5f5863
fix non-determinism bug
Red-Portal ade0d10
fix test hyperparameters so that tests pass, minor cleanups
Red-Portal 0caf7a9
fix minor reorganization
Red-Portal 5658cbf
add missing files
Red-Portal c712a97
fix add missing file, rename adbackend argument
Red-Portal bee839d
fix errors
Red-Portal 913911e
rename test suite
Red-Portal d50cabb
refactor renamed arguments for ADVI to be shorter
Red-Portal b134f70
fix compile error in advi test
Red-Portal a6ba379
add initial doc
Red-Portal 619b1c0
remove unused epsilon argument in location scale
Red-Portal f1c02f0
add project file for documenter
Red-Portal b0f259a
refactor STL gradient calculation to use multiple dispatch
Red-Portal b72c258
fix type bugs, relax test threshold for the exact inference tests
Red-Portal a8df9eb
refactor derivative utils to match NormalizingFlows.jl with extras
Red-Portal e8db6a7
add documentation, refactor optimize
Red-Portal 65a2b37
fix bug missing extension
Red-Portal 1a02051
remove tracker from tests
Red-Portal d8b5ea5
remove export for internal derivative utils
Red-Portal 818bc2c
fix test errors, old interface
Red-Portal 215abf3
fix wrong derivative interface, add documentation
Red-Portal 88ad768
update documentation
Red-Portal e66935b
add doc build CI
Red-Portal 9f1c647
remove convergence criterion for now
Red-Portal c8b3ee3
remove outdated export
Red-Portal afda1a1
update documentation
Red-Portal 0d37ace
update documentation
Red-Portal b8b113d
update documentation
Red-Portal b78e713
fix type error in test
Red-Portal a0564b5
remove default ADType argument
Red-Portal 3795d1e
update README
Red-Portal 28a35bc
update make getting started example actually run Julia
Red-Portal 620b38e
fix remove Float32 tests for inference tests
Red-Portal fa53398
update version
Red-Portal e909f41
add documentation publishing url
Red-Portal 43f5b75
fix wrong uuid for ForwardDiff
Red-Portal 468d5ca
Update CI.yml
yebai c07a511
refactor use `sum` and `mean` instead of abusing `mapreduce`
Red-Portal 8256df1
Merge branch 'rewriting_advancedvi' of github.com:Red-Portal/Advanced…
Red-Portal 13a8a44
remove tests for `FullMonteCarlo`
Red-Portal aadf8d3
add tests for the `optimize` interface
Red-Portal 8c4e13d
fix turn off Zygote tests for now
Red-Portal 0b708e6
remove unused function
Red-Portal be61acd
refactor change bijector field name, simplify STL estimator
Red-Portal fb519a5
update documentation
Red-Portal 8682fd9
update STL documentation
Red-Portal 9a16ee1
update STL documentation
Red-Portal fc74afa
update location scale documentation
Red-Portal 4be30a1
fix README
Red-Portal c58309d
fix math in README
Red-Portal 5b5bd3e
add gradient to arguments of callback!, remove `gradient_norm` info
Red-Portal 967021d
fix math in README.md
Red-Portal 4dab522
fix type constraint in `ZygoteExt`
Red-Portal 8ab2f19
fix import of `Random`
Red-Portal 83dec9f
refactor `__init__()`
Red-Portal a3e563c
fix type constraint in definition of `value_and_gradient!`
Red-Portal 5553bb9
refactor `ZygoteExt`; use `only` instead of `first`
Red-Portal 79b4557
refactor type constraint in `ReverseDiffExt`
Red-Portal 656b44b
refactor remove outdated debug mode macro
Red-Portal c794063
fix remove outdated DEBUG mechanism
Red-Portal 0c5cc1c
fix LaTeX in README: `operatorname` is currently broken
Red-Portal 29d7d27
remove `SimpleUnPack` dependency
Red-Portal 75eef44
fix LaTeX in docs and README
Red-Portal 40574f4
add warning about forward-mode AD when using `LocationScale`
Red-Portal 8738256
fix documentation
Red-Portal 8173744
fix remove reamining use of `@unpack`
Red-Portal e0548ae
Revert "remove `SimpleUnPack` dependency"
Red-Portal 6ab95a0
Revert "fix remove reamining use of `@unpack`"
Red-Portal f0ec242
fix documentation for `optimize`
Red-Portal 1d4c1b6
add specializations of `Optimise.destructure` for mean-field
Red-Portal 231835f
add test for `Optimisers.destructure` specializations
Red-Portal ea2d426
add specialization of `rand` for meanfield resulting in faster AD
Red-Portal 3033d75
add argument checks for `VIMeanFieldGaussian`, `VIFullRankGaussian`
Red-Portal 0cc36c0
update documentation
Red-Portal b7d3471
fix type instability, bug in argument check in `LocationScale`
Red-Portal df50e83
add missing import bug
Red-Portal ae3e9b0
refactor test, fix type bug in tests for `LocationScale`
Red-Portal e4002cf
add missing compat entries
Red-Portal 8c82569
fix missing package import in test
Red-Portal c2e7517
add additional tests for sampling `LocationScale`
Red-Portal 3a6f8bf
fix bug in batch in-place `rand!` for `LocationScale`
Red-Portal b78ef4b
fix bug in inference test initialization
Red-Portal a1f7e98
add missing file
Red-Portal 8b783ec
fix remove use of for 1.6
Red-Portal 12cd9f2
refactor adjust inference test hyperparameters to be more robust
Red-Portal 837c729
refactor `optimize` to return `obj_state`, add warm start kwargs
Red-Portal 95629a5
refactor make tests more robust, reduce amount of tests
Red-Portal 0b4b865
fix remove a cholesky in test model
Red-Portal b49f4eb
fix compat bounds, remove unused package
Red-Portal 947a070
bump compat for ADTypes 0.2
Red-Portal a9b3f48
fix broken LaTeX in README
Red-Portal 54826eb
remove redundant use of PDMats in docs
Red-Portal 1d1c8ff
fix use `Cholesky` signature supported in 1.6
Red-Portal 7bac95b
revert custom variational families and docs
Red-Portal d2ae29f
remove doc action for now
Red-Portal fb84e3d
revert README for now
Red-Portal 0575b23
refactor remove redundant `rng` argument to `ADVI`, improve docs
Red-Portal ecc5242
fix wrong whitespace in tests
Red-Portal 1cff3df
refactor `estimate_gradient` to `estimate_gradient!`, add docs
Red-Portal 54acd8a
refactor add default `init` impl, update docs
Red-Portal 61a2272
merge (manually) commit ff32ac642d6aa3a08d371ed895aa6b4026b06b92
Red-Portal c56d29e
fix test for new interface, change interface for `optimize`, `advi`
Red-Portal 913b469
fix integer subtype error in documentation of advi
Red-Portal 385a653
fix remove redundant argument for `advi`
Red-Portal 4716b62
Merge branch 'rewriting_advancedvi_optimize' of github.com:Red-Portal…
Red-Portal c9df90e
remove manifest
Red-Portal 19d11d1
refactor remove imports and use fully qualified names
Red-Portal 59bd4f8
update documentation for `AbstractVariationalObjective`
Red-Portal dedc5cf
refactor use StableRNG instead of Random123
Red-Portal e35dc67
refactor migrate to Test, re-enable x86 tests
Red-Portal 6413183
refactor remove inner constructor for `ADVI`
Red-Portal 1668bae
fix swap `export`s and `include`s
Red-Portal a8f1254
fix doscs for `ADVI`
Red-Portal 7b368c1
fix use `FillArrays` in the test problems
Red-Portal f216b37
fix `optimize` docs
Red-Portal 9e0338d
fix improve argument names and docs for `optimize`
Red-Portal d6fcaf6
fix tests to match new interface of `optimize`
Red-Portal 5799f1e
refactor move utility functions to new file
Red-Portal 2229d61
fix docs for `optimize`
Red-Portal bc48e14
refactor advi internal objective
Red-Portal 9949a04
refactor move `rng` to be an optional first argument
Red-Portal 81010cd
Merge branch 'rewriting_advancedvi_optimize' of github.com:Red-Portal…
Red-Portal 92cf354
fix docs for optimize
Red-Portal d75fd3c
add compat bounds to test dependencies
Red-Portal faa91ce
update compat bound for `Optimisers`
Red-Portal 6dc0bb7
fix test compat
Red-Portal e941ad4
fix remove `!` in callback
Red-Portal 15e0553
fix rng argument position in `advi`
Red-Portal a643cf2
fix callback signature in `optimize`
Red-Portal ffa69a3
refactor reorganize test files and naming
Red-Portal d5026e1
fix simplify description for `optimize`
Red-Portal 764406b
fix remove redundant `Nothing` type signature for `maybe_init`
Red-Portal 65006cb
fix remove "internal use" warning in documentation
Red-Portal b23a610
refactor change `estimate_gradient!` signature to be type stable
Red-Portal 6c6634f
Merge branch 'rewriting_advancedvi_optimize' of github.com:Red-Portal…
Red-Portal 9c242a5
add signature for computing `advi` over a fixed set of samples
Red-Portal e014863
fix change test tolerance
Red-Portal 71184fa
fix update documentation for `estimate_gradient!`
Red-Portal 9f6d663
refactor remove type constraint for variational parameters
Red-Portal a673520
fix remove dead code
Red-Portal a3f9886
add compat entry for stdlib
Red-Portal 7a92708
add compat entry for stdlib in `test/`
Red-Portal 5dd434d
fix rng argument position in tests
Red-Portal a764d9b
refactor change name of inference test
Red-Portal 8af8a5f
fix documentation for `optimize`
Red-Portal 5f1fb52
refactor rewrite the documentation for the global interfaces
Red-Portal 2491c64
fix compat error
Red-Portal 92d1489
fix documentation for `optimize` to be single line
Red-Portal a03e955
refactor remove begin end for one-liner
Red-Portal ff83c03
refactor create unified interface for estimating objectives
Red-Portal aecc655
refactor unify interface for entropy estimator, fix advi docs
Red-Portal a8d532a
fix STL estimator to use manually stopped gradients instead
Red-Portal 65e9b12
add inference test for a non-bijector model
Red-Portal 3691f16
refactor add indirections to handle STL and bijectors in ADVI
Red-Portal a063583
refactor split inference tests for advi+distributionsad
Red-Portal 316b629
refactor rename advi to repgradelbo and not use bijectors directly
Red-Portal 13b2088
fix documentation for estimate_objective
Red-Portal b0e1be1
refactor add indirection in repgradelbo for interacting with `q`
Red-Portal 7361ed4
add TransformedDistribution support as extension
Red-Portal d2e7614
Update src/objectives/elbo/repgradelbo.jl
Red-Portal 77686b5
fix docstring for entropy estimator
Red-Portal 8461b43
fix `reparam_with_entropy` specialization for bijectors
Red-Portal 8c559e3
Merge branch 'rewriting_advancedvi_optimize' of github.com:Red-Portal…
Red-Portal bd925cc
enable Zygote for non-bijector tests
Red-Portal File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,67 @@ | ||
name = "AdvancedVI" | ||
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" | ||
version = "0.2.4" | ||
version = "0.3.0" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | ||
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" | ||
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Requires = "ae029012-a4dd-5104-9daa-d747884805df" | ||
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" | ||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" | ||
|
||
[weakdeps] | ||
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[extensions] | ||
AdvancedVIEnzymeExt = "Enzyme" | ||
AdvancedVIForwardDiffExt = "ForwardDiff" | ||
AdvancedVIReverseDiffExt = "ReverseDiff" | ||
AdvancedVIZygoteExt = "Zygote" | ||
|
||
[compat] | ||
Bijectors = "0.11, 0.12, 0.13" | ||
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" | ||
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" | ||
ADTypes = "0.1, 0.2" | ||
Accessors = "0.1" | ||
Bijectors = "0.12, 0.13" | ||
ChainRulesCore = "1.16" | ||
DiffResults = "1" | ||
Distributions = "0.25.87" | ||
DocStringExtensions = "0.8, 0.9" | ||
ForwardDiff = "0.10.3" | ||
ProgressMeter = "1.0.0" | ||
Requires = "0.5, 1.0" | ||
Enzyme = "0.11.7" | ||
FillArrays = "1.3" | ||
ForwardDiff = "0.10.36" | ||
Functors = "0.4" | ||
LogDensityProblems = "2" | ||
Optimisers = "0.2.16, 0.3" | ||
ProgressMeter = "1.6" | ||
Requires = "1.0" | ||
ReverseDiff = "1.15.1" | ||
SimpleUnPack = "1.1.0" | ||
StatsBase = "0.32, 0.33, 0.34" | ||
StatsFuns = "0.8, 0.9, 1" | ||
Tracker = "0.2.3" | ||
Zygote = "0.6.63" | ||
julia = "1.6" | ||
|
||
[extras] | ||
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" | ||
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
|
||
[targets] | ||
test = ["Pkg", "Test"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
|
||
module AdvancedVIEnzymeExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using Enzyme | ||
using AdvancedVI | ||
using AdvancedVI: ADTypes, DiffResults | ||
else | ||
using ..Enzyme | ||
using ..AdvancedVI | ||
using ..AdvancedVI: ADTypes, DiffResults | ||
end | ||
|
||
# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) | ||
function AdvancedVI.value_and_gradient!( | ||
ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult | ||
) where {T<:Real} | ||
y = f(θ) | ||
DiffResults.value!(out, y) | ||
∇θ = DiffResults.gradient(out) | ||
fill!(∇θ, zero(T)) | ||
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) | ||
return out | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
|
||
module AdvancedVIForwardDiffExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using ForwardDiff | ||
using AdvancedVI | ||
using AdvancedVI: ADTypes, DiffResults | ||
else | ||
using ..ForwardDiff | ||
using ..AdvancedVI | ||
using ..AdvancedVI: ADTypes, DiffResults | ||
end | ||
|
||
getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize | ||
|
||
function AdvancedVI.value_and_gradient!( | ||
ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult | ||
) where {T<:Real} | ||
chunk_size = getchunksize(ad) | ||
config = if isnothing(chunk_size) | ||
ForwardDiff.GradientConfig(f, θ) | ||
else | ||
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) | ||
end | ||
ForwardDiff.gradient!(out, f, θ, config) | ||
return out | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
|
||
module AdvancedVIReverseDiffExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using AdvancedVI | ||
using AdvancedVI: ADTypes, DiffResults | ||
using ReverseDiff | ||
else | ||
using ..AdvancedVI | ||
using ..AdvancedVI: ADTypes, DiffResults | ||
using ..ReverseDiff | ||
end | ||
|
||
# ReverseDiff without compiled tape | ||
function AdvancedVI.value_and_gradient!( | ||
ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult | ||
) | ||
tp = ReverseDiff.GradientTape(f, θ) | ||
ReverseDiff.gradient!(out, tp, θ) | ||
return out | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
|
||
module AdvancedVIZygoteExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using AdvancedVI | ||
using AdvancedVI: ADTypes, DiffResults | ||
using Zygote | ||
else | ||
using ..AdvancedVI | ||
using ..AdvancedVI: ADTypes, DiffResults | ||
using ..Zygote | ||
end | ||
|
||
function AdvancedVI.value_and_gradient!( | ||
ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult | ||
) | ||
y, back = Zygote.pullback(f, θ) | ||
∇θ = back(one(y)) | ||
DiffResults.value!(out, y) | ||
DiffResults.gradient!(out, only(∇θ)) | ||
return out | ||
end | ||
|
||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not handle compiled tape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll look into it.