Skip to content

Commit

Permalink
fix some setup bugs (#2145)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Dec 28, 2022
1 parent ba48ad0 commit 4da339e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ end
_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thesh) # called omega, and there are more fields
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thesh) # called delta now, and struct name differs
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs
const ClipGrad = Optimise.ClipValue
_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred

Expand Down
5 changes: 3 additions & 2 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module Train

using LinearAlgebra
using Optimisers: Optimisers
using Functors: fmap
using Functors: fmap, fmapstructure

import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions

Expand Down Expand Up @@ -48,7 +48,8 @@ julia> opt_state # mutated by Flux.train!
"""
function setup(rule::Optimisers.AbstractRule, model)
state = Optimisers.setup(rule, model)
fmap(model, exclude = Optimisers.isnumeric) do x
# This check only needs foreach; using fmap caused https://github.com/FluxML/Flux.jl/issues/2144
fmapstructure(model, exclude = Optimisers.isnumeric) do x
Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`.
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""")
end
Expand Down
7 changes: 7 additions & 0 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,10 @@ end
@test diff1 diff3
end

@testset "Flux.setup bugs" begin
# https://github.com/FluxML/Flux.jl/issues/2144
@test Flux.setup(Flux.Adam(), Embedding(3 => 1)).weight isa Optimisers.Leaf
# Typo in 0.13.9's deprecation
@test Flux.setup(Flux.ClipValue(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipGrad
@test Flux.setup(Flux.ClipNorm(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipNorm
end

0 comments on commit 4da339e

Please sign in to comment.