From 4778b9769acad8d01217883bf6bd80e1b48944c3 Mon Sep 17 00:00:00 2001 From: jeremiedb Date: Thu, 23 Feb 2023 20:22:49 -0700 Subject: [PATCH] Transfer learning ResNet (#395) * WIP transfer learning ResNet * quick fixes * cleanup repo fix typos and add context * cleanup repo fix typos and add context * cleanup repo fix typos and add context * change data section title * typo * explicit gradients * explicit gradients script * explicit gradients tutorial --- tutorials/transfer_learning/.gitignore | 1 + tutorials/transfer_learning/Manifest.toml | 288 ++++++++++++------ tutorials/transfer_learning/Project.toml | 4 +- tutorials/transfer_learning/README.md | 252 +++++++++++++++ tutorials/transfer_learning/dataloader.jl | 52 ---- .../transfer_learning/transfer_learning.jl | 189 +++++++----- 6 files changed, 566 insertions(+), 220 deletions(-) create mode 100644 tutorials/transfer_learning/.gitignore create mode 100644 tutorials/transfer_learning/README.md delete mode 100644 tutorials/transfer_learning/dataloader.jl diff --git a/tutorials/transfer_learning/.gitignore b/tutorials/transfer_learning/.gitignore new file mode 100644 index 00000000..adbb97d2 --- /dev/null +++ b/tutorials/transfer_learning/.gitignore @@ -0,0 +1 @@ +data/ \ No newline at end of file diff --git a/tutorials/transfer_learning/Manifest.toml b/tutorials/transfer_learning/Manifest.toml index be1b5735..32b32ccc 100644 --- a/tutorials/transfer_learning/Manifest.toml +++ b/tutorials/transfer_learning/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "b0ad532d6c3d60bd11582fc3df005951e8cf6f5d" +project_hash = "1b0c9d969899e62a3b07dc1cd1fd4fb3525b10c3" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -12,9 +12,9 @@ version = "1.2.1" [[deps.Accessors]] deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "StaticArrays", "Test"] -git-tree-sha1 = "f3d4132fa63a6c62ab19b5765daf87ce2d36076c" +git-tree-sha1 = "4a98a9491dd44348664c371998a75074a6938145" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.25" +version = "0.1.27" [[deps.Adapt]] deps = ["LinearAlgebra"] @@ -37,18 +37,6 @@ git-tree-sha1 = "62e51b39331de8911e4a7ff6f5aaf38a5f4cc0ae" uuid = "ec485272-7323-5ecc-a04f-4719b315124d" version = "0.2.0" -[[deps.ArrayInterface]] -deps = ["ArrayInterfaceCore", "Compat", "IfElse", "LinearAlgebra", "SnoopPrecompile", "Static"] -git-tree-sha1 = "dedc16cbdd1d32bead4617d27572f582216ccf23" -uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "6.0.25" - -[[deps.ArrayInterfaceCore]] -deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d" -uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.29" - [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -66,9 +54,9 @@ version = "0.4.6" [[deps.BFloat16s]] deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" +git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.2.0" +version = "0.4.2" [[deps.BSON]] git-tree-sha1 = "86e9781ac28f4e80e9b98f7f96eae21891332ac2" @@ -95,10 +83,40 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.2" [[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] -git-tree-sha1 = "6717cb9a3425ebb7b31ca4f832823615d175f64a" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions"] +git-tree-sha1 = "edff14c60784c8f7191a62a23b15a421185bc8a8" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "3.13.1" +version = "4.0.1" + +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "75d7896d1ec079ef10d3aee8f3668c11354c03a1" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.2.0+0" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "58dd8ec29f54f08c04b052d2c2fa6760b4f4b3a4" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.1.1" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] +git-tree-sha1 = "d3e6ccd30f84936c1a3a53d622d85d7d3f9b9486" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.2.3+2" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] +git-tree-sha1 = "57011df4fce448828165e566af9befa2ea94350a" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "8.6.0+3" + +[[deps.Calculus]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" +uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +version = "0.5.1" [[deps.CatIndices]] deps = ["CustomUnitRanges", "OffsetArrays"] @@ -108,9 +126,9 @@ version = "0.2.2" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] -git-tree-sha1 = "c46adabdd0348f0ee8de91142cfc4a72a613ac0a" +git-tree-sha1 = "fdde4d8a31cf82b1d136cf6cb53924e8744a832b" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.46.1" +version = "1.47.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] @@ -120,9 +138,9 @@ version = "1.15.7" [[deps.ChangesOfVariables]] deps = ["ChainRulesCore", "LinearAlgebra", "Test"] -git-tree-sha1 = "844b061c104c408b24537482469400af6075aae4" +git-tree-sha1 = "485193efd2176b88e6622a39a246f8c5b600e74e" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.5" +version = "0.1.6" [[deps.Clustering]] deps = ["Distances", "LinearAlgebra", "NearestNeighbors", "Printf", "Random", "SparseArrays", "Statistics", "StatsBase"] @@ -130,6 +148,12 @@ git-tree-sha1 = "64df3da1d2a26f4de23871cd1b6482bb68092bd5" uuid = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" version = "0.14.3" +[[deps.ColorBlendModes]] +deps = ["ColorTypes", "FixedPointNumbers"] +git-tree-sha1 = "9ec825436862d5ab02ad8f2cde72f2d860151fa6" +uuid = "60508b50-96e1-4007-9d6c-f475c410f16b" +version = "0.2.4" + [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" @@ -156,9 +180,9 @@ version = "0.3.0" [[deps.Compat]] deps = ["Dates", "LinearAlgebra", "UUIDs"] -git-tree-sha1 = "00a2cccc7f098ff3b66806862d275ca3db9e6e5a" +git-tree-sha1 = "61fdd77467a5c3ad071ef8277ac6bd6af7dd4c04" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.5.0" +version = "4.6.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] @@ -203,6 +227,12 @@ git-tree-sha1 = "e8119c1a33d267e16108be441a287a6981ba1630" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.14.0" +[[deps.DataAugmentation]] +deps = ["ColorBlendModes", "CoordinateTransformations", "Distributions", "ImageCore", "ImageDraw", "ImageTransformations", "IndirectArrays", "Interpolations", "LinearAlgebra", "MosaicViews", "OffsetArrays", "Parameters", "Random", "Rotations", "Setfield", "StaticArrays", "Statistics", "Test"] +git-tree-sha1 = "9073179282095c1ce9590e19525ad864a7b83211" +uuid = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" +version = "0.2.11" + [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" @@ -227,6 +257,12 @@ version = "0.1.2" deps = ["Mmap"] uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +[[deps.DensityInterface]] +deps = ["InverseFunctions", "Test"] +git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" +uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +version = "0.4.0" + [[deps.DiffResults]] deps = ["StaticArraysCore"] git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" @@ -249,6 +285,12 @@ version = "0.10.7" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[deps.Distributions]] +deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] +git-tree-sha1 = "74911ad88921455c6afcad1eefa12bd7b1724631" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.80" + [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" @@ -260,6 +302,12 @@ deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" +[[deps.DualNumbers]] +deps = ["Calculus", "NaNMath", "SpecialFunctions"] +git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" +uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" +version = "0.6.8" + [[deps.ExprTools]] git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" @@ -317,10 +365,10 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.4" [[deps.Flux]] -deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"] -git-tree-sha1 = "96dc065bf4a998e8adeebc0ff1302902b6e59362" +deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote"] +git-tree-sha1 = "4ff3a1d7b0dd38f2fc38e813bc801f817639c1f2" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.4" +version = "0.13.13" [[deps.FoldsThreads]] deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] @@ -340,9 +388,10 @@ uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" version = "1.1.3" [[deps.Functors]] -git-tree-sha1 = "223fffa49ca0ff9ce4f875be001ffe173b2b7de4" +deps = ["LinearAlgebra"] +git-tree-sha1 = "7ed0833a55979d3d2658a60b901469748a6b9a7c" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.2.8" +version = "0.4.3" [[deps.Future]] deps = ["Random"] @@ -350,21 +399,27 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "4dfaff044eb2ce11a897fecd85538310e60b91e6" +git-tree-sha1 = "a28f752ffab0ccd6660fc7af5ad1c9ad176f45f7" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.6.2" +version = "8.6.3" [[deps.GPUArraysCore]] deps = ["Adapt"] -git-tree-sha1 = "57f7cde02d7a53c9d1d28443b9f11ac5fbe7ebc9" +git-tree-sha1 = "1cd7f0af1aa58abc02ea1d872953a97359cb87fa" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.3" +version = "0.1.4" [[deps.GPUCompiler]] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "48832a7cacbe56e591a7bef690c78b9d00bcc692" +git-tree-sha1 = "95185985a5d2388c6d0fedb06181ad4ddd40e0cb" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.17.1" +version = "0.17.2" + +[[deps.Ghostscript_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "43ba3d3c82c18d88471cfd2924931658838c9d8f" +uuid = "61579ee1-b43e-5ca0-a5da-69d92c66a64b" +version = "9.55.0+4" [[deps.Graphics]] deps = ["Colors", "LinearAlgebra", "NaNMath"] @@ -374,20 +429,21 @@ version = "1.1.2" [[deps.Graphs]] deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "ba2d094a88b6b287bd25cfa86f301e7693ffae2f" +git-tree-sha1 = "1cf1d7dcb4bc32d7b4a5add4232db3750c27ecb4" uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.7.4" +version = "1.8.0" + +[[deps.HypergeometricFunctions]] +deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions", "Test"] +git-tree-sha1 = "709d864e3ed6e3545230601f94e11ebc65994641" +uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" +version = "0.3.11" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "2e99184fca5eb6f075944b04c22edec29beb4778" +git-tree-sha1 = "2af2fe19f0d5799311a6491267a14817ad9fbd20" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.7" - -[[deps.IfElse]] -git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" -uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" -version = "0.1.1" +version = "0.4.8" [[deps.ImageAxes]] deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"] @@ -419,6 +475,12 @@ git-tree-sha1 = "b1798a4a6b9aafb530f8f0c4a7b2eb5501e2f2a3" uuid = "51556ac3-7006-55f5-8cb3-34580c88182d" version = "0.2.16" +[[deps.ImageDraw]] +deps = ["Distances", "ImageCore", "LinearAlgebra"] +git-tree-sha1 = "6ed6e945d909f87c3013e391dcd3b2a56e48b331" +uuid = "4381153b-2b60-58ae-a1ba-fd683676385f" +version = "0.2.5" + [[deps.ImageFiltering]] deps = ["CatIndices", "ComputationalResources", "DataStructures", "FFTViews", "FFTW", "ImageBase", "ImageCore", "LinearAlgebra", "OffsetArrays", "Reexport", "SnoopPrecompile", "SparseArrays", "StaticArrays", "Statistics", "TiledIteration"] git-tree-sha1 = "f265e53558fbbf23e0d54e4fab7106c0f2a9e576" @@ -432,16 +494,16 @@ uuid = "82e4d734-157c-48bb-816b-45c225c6df19" version = "0.6.6" [[deps.ImageMagick]] -deps = ["FileIO", "ImageCore", "ImageMagick_jll", "InteractiveUtils"] -git-tree-sha1 = "ca8d917903e7a1126b6583a097c5cb7a0bedeac1" +deps = ["FileIO", "ImageCore", "ImageMagick_jll", "InteractiveUtils", "Libdl", "Pkg", "Random"] +git-tree-sha1 = "5bc1cb62e0c5f1005868358db0692c994c3a13c6" uuid = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" -version = "1.2.2" +version = "1.2.1" [[deps.ImageMagick_jll]] -deps = ["JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pkg", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "1c0a2295cca535fabaf2029062912591e9b61987" +deps = ["Artifacts", "Ghostscript_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pkg", "Zlib_jll", "libpng_jll"] +git-tree-sha1 = "124626988534986113cfd876e3093e4a03890f58" uuid = "c73af94c-d91f-53ed-93a7-00f77d67a9d7" -version = "6.9.10-12+3" +version = "6.9.12+3" [[deps.ImageMetadata]] deps = ["AxisArrays", "ImageAxes", "ImageBase", "ImageCore"] @@ -523,10 +585,10 @@ deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[deps.Interpolations]] -deps = ["Adapt", "AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] -git-tree-sha1 = "721ec2cf720536ad005cb38f50dbba7b02419a15" +deps = ["AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] +git-tree-sha1 = "b7bc05649af456efc75d178846f47006c2c4c3c7" uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -version = "0.14.7" +version = "0.13.6" [[deps.IntervalSets]] deps = ["Dates", "Random", "Statistics"] @@ -569,15 +631,15 @@ version = "1.4.1" [[deps.JpegTurbo]] deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] -git-tree-sha1 = "a77b273f1ddec645d1b7c4fd5fb98c8f90ad10a5" +git-tree-sha1 = "106b6aa272f294ba47e96bd3acbabdc0407b5c60" uuid = "b835a17e-a41a-41e7-81f0-2f016b05efe0" -version = "0.1.1" +version = "0.1.2" [[deps.JpegTurbo_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "b53380851c6e6664204efb2e62cd24fa5c47e4ba" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "2.1.2+0" +version = "2.1.91+0" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -593,15 +655,15 @@ version = "3.0.0+1" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "b8ae281340f0d3e973aae7b96fb7502b0119b376" +git-tree-sha1 = "df115c31f5c163697eede495918d8e85045c8f04" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "4.15.0" +version = "4.16.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] -git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" +git-tree-sha1 = "7718cf44439c676bc0ec66a87099f41015a522d6" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.16+0" +version = "0.0.16+2" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -646,9 +708,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "45b288af6956e67e621c5cbb2d75a261ab58300b" +git-tree-sha1 = "071602a0be5af779066df0d7ef4e14945a010818" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.20" +version = "0.3.22" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -660,15 +722,15 @@ uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" version = "2022.2.0+0" [[deps.MLStyle]] -git-tree-sha1 = "060ef7956fef2dc06b0e63b294f7dbfbcbdc7ea2" +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.16" +version = "0.4.17" [[deps.MLUtils]] -deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase", "Transducers"] -git-tree-sha1 = "824e9dfc7509cab1ec73ba77b55a916bb2905e26" +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "FoldsThreads", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "f69cdbb5b7c630c02481d81d50eac43697084fe0" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.2.11" +version = "0.4.1" [[deps.MacroTools]] deps = ["Markdown", "Random"] @@ -698,9 +760,9 @@ version = "0.7.2" [[deps.Metalhead]] deps = ["Artifacts", "BSON", "Flux", "Functors", "LazyArtifacts", "MLUtils", "NNlib", "Random", "Statistics"] -git-tree-sha1 = "a8513152030f7210ccc0b871e03d60c9b13ed0b1" +git-tree-sha1 = "0e95f91cc5f23610f8f270d7397f307b21e19d2b" uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" -version = "0.7.3" +version = "0.7.4" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -729,21 +791,21 @@ version = "2022.2.1" [[deps.NNlib]] deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "b488fc28dfae4c8ec3d61a34a0143a4245e7b13b" +git-tree-sha1 = "ddf38a5d9140bc8c08ea6158484a455ca3efdd2d" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.16" +version = "0.8.18" [[deps.NNlibCUDA]] -deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] -git-tree-sha1 = "b05a082b08a3af0e5c576883bc6dfb6513e7e478" +deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] +git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf" uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" -version = "0.2.6" +version = "0.2.7" [[deps.NaNMath]] deps = ["OpenLibm_jll"] -git-tree-sha1 = "a7c3d1da1189a1c2fe843a3bfa04d18d20eb3211" +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.1" +version = "1.0.2" [[deps.NameResolution]] deps = ["PrettyPrint"] @@ -773,6 +835,12 @@ git-tree-sha1 = "82d7c9e310fe55aa54996e6f7f94674e2a38fcb4" uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" version = "1.12.9" +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "f511fca956ed9e70b80cd3417bb8c2dde4b68644" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.3" + [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" @@ -803,15 +871,21 @@ version = "0.5.5+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "1ef34738708e3f31994b52693286dabcb3d29f6b" +git-tree-sha1 = "e657acef119cc0de2a8c0762666d3b64727b053b" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.9" +version = "0.2.14" [[deps.OrderedCollections]] git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.4.1" +[[deps.PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "cf494dca75a69712a72b80bc48f59dcf3dea63ec" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.16" + [[deps.PNGFiles]] deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] git-tree-sha1 = "f809158b27eba0c18c269cf2a2be6ed751d3e81d" @@ -874,6 +948,12 @@ git-tree-sha1 = "18e8f4d1426e965c7b532ddd260599e1510d26ce" uuid = "4b34888f-f399-49d4-9bb3-47ed5cae4e65" version = "1.0.0" +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "786efa36b7eff813723c4849c90456609cf06661" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.8.1" + [[deps.Quaternions]] deps = ["LinearAlgebra", "Random", "RealDot"] git-tree-sha1 = "da095158bdc8eaccb7890f9884048555ab771019" @@ -934,6 +1014,18 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "1.3.0" +[[deps.Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.7.1" + +[[deps.Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.4.0+0" + [[deps.Rotations]] deps = ["LinearAlgebra", "Quaternions", "Random", "StaticArrays", "Statistics"] git-tree-sha1 = "9480500060044fd25a1c341da53f34df7443c2f2" @@ -970,9 +1062,9 @@ version = "0.9.4" [[deps.SimpleWeightedGraphs]] deps = ["Graphs", "LinearAlgebra", "Markdown", "SparseArrays", "Test"] -git-tree-sha1 = "a8d28ad975506694d59ac2f351e29243065c5c52" +git-tree-sha1 = "7d0b07df35fccf9b866a94bcab98822a87a3cb6f" uuid = "47aef6b3-ad0c-573a-a1e2-d07658019622" -version = "1.2.2" +version = "1.3.0" [[deps.Sixel]] deps = ["Dates", "FileIO", "ImageCore", "IndirectArrays", "OffsetArrays", "REPL", "libsixel_jll"] @@ -1017,17 +1109,11 @@ git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" version = "0.1.1" -[[deps.Static]] -deps = ["IfElse"] -git-tree-sha1 = "c35b107b61e7f34fa3f124026f2a9be97dea9e1c" -uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.8.3" - [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "6954a456979f23d05085727adb17c4551c19ecd1" +git-tree-sha1 = "2d7d9e1ddadc8407ffd460e24218e37ef52dd9a3" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.12" +version = "1.5.16" [[deps.StaticArraysCore]] git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" @@ -1050,6 +1136,12 @@ git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" version = "0.33.21" +[[deps.StatsFuns]] +deps = ["ChainRulesCore", "HypergeometricFunctions", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "ab6083f09b3e617e34a956b43e9d51b824206932" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "1.1.1" + [[deps.StructArrays]] deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] git-tree-sha1 = "b03a3b745aa49b566f128977a7dd1be8711c5e71" @@ -1146,10 +1238,10 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a" version = "1.2.12+3" [[deps.Zstd_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "e45044cd873ded54b6a5bac0eb5c971392cf1927" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "c6edfe154ad7b313c01aceca188c05c835c67360" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.2+0" +version = "1.5.4+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] @@ -1163,6 +1255,12 @@ git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" uuid = "700de1a5-db45-46bc-99cf-38207098b444" version = "0.2.2" +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDNN_jll"] +git-tree-sha1 = "c0ffcb38d1e8c0bbcd3dab2559cf9c369130b2f2" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.0.1" + [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" diff --git a/tutorials/transfer_learning/Project.toml b/tutorials/transfer_learning/Project.toml index 7ded1877..3892825e 100644 --- a/tutorials/transfer_learning/Project.toml +++ b/tutorials/transfer_learning/Project.toml @@ -1,10 +1,10 @@ [deps] +DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] Flux = "0.13" diff --git a/tutorials/transfer_learning/README.md b/tutorials/transfer_learning/README.md new file mode 100644 index 00000000..6bce4153 --- /dev/null +++ b/tutorials/transfer_learning/README.md @@ -0,0 +1,252 @@ +# Transfer learning of vision model with Flux + +## Context + +This tutorial shows how to perform transfer learning using a pre-trained vision model. In the process, we will also learn how to use a custom data container, a useful feature when dealing with large datasets that cannot fit into memory. + +Transfer learning is a common way in which large, compute intensive models can be used in practice. Following their training to perform well on their general trask, they can be subsequently used as a basis to fine-tune only some of their components on smaller, specialized datasets adapted to a specific task. + +Self contained Julia code presented in this tutorial is found in ["transfer_learning.jl"](transfer_learning.jl) and can be launched with: + +``` +julia project=@. --threads=8 transfer_learning.jl +``` + +## Getting started + +In this tutorial, we'll used a pre-trained ResNet18 model to solve a 3-class classification problem: 🐱, 🐶, 🐼. + +Data can be accessed from [Kaggle](https://www.kaggle.com/datasets/ashishsaxena2209/animal-image-datasetdog-cat-and-panda). + +Following download and unzip, data is expected to live under the following structure: + +``` +- data + - animals + - cats + - dogs + - panda +``` + +In Julia, the following packages are needed: + +```julia +using Random: shuffle! +import Base: length, getindex +using Images +using Flux +using Flux: update! +using DataAugmentation +using Metalhead +``` + +We also define a utility to help manage cpu-gpu conversion: + +```julia +device = Flux.CUDA.functional() ? gpu : cpu +``` + +A CUDA enabled GPU is recommended. A modest one with 6GB RAM is sufficient and should run the tutorial in just 5-6 mins. If running on CPU, it may take over 40 mins. + +## Data preparation and loader + +When dealing with large datasets, it's unrealistic to use a vanilla `DataLoader` constructor using the entire dataset as input. A handy approach is to rely on custom data containers, which allows to only pull data into memory as needed. + +Our custom data container is very simple. It's a `struct` containing the paths to each of our images: + +```julia +const CATS = readdir(abspath(joinpath("data", "animals", "cats")), join = true) +const DOGS = readdir(abspath(joinpath("data", "animals", "dogs")), join = true) +const PANDA = readdir(abspath(joinpath("data", "animals", "panda")), join = true) + +struct ImageContainer{T<:Vector} + img::T +end + +imgs = [CATS..., DOGS..., PANDA...] +shuffle!(imgs) +data = ImageContainer(imgs) +``` + +In order to be compatible with `DataLoader`, 2 functions must minimally be defined: + - `Base.length`: returns the number of observations in the data container. + - `Base.getindex`: function that returns the observation for a specified index. + +```julia +length(data::ImageContainer) = length(data.img) + +const im_size = (224, 224) +tfm = DataAugmentation.compose(ScaleKeepAspect(im_size), CenterCrop(im_size)) +name_to_idx = Dict{String,Int32}("cats" => 1, "dogs" => 2, "panda" => 3) + +function getindex(data::ImageContainer, idx::Int) + path = data.img[idx] + _img = Images.load(path) + _img = itemdata(apply(tfm, Image(_img))) + img = collect(channelview(float32.(RGB.(_img)))) + img = permutedims((img .- mu) ./ sigma, (3, 2, 1)) + name = replace(path, r"(.+)\\(.+)\\(.+_\d+)\.jpg" => s"\2") + y = name_to_idx[name] + return img, y +end +``` + +In the above, the class label `y` is obtained by first using a regexp to extract the parent folder name of the image, which can be one of `cats`, `dogs` or `panda`. Then, this name can be mapped into an integer index using the `name_to_idx` dictionary. + +Data augmentation is performed through the `tfm` pipeline powered by [DataAugmentation.jl]https://github.com/lorenzoh/DataAugmentation.jl. Random crops, flips and color augmentation techniques are also supported. + +We can now define our train and eval data iterators: + +```julia +const batchsize = 16 + +dtrain = Flux.DataLoader( + ImageContainer(imgs[1:2700]); + batchsize, + collate = true, + parallel = true +) +device == gpu ? dtrain = Flux.CuIterator(dtrain) : nothing +``` + +```julia +deval = Flux.DataLoader( + ImageContainer(imgs[2701:3000]); + batchsize, + collate = true, + parallel = true +) +device == gpu ? deval = Flux.CuIterator(deval) : nothing +``` + +The `collate` option is set to `true` in order for all of the images to be concatenated into a 4D Array, where the batch dimension is last. Is set to false, it will return a vector of length `batchsize`, in which each element is a single 3D Array (width, height, channels). + +Setting `parallel` to `true` is an important performance enhancement as the GPU would otherwise spent significant time on idle waiting for the CPU data loading to complete. + +## Fine-tune | 🐢 mode + +Load a pre-trained model: + +```julia +m = Metalhead.ResNet(18, pretrain = true).layers +``` + +Substitute the latest layers with ones adapted to the fine-tuning task: + +```julia +m_tot = Chain(m[1], AdaptiveMeanPool((1, 1)), Flux.flatten, Dense(512 => 3)) |> device +``` + +Define an accuracy evaluation function: + +```julia +function eval_f(m, deval) + good = 0 + count = 0 + for (x, y) in deval + good += sum(Flux.onecold(m(x)) .== y) + count += length(y) + end + acc = round(good / count, digits = 4) + return acc +end +``` + +Define a training loop for 1 epoch: + +```julia +function train_epoch!(model; opt, dtrain) + for (x, y) in dtrain + grads = gradient(model) do m + Flux.Losses.logitcrossentropy(m(x), Flux.onehotbatch(y, 1:3)) + end + update!(opt, model, grads[1]) + end +end +``` + +Set learnable parameters and optimiser: + +```julia +opt = Flux.setup(Flux.Optimisers.Adam(1e-5), m_tot); +``` + +Train for a few epochs: + +```julia +for iter = 1:5 + @time train_epoch!(m_tot; opt, dtrain) + metric_train = eval_f(m_tot, dtrain) + metric_eval = eval_f(m_tot, deval) + @info "train" metric = metric_train + @info "eval" metric = metric_eval +end + 13.744040 seconds (5.42 M allocations: 10.991 GiB, 13.04% gc time) +┌ Info: train +└ metric = 0.9996 +┌ Info: eval +└ metric = 0.99 +``` + +## Fine-tune | 🐇 mode + +In the previous fine-tuning, despite having only specified the last `Dense` layer as trainable parameters, we nonetheless ended computing the gradients over the entire model. + +To avoid these unnecessary computations, we can split our model in two: +- The original pre-trained core, for which we don't want to compute gradients +- The new final layers, for which gradients are needed. + +```julia +m_infer = deepcopy(m[1]) |> device +m_tune = Chain(AdaptiveMeanPool((1, 1)), Flux.flatten, Dense(512 => 3)) |> device +``` + +Only minimal adaptations are then needed to the eval and training functions: + +```julia +function eval_f(m_infer, m_tune, deval) + good = 0 + count = 0 + for (x, y) in deval + good += sum(Flux.onecold(m_tune(m_infer(x))) .== y) + count += length(y) + end + acc = round(good / count, digits = 4) + return acc +end +``` + +```julia +function train_epoch!(m_infer, m_tune; opt, dtrain) + for (x, y) in dtrain + infer = m_infer(x) + grads = gradient(m_tune) do m + Flux.Losses.logitcrossentropy(m(infer), Flux.onehotbatch(y, 1:3)) + end + update!(opt, m_tune, grads[1]) + end +end +``` + +```julia +opt = Flux.setup(Flux.Optimisers.Adam(1e-4), m_tune); +``` + +```julia +for iter = 1:5 + @time train_epoch!(m_infer, m_tune; opt, dtrain) + metric_train = eval_f(m_infer, m_tune, dtrain) + metric_eval = eval_f(m_infer, m_tune, deval) + @info "train" metric = metric_train + @info "eval" metric = metric_eval +end + 4.398730 seconds (1.23 M allocations: 10.690 GiB, 9.75% gc time) +┌ Info: train +└ metric = 0.9907 +┌ Info: eval +└ metric = 0.9867 +``` + +As we can see, nearly 3X speedup can be achieved in this situation by avoiding unneeded gradients calculations. + +**This concludes the tutorial, happy transfer!** \ No newline at end of file diff --git a/tutorials/transfer_learning/dataloader.jl b/tutorials/transfer_learning/dataloader.jl deleted file mode 100644 index e5a92826..00000000 --- a/tutorials/transfer_learning/dataloader.jl +++ /dev/null @@ -1,52 +0,0 @@ -using Flux, Images -using StatsBase: sample, shuffle - -const PATH = joinpath(@__DIR__, "train") -const FILES = joinpath.(PATH, readdir(PATH)) -if isempty(readdir(PATH)) - error("Empty train folder - perhaps you need to download and extract the kaggle dataset.") -end - -# Get all of the files with "dog" in the name -const DOGS = filter(x -> occursin("dog", x), FILES) - -# Get all of the files with "cat" in the name -const CATS = filter(x -> occursin("cat", x), FILES) - -# Takes in the number of requested images per batch ("n") and image size -# Returns a 4D array with images and an array of labels -function load_batch(n = 10, nsize = (224,224); path = PATH) - if isodd(n) - print("Batch size must be an even number") - end - # Sample N dog images and N cat images, shuffle, and then combine them into a batch - imgs_paths = shuffle(vcat(sample(DOGS, Int(n/2)), sample(CATS, Int(n/2)))) - - # Generate the image label based on the file name - labels = map(x -> occursin("dog.",x) ? 1 : 0, imgs_paths) - # Here, dog is set to 1 and cat to 0 - - # Convert the text based names to 0 or 1 (one hot encoding) - labels = Flux.onehotbatch(labels, [0,1]) - - # Load all of the images - imgs = Images.load.(imgs_paths) - - # Re-size the images based on imagesize from above (most models use 224 x 224) - imgs = map(img -> Images.imresize(img, nsize...), imgs) - - # Change the dimensions of each image, switch to gray scale. Channel view switches to... - # a 3 channel 3rd dimension and then (3,2,1) makes those into seperate arrays. - # So we end up with [:, :, 1] being the Red values, [:, :, 2] being the Green values, etc - imgs = map(img -> permutedims(channelview(img), (3,2,1)), imgs) - # Result is two 3D arrays representing each image - - # Concatenate the two images into a single 4D array and add another extra dim at the end - # which shows how many images there are per set, in this case, it's 2 - imgs = cat(imgs..., dims = 4) - # This is requires since the model's input is a 4D array - - # Convert the images to float form and return them along with the labels - # The default is float64 but float32 is commonly used which is why we use it - Float32.(imgs), labels -end diff --git a/tutorials/transfer_learning/transfer_learning.jl b/tutorials/transfer_learning/transfer_learning.jl index a4e9d89a..dd7dc578 100644 --- a/tutorials/transfer_learning/transfer_learning.jl +++ b/tutorials/transfer_learning/transfer_learning.jl @@ -1,88 +1,135 @@ -# # Transfer Learning with Flux - -# This article is intended to be a general guide to how transfer learning works in the Flux ecosystem. -# We assume a certain familiarity of the reader with the concept of transfer learning. Having said that, -# we will start off with a basic definition of the setup and what we are trying to achieve. There are many -# resources online that go in depth as to why transfer learning is an effective tool to solve many ML -# problems, and we recommend checking some of those out. - -# Machine Learning today has evolved to use many highly trained models in a general task, -# where they are tuned to perform especially well on a subset of the problem. - -# This is one of the key ways in which larger (or smaller) models are used in practice. They are trained on -# a general problem, achieving good results on the test set, and then subsequently tuned on specialised datasets. - -# In this process, our model is already pretty well trained on the problem, so we don't need to train it -# all over again as if from scratch. In fact, as it so happens, we don't need to do that at all! We only -# need to tune the last couple of layers to get the most performance from our models. The exact last number of layers -# is dependant on the problem setup and the expected outcome, but a common tip is to train the last few `Dense` -# layers in a more complicated model. - -# So let's try to simulate the problem in Flux. - -# We'll tune a pretrained ResNet from Metalhead as a proxy. We will tune the `Dense` layers in there on a new set of images. - -using Flux, Metalhead -resnet = ResNet(pretrain=true).layers +# load packages +using Random: shuffle! +import Base: length, getindex +using Images +using Flux +using Flux: update! +using DataAugmentation +using Metalhead + +device = Flux.CUDA.functional() ? gpu : cpu +# device = cpu + +## Custom DataLoader +const CATS = readdir(abspath(joinpath("data", "animals", "cats")), join = true) +const DOGS = readdir(abspath(joinpath("data", "animals", "dogs")), join = true) +const PANDA = readdir(abspath(joinpath("data", "animals", "panda")), join = true) + +struct ImageContainer{T<:Vector} + img::T +end -# If we intended to add a new class of objects in there, we need only `reshape` the output from the previous layers accordingly. -# Our model would look something like so: +imgs = [CATS..., DOGS..., PANDA...] +shuffle!(imgs) +data = ImageContainer(imgs) -# ```julia -# model = Chain( -# resnet[1], # We only need to pull out the dense layer in here -# x -> reshape(x, size_we_want), # / global_avg_pooling layer -# Dense(reshaped_input_features, n_classes) -# ) -# ``` +length(data::ImageContainer) = length(data.img) -# We will use the [Dogs vs. Cats](https://www.kaggle.com/c/dogs-vs-cats/data) dataset from Kaggle for our use here. -# Make sure to extract the images in a `train` folder. +const im_size = (224, 224) +tfm = DataAugmentation.compose(ScaleKeepAspect(im_size), CenterCrop(im_size)) +name_to_idx = Dict{String,Int32}("cats" => 1, "dogs" => 2, "panda" => 3) -# The `dataloader.jl` script contains some functions that will help us load batches of images, shuffled between -# dogs and cats along with their correct labels. +const mu = [0.485f0, 0.456f0, 0.406f0] +const sigma = [0.229f0, 0.224f0, 0.225f0] -include("dataloader.jl") +function getindex(data::ImageContainer, idx::Int) + path = data.img[idx] + _img = Images.load(path) + _img = itemdata(apply(tfm, Image(_img))) + img = collect(channelview(float32.(RGB.(_img)))) + img = permutedims((img .- mu) ./ sigma, (3, 2, 1)) + name = replace(path, r"(.+)\\(.+)\\(.+_\d+)\.jpg" => s"\2") + y = name_to_idx[name] + return img, y +end -# Finally, the model looks something like: +# define DataLoaders +const batchsize = 16 -model = Chain( - resnet[1], - AdaptiveMeanPool((1, 1)), - Flux.flatten, - Dense(2048, 1000, relu), - Dense(1000, 256, relu), - Dense(256, 2), # we get 2048 features out, and we have 2 classes +dtrain = Flux.DataLoader( + ImageContainer(imgs[1:2700]); + batchsize, + collate = true, + parallel = true, ) +device == gpu ? dtrain = Flux.CuIterator(dtrain) : nothing -# To speed up training, let's move everything over to the GPU - -model = model |> gpu -dataset = [gpu.(load_batch(10)) for i in 1:10] - -# After this, we only need to define the other parts of the training pipeline like we usually do. - -opt = ADAM() -loss(x,y) = Flux.Losses.logitcrossentropy(model(x), y) - -# Now to train -# As discussed earlier, we don't need to pass all the parameters to our training loop. Only the ones we need to -# fine-tune. Note that we could have picked and chosen the layers we want to train individually as well, but this -# is sufficient for our use as of now. +deval = Flux.DataLoader( + ImageContainer(imgs[2701:3000]); + batchsize, + collate = true, + parallel = true, +) +device == gpu ? deval = Flux.CuIterator(deval) : nothing + +# Fine-tune | 🐢 mode +# Load a pre-trained model: +m = Metalhead.ResNet(18, pretrain = true).layers +m_tot = Chain(m[1], AdaptiveMeanPool((1, 1)), Flux.flatten, Dense(512 => 3)) |> device + +function eval_f(m, deval) + good = 0 + count = 0 + for (x, y) in deval + good += sum(Flux.onecold(m(x)) .== y) + count += length(y) + end + acc = round(good / count, digits = 4) + return acc +end -ps = Flux.params(model[2:end]) # ignore the already trained layers of the ResNet +function train_epoch!(model; opt, dtrain) + for (x, y) in dtrain + grads = gradient(model) do m + Flux.Losses.logitcrossentropy(m(x), Flux.onehotbatch(y, 1:3)) + end + update!(opt, model, grads[1]) + end +end -# And now, let's train! +opt = Flux.setup(Flux.Optimisers.Adam(1e-5), m_tot); -for epoch in 1:2 - Flux.train!(loss, ps, dataset, opt) +for iter = 1:5 + @time train_epoch!(m_tot; opt, dtrain) + metric_train = eval_f(m_tot, dtrain) + metric_eval = eval_f(m_tot, deval) + @info "train" metric = metric_train + @info "eval" metric = metric_eval end -# And there you have it, a pretrained model, fine tuned to tell the the dogs from the cats. -# We can verify this too. +# Fine-tune | 🐇 mode +# define models +m_infer = deepcopy(m[1]) |> device +m_tune = Chain(AdaptiveMeanPool((1, 1)), Flux.flatten, Dense(512 => 3)) |> device + +function eval_f(m_infer, m_tune, deval) + good = 0 + count = 0 + for (x, y) in deval + good += sum(Flux.onecold(m_tune(m_infer(x))) .== y) + count += length(y) + end + acc = round(good / count, digits = 4) + return acc +end -imgs, labels = gpu.(load_batch(10)) -display(model(imgs)) +function train_epoch!(m_infer, m_tune; opt, dtrain) + for (x, y) in dtrain + infer = m_infer(x) + grads = gradient(m_tune) do m + Flux.Losses.logitcrossentropy(m(infer), Flux.onehotbatch(y, 1:3)) + end + update!(opt, m_tune, grads[1]) + end +end -labels +opt = Flux.setup(Flux.Optimisers.Adam(1e-3), m_tune); +# training loop +for iter = 1:5 + @time train_epoch!(m_infer, m_tune; opt, dtrain) + metric_train = eval_f(m_infer, m_tune, dtrain) + metric_eval = eval_f(m_infer, m_tune, deval) + @info "train" metric = metric_train + @info "eval" metric = metric_eval +end