From b5d9a895dbd3bc9ec4eb60f8946e066af78370cb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 28 Dec 2022 15:11:39 -0500 Subject: [PATCH] fix some setup bugs --- src/deprecations.jl | 4 ++-- src/train.jl | 5 +++-- test/train.jl | 7 +++++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/deprecations.jl b/src/deprecations.jl index 8a445266a4..a763ffd905 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -130,8 +130,8 @@ end _old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...) const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too. _old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now -_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh) # called omega, and there are more fields -_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh) # called delta now, and struct name differs +_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields +_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs const ClipGrad = Optimise.ClipValue _old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred diff --git a/src/train.jl b/src/train.jl index 63d95258b9..a18b1db59e 100644 --- a/src/train.jl +++ b/src/train.jl @@ -2,7 +2,7 @@ module Train using LinearAlgebra using Optimisers: Optimisers -using Functors: fmap +using Functors: fmap, fmapstructure import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions @@ -48,7 +48,8 @@ julia> opt_state # mutated by Flux.train! """ function setup(rule::Optimisers.AbstractRule, model) state = Optimisers.setup(rule, model) - fmap(model, exclude = Optimisers.isnumeric) do x + # This check only needs foreach; using fmap caused https://github.com/FluxML/Flux.jl/issues/2144 + fmapstructure(model, exclude = Optimisers.isnumeric) do x Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`. If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""") end diff --git a/test/train.jl b/test/train.jl index 310102331e..1d938649d0 100644 --- a/test/train.jl +++ b/test/train.jl @@ -139,3 +139,10 @@ end @test diff1 ≈ diff3 end +@testset "Flux.setup bugs" begin + # https://github.com/FluxML/Flux.jl/issues/2144 + @test Flux.setup(Flux.Adam(), Embedding(3 => 1)).weight isa Optimisers.Leaf + # Typo in 0.13.9's deprecation + @test Flux.setup(Flux.ClipValue(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipGrad + @test Flux.setup(Flux.ClipNorm(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipNorm +end