Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer learning ResNet #395

Merged
merged 10 commits into from
Feb 24, 2023
Merged

Conversation

jeremiedb
Copy link
Contributor

@jeremiedb jeremiedb commented Feb 16, 2023

I am not too sure what is the desired target format for the tutorials, notably as I noticed that associated tutorials like the A 60 Minute Blitz is no longer directly visible on https://fluxml.ai/.

However, I thought it would still be relevant to have blog post stlye tutorials, similar to the [DataLoader] one: https://github.com/FluxML/model-zoo/tree/master/tutorials/dataloader.

As such, I made a complete transfer learning tutorial for vision in such format, which in the meatime shows how to use custom data container and data augmentation, which seems fairly commonly asked for.

If such format is deemed of interest, I'll add a some discussion elements to enrich a bit the explanation of each of the code blocks.

I'd also suggest that I replace exisiting transfer_learning.jl and dataloader.jl by the single self contained script that allows to directly run the tutorial presented in the README.

fix typos and add context
fix typos and add context
fix typos and add context
@jeremiedb jeremiedb changed the title [WIP] transfer learning ResNet Transfer learning ResNet Feb 16, 2023
@jeremiedb
Copy link
Contributor Author

Would be ready for review.

tutorials/transfer_learning/README.md Outdated Show resolved Hide resolved
Comment on lines 82 to 91
function getindex(data::ImageContainer, idx::Int)
path = data.img[idx]
img = Images.load(path)
img = apply(tfm, Image(img))
img = permutedims(channelview(RGB.(itemdata(img))), (3, 2, 1))
img = Float32.(img)
name = replace(path, r"(.+)\\(.+)\\(.+_\d+)\.jpg" => s"\2")
y = name_to_idx[name]
return img, Flux.onehotbatch(y, 1:3)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of applying the transformation in getindex it would be better to showcase MLUtils.mapobs. The advantage is that it is a pattern that can be used with any dataset. An example with mnist is given in the README here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is what you suggest to have a sepration between a minimal getindex and a transform function, so something like:

function getindex(data::ImageContainer, idx::Int)
    path = data.img[idx]
    x = Images.load(path)
    name = replace(path, r"(.+)\\(.+)\\(.+_\d+)\.jpg" => s"\2")
    y = name_to_idx[name]
    return (x = x, y = y)
end
function img_transform(batch)
    img = apply(tfm, Image(batch[:x]))
    img = permutedims(channelview(RGB.(itemdata(img))), (3, 2, 1))
    x = Float32.(img)
    y = Flux.onehotbatch(batch[:y], 1:3)
    return (x, y)
end

dtrain = Flux.mapobs(img_transform, ImageContainer(imgs[1:2700]))

That's out of scope for thie PR, but benchmarking to validate there wasn't any performance difference, both approaches perform essentially the same, but performance degrades following each iteration:

function data_loop(data)
    count = 0
    for (x, y) in data
        count += size(y, 2)
    end
    @info count
end
@btime data_loop($dtrain)
2.778 s (505360 allocations: 6.86 GiB)
4.000 s (505459 allocations: 6.86 GiB)
5.303 s (505461 allocations: 6.86 GiB)
5.423 s (505424 allocations: 6.86 GiB)

That's with MLUtils 2.11, so may be ignored until tutorial can be updated to latest Flux/MLUtils versions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is what you suggest to have a sepration between a minimal getindex and a transform function

yes

but performance degrades following each iteration

this is awful, if you have some time please open an issue MLUtils.jl

For this PR do as you prefer regarding the use of mapobs or not

Comment on lines 220 to 235
```julia
function train_epoch!(m_infer, m_tune; ps, opt, dtrain)
for (x, y) in dtrain
infer = m_infer(x)
grads = gradient(ps) do
Flux.Losses.logitcrossentropy(m_tune(infer), y)
end
update!(opt, ps, grads)
end
end
```

```julia
ps = Flux.params(m_tune);
opt = Adam(3e-4)
```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use the new Optmisers.jl interface instead of the params-based one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metalhead compat forces use of Flux v0.13.4, for which my understanding was that the new Optimisers.jl wasn't yet in place.
There were some pending issues on weights import with Metalhead which once fixed, should allow to bump Flux compat, and I'd migrate the tutorial to Optimisers/explicit gradients once done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new Metalhed version just got released

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating to latest Flux brought some new challenges with MLUtils!

Creating the DataLoaders with parallel=true results in stalled / hanging julia session, at least on Windows: https://github.com/jeremiedb/model-zoo/blob/51e6a091aa9d790d3162cfe8a106ad45aa0beac1/tutorials/transfer_learning/transfer_learning.jl#L35-L56
This occurs when running manually from the REPL, but works fine when launched as a script. Looks like it might be related to JuliaML/MLUtils.jl#142

Another "annoyance" is the need to have a collect in the getobs recipe, as mentionned in JuliaML/MLUtils.jl#139, as it seems opposite the the typical Julia pattern where explicit collect isn't typically needed.

Would you be fine moving forward with the new Flux/MLUtils/explicit gradient approach, considering the above few caveats? I'd be fine going forward with latest versions and get these few gotchas ironed out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR has been updated to explicit gradients

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CarloLucibello Do you feel this explicit gradient refresh is robust enough to go forward, or should first figure out JuliaML/MLUtils.jl#148 and others cavats?

@CarloLucibello CarloLucibello merged commit 4778b97 into FluxML:master Feb 24, 2023
mcabbott added a commit that referenced this pull request Feb 24, 2023
v0.13 + as it does not use implicit parameters
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants