diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 996d5640e9..bd39b54280 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -15,7 +15,16 @@ steps: - "test/" - "Project.toml" - ".buildkite/testing.yml" - - "lib/" + - "lib/LuxTestUtils/src/" + - "lib/LuxTestUtils/ext/" + - "lib/LuxCore/src/" + - "lib/LuxCore/ext/" + - "lib/MLDataDevices/src/" + - "lib/MLDataDevices/ext/" + - "lib/WeightInitializers/src/" + - "lib/WeightInitializers/ext/" + - "lib/LuxLib/src/" + - "lib/LuxLib/ext/" config: command: "buildkite-agent pipeline upload .buildkite/testing.yml" agents: @@ -52,9 +61,12 @@ steps: path: - "lib/LuxLib/" - ".buildkite/testing_luxlib.yml" - - "lib/LuxTestUtils/" - - "lib/LuxCore/" - - "lib/MLDataDevices/" + - "lib/LuxTestUtils/src/" + - "lib/LuxTestUtils/ext/" + - "lib/LuxCore/src/" + - "lib/LuxCore/ext/" + - "lib/MLDataDevices/src/" + - "lib/MLDataDevices/ext/" config: command: "buildkite-agent pipeline upload .buildkite/testing_luxlib.yml" agents: diff --git a/.buildkite/testing_mldatadevices.yml b/.buildkite/testing_mldatadevices.yml index 555e304233..df2ce766b0 100644 --- a/.buildkite/testing_mldatadevices.yml +++ b/.buildkite/testing_mldatadevices.yml @@ -7,7 +7,7 @@ steps: version: "{{matrix.julia}}" - JuliaCI/julia-test#v1: project: "lib/MLDataDevices" - test_args: "BACKEND_GROUP={{matrix.group}}" + test_args: "--BACKEND_GROUP={{matrix.group}}" - JuliaCI/julia-coverage#v1: codecov: true dirs: @@ -34,7 +34,7 @@ steps: # version: "{{matrix.julia}}" # - JuliaCI/julia-test#v1: # project: "lib/MLDataDevices" - # test_args: "BACKEND_GROUP=AMDGPU" + # test_args: "--BACKEND_GROUP=AMDGPU" # - JuliaCI/julia-coverage#v1: # codecov: true # dirs: @@ -59,7 +59,7 @@ steps: version: "{{matrix.julia}}" - JuliaCI/julia-test#v1: project: "lib/MLDataDevices" - test_args: "BACKEND_GROUP=Metal" + test_args: "--BACKEND_GROUP=Metal" - JuliaCI/julia-coverage#v1: codecov: true dirs: @@ -85,7 +85,7 @@ steps: version: "{{matrix.julia}}" - JuliaCI/julia-test#v1: project: "lib/MLDataDevices" - test_args: "BACKEND_GROUP=oneAPI" + test_args: "--BACKEND_GROUP=oneAPI" - JuliaCI/julia-coverage#v1: codecov: true dirs: diff --git a/.buildkite/testing_weightinitializers.yml b/.buildkite/testing_weightinitializers.yml index b2269d8a4d..bf2120e837 100644 --- a/.buildkite/testing_weightinitializers.yml +++ b/.buildkite/testing_weightinitializers.yml @@ -7,7 +7,7 @@ steps: version: "{{matrix.julia}}" - JuliaCI/julia-test#v1: project: "lib/WeightInitializers" - test_args: "BACKEND_GROUP=CUDA" + test_args: "--BACKEND_GROUP=CUDA" - JuliaCI/julia-coverage#v1: codecov: true dirs: @@ -31,7 +31,7 @@ steps: # version: "{{matrix.julia}}" # - JuliaCI/julia-test#v1: # project: "lib/WeightInitializers" - # test_args: "BACKEND_GROUP=AMDGPU" + # test_args: "--BACKEND_GROUP=AMDGPU" # - JuliaCI/julia-coverage#v1: # codecov: true # dirs: @@ -56,7 +56,7 @@ steps: version: "{{matrix.julia}}" - JuliaCI/julia-test#v1: project: "lib/WeightInitializers" - test_args: "BACKEND_GROUP=Metal" + test_args: "--BACKEND_GROUP=Metal" - JuliaCI/julia-coverage#v1: codecov: true dirs: @@ -82,7 +82,7 @@ steps: version: "{{matrix.julia}}" - JuliaCI/julia-test#v1: project: "lib/WeightInitializers" - test_args: "BACKEND_GROUP=oneAPI" + test_args: "--BACKEND_GROUP=oneAPI" - JuliaCI/julia-coverage#v1: codecov: true dirs: diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 73c5db566f..d41bbb1371 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -10,11 +10,16 @@ on: - "test/**" - "Project.toml" - ".github/workflows/CI.yml" - - "lib/LuxTestUtils/**" - - "lib/LuxCore/**" - - "lib/MLDataDevices/**" - - "lib/WeightInitializers/**" - - "lib/LuxLib/**" + - "lib/LuxTestUtils/src/**" + - "lib/LuxTestUtils/ext/**" + - "lib/LuxCore/src/**" + - "lib/LuxCore/ext/**" + - "lib/MLDataDevices/src/**" + - "lib/MLDataDevices/ext/**" + - "lib/WeightInitializers/src/**" + - "lib/WeightInitializers/ext/**" + - "lib/LuxLib/src/**" + - "lib/LuxLib/ext/**" push: branches: - main diff --git a/.github/workflows/CIPreRelease.yml b/.github/workflows/CIPreRelease.yml index e4240171f1..d0aa922f6d 100644 --- a/.github/workflows/CIPreRelease.yml +++ b/.github/workflows/CIPreRelease.yml @@ -10,11 +10,16 @@ # - "test/**" # - "Project.toml" # - ".github/workflows/CI.yml" -# - "lib/LuxTestUtils/**" -# - "lib/LuxCore/**" -# - "lib/MLDataDevices/**" -# - "lib/WeightInitializers/**" -# - "lib/LuxLib/**" +# - "lib/LuxTestUtils/src/**" +# - "lib/LuxTestUtils/ext/**" +# - "lib/LuxCore/src/**" +# - "lib/LuxCore/ext/**" +# - "lib/MLDataDevices/src/**" +# - "lib/MLDataDevices/ext/**" +# - "lib/WeightInitializers/src/**" +# - "lib/WeightInitializers/ext/**" +# - "lib/LuxLib/src/**" +# - "lib/LuxLib/ext/**" # push: # branches: # - main diff --git a/.github/workflows/CI_LuxCore.yml b/.github/workflows/CI_LuxCore.yml index 35d2eec413..cf574603d5 100644 --- a/.github/workflows/CI_LuxCore.yml +++ b/.github/workflows/CI_LuxCore.yml @@ -7,7 +7,8 @@ on: - ".github/workflows/CommonCI.yml" - "lib/LuxCore/**" - ".github/workflows/CI_LuxCore.yml" - - "lib/MLDataDevices/**" + - "lib/MLDataDevices/src/**" + - "lib/MLDataDevices/ext/**" push: branches: - main diff --git a/.github/workflows/CI_LuxLib.yml b/.github/workflows/CI_LuxLib.yml index 700e6c3d53..9d81c94f16 100644 --- a/.github/workflows/CI_LuxLib.yml +++ b/.github/workflows/CI_LuxLib.yml @@ -7,9 +7,12 @@ on: - ".github/workflows/CommonCI.yml" - "lib/LuxLib/**" - ".github/workflows/CI_LuxLib.yml" - - "lib/LuxTestUtils/**" - - "lib/LuxCore/**" - - "lib/MLDataDevices/**" + - "lib/LuxTestUtils/src/**" + - "lib/LuxTestUtils/ext/**" + - "lib/LuxCore/src/**" + - "lib/LuxCore/ext/**" + - "lib/MLDataDevices/src/**" + - "lib/MLDataDevices/ext/**" push: branches: - main diff --git a/.github/workflows/CI_MLDataDevices.yml b/.github/workflows/CI_MLDataDevices.yml index c003a2dd35..7318b69be6 100644 --- a/.github/workflows/CI_MLDataDevices.yml +++ b/.github/workflows/CI_MLDataDevices.yml @@ -28,11 +28,16 @@ jobs: - cpu - opencl - reactant + exclude: + - os: windows-latest + group: opencl + - os: macos-latest + group: opencl uses: ./.github/workflows/CommonCI.yml with: julia_version: "1.12" project: "lib/MLDataDevices" - test_args: "BACKEND_GROUP=${{ matrix.group }}" + test_args: "--BACKEND_GROUP=${{ matrix.group }}" os: ${{ matrix.os }} downgrade: @@ -48,4 +53,4 @@ jobs: julia_version: "1.11" project: "lib/MLDataDevices" downgrade_testing: true - test_args: "BACKEND_GROUP=${{ matrix.group }}" + test_args: "--BACKEND_GROUP=${{ matrix.group }}" diff --git a/.github/workflows/CI_WeightInitializers.yml b/.github/workflows/CI_WeightInitializers.yml index ce3b0f91b3..d291553310 100644 --- a/.github/workflows/CI_WeightInitializers.yml +++ b/.github/workflows/CI_WeightInitializers.yml @@ -21,7 +21,7 @@ jobs: with: julia_version: "1.12" project: "lib/WeightInitializers" - test_args: "BACKEND_GROUP=cpu" + test_args: "--BACKEND_GROUP=cpu" downgrade: uses: ./.github/workflows/CommonCI.yml @@ -29,4 +29,4 @@ jobs: julia_version: "1.11" project: "lib/WeightInitializers" downgrade_testing: true - test_args: "BACKEND_GROUP=cpu" + test_args: "--BACKEND_GROUP=cpu" diff --git a/lib/LuxCore/test/Project.toml b/lib/LuxCore/test/Project.toml index 2355f8d198..b9b3c85981 100644 --- a/lib/LuxCore/test/Project.toml +++ b/lib/LuxCore/test/Project.toml @@ -6,10 +6,15 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +LuxCore = {path = ".."} +MLDataDevices = {path = "../../MLDataDevices"} + [compat] Aqua = "0.8.7" EnzymeCore = "0.8.14" @@ -17,10 +22,7 @@ ExplicitImports = "1.9.0" Functors = "0.5" MLDataDevices = "1.17" Optimisers = "0.3.4, 0.4" +ParallelTestRunner = "2.1" Random = "1.10" Setfield = "1.1" Test = "1.10" - -[sources] -LuxCore = {path = ".."} -MLDataDevices = {path = "../../MLDataDevices"} diff --git a/lib/LuxCore/test/abstractluxcontainerlayer.jl b/lib/LuxCore/test/abstractluxcontainerlayer.jl new file mode 100644 index 0000000000..ce905d149f --- /dev/null +++ b/lib/LuxCore/test/abstractluxcontainerlayer.jl @@ -0,0 +1,44 @@ +using LuxCore, Test, Random + +rng = LuxCore.Internal.default_rng() + +include("common.jl") + +@testset "AbstractLuxContainerLayer Interface" begin + model = Chain((; layer_1=Dense(5, 5), layer_2=Dense(5, 6))) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + @test fieldnames(typeof(ps)) == (:layers,) + @test fieldnames(typeof(st)) == (:layers,) + + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layers[1]) + LuxCore.parameterlength(model.layers[2]) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layers[1]) + LuxCore.statelength(model.layers[2]) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test LuxCore.stateless_apply(model, x, ps) == first(LuxCore.apply(model, x, ps, st)) + + @test_nowarn println(model) + + model = Chain2(Dense(5, 5), Dense(5, 6)) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layer1) + LuxCore.parameterlength(model.layer2) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layer1) + LuxCore.statelength(model.layer2) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test LuxCore.stateless_apply(model, x, ps) == first(LuxCore.apply(model, x, ps, st)) + + @test_nowarn println(model) +end diff --git a/lib/LuxCore/test/abstractluxlayer.jl b/lib/LuxCore/test/abstractluxlayer.jl new file mode 100644 index 0000000000..77f95881c7 --- /dev/null +++ b/lib/LuxCore/test/abstractluxlayer.jl @@ -0,0 +1,61 @@ +using LuxCore, Test, Random, Functors + +rng = LuxCore.Internal.default_rng() + +include("common.jl") + +@testset "AbstractLuxLayer Interface" begin + @testset "Custom Layer" begin + model = Dense(5, 6) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model) + @test LuxCore.statelength(st) == LuxCore.statelength(model) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test LuxCore.stateless_apply(model, x, ps) == + first(LuxCore.apply(model, x, ps, NamedTuple())) + + @test_nowarn println(model) + + @testset for wrapper in (DenseWrapper, DenseWrapper2) + model2 = wrapper(model) + ps, st = LuxCore.setup(rng, model2) + + @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model2) + @test LuxCore.statelength(st) == LuxCore.statelength(model2) + + @test model2(x, ps, st)[1] == model(x, ps, st)[1] + + @test_nowarn println(model2) + end + end + + @testset "Default Fallbacks" begin + struct NoParamStateLayer <: AbstractLuxLayer end + + layer = NoParamStateLayer() + @test LuxCore.initialparameters(rng, layer) == NamedTuple() + @test LuxCore.initialstates(rng, layer) == NamedTuple() + + @test LuxCore.parameterlength(zeros(10, 2)) == 20 + @test LuxCore.statelength(zeros(10, 2)) == 20 + @test LuxCore.statelength(Val(true)) == 1 + @test LuxCore.statelength((zeros(10), zeros(5, 2))) == 20 + @test LuxCore.statelength((layer_1=zeros(10), layer_2=zeros(5, 2))) == 20 + + @test LuxCore.initialparameters(rng, NamedTuple()) == NamedTuple() + @test_throws MethodError LuxCore.initialparameters(rng, ()) + @test LuxCore.initialparameters(rng, nothing) == NamedTuple() + @test LuxCore.initialparameters(rng, (nothing, layer)) == + (NamedTuple(), NamedTuple()) + + @test LuxCore.initialstates(rng, NamedTuple()) == NamedTuple() + @test_throws MethodError LuxCore.initialstates(rng, ()) + @test LuxCore.initialstates(rng, nothing) == NamedTuple() + @test LuxCore.initialparameters(rng, (nothing, layer)) == + (NamedTuple(), NamedTuple()) + end +end diff --git a/lib/LuxCore/test/abstractluxwrapperlayer.jl b/lib/LuxCore/test/abstractluxwrapperlayer.jl new file mode 100644 index 0000000000..0771e5040c --- /dev/null +++ b/lib/LuxCore/test/abstractluxwrapperlayer.jl @@ -0,0 +1,29 @@ +using LuxCore, Test, Random + +rng = LuxCore.Internal.default_rng() + +include("common.jl") + +@testset "AbstractLuxWrapperLayer Interface" begin + model = ChainWrapper((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + @test fieldnames(typeof(ps)) == (:layer_1, :layer_2) + @test fieldnames(typeof(st)) == (:layer_1, :layer_2) + + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layers.layer_1) + + LuxCore.parameterlength(model.layers.layer_2) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layers.layer_1) + + LuxCore.statelength(model.layers.layer_2) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test LuxCore.stateless_apply(model, x, ps) == first(LuxCore.apply(model, x, ps, st)) + + @test_nowarn println(model) +end diff --git a/lib/LuxCore/test/common.jl b/lib/LuxCore/test/common.jl new file mode 100644 index 0000000000..11edc16b57 --- /dev/null +++ b/lib/LuxCore/test/common.jl @@ -0,0 +1,55 @@ +using LuxCore, Random + +# Define some custom layers +struct Dense <: AbstractLuxLayer + in::Int + out::Int +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::Dense) + return (w=randn(rng, l.out, l.in), b=randn(rng, l.out)) +end + +(::Dense)(x, _, st) = x, st # Dummy Forward Pass + +struct DenseWrapper{L} <: AbstractLuxWrapperLayer{:layer} + layer::L +end + +# For checking ambiguities in the dispatch +struct DenseWrapper2{L} <: AbstractLuxWrapperLayer{:layer} + layer::L +end + +(d::DenseWrapper2)(x::AbstractArray, ps, st) = d.layer(x, ps, st) + +struct Chain{L} <: AbstractLuxContainerLayer{(:layers,)} + layers::L +end + +function (c::Chain)(x, ps, st) + y, st1 = c.layers[1](x, ps.layers.layer_1, st.layers.layer_1) + y, st2 = c.layers[2](y, ps.layers.layer_2, st.layers.layer_2) + return y, (; layers=(; layer_1=st1, layer_2=st2)) +end + +struct ChainWrapper{L} <: AbstractLuxWrapperLayer{:layers} + layers::L +end + +function (c::ChainWrapper)(x, ps, st) + y, st1 = c.layers[1](x, ps.layer_1, st.layer_1) + y, st2 = c.layers[2](y, ps.layer_2, st.layer_2) + return y, (; layer_1=st1, layer_2=st2) +end + +struct Chain2{L1,L2} <: AbstractLuxContainerLayer{(:layer1, :layer2)} + layer1::L1 + layer2::L2 +end + +function (c::Chain2)(x, ps, st) + y, st1 = c.layer1(x, ps.layer1, st.layer1) + y, st2 = c.layer2(y, ps.layer2, st.layer2) + return y, (; layer1=st1, layer2=st2) +end diff --git a/lib/LuxCore/test/misc.jl b/lib/LuxCore/test/misc.jl new file mode 100644 index 0000000000..f883ec8d1c --- /dev/null +++ b/lib/LuxCore/test/misc.jl @@ -0,0 +1,239 @@ +using Optimisers, Random, EnzymeCore, MLDataDevices, LuxCore, Test, Functors, Setfield + +rng = LuxCore.Internal.default_rng() + +include("common.jl") + +@testset "update_state API" begin + st = ( + layer_1=(training=Val(true), val=1), + layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),)), + ) + + st_ = LuxCore.testmode(st) + + @test st_.layer_1.training == Val(false) && + st_.layer_2.layer_2.training == Val(false) && + st_.layer_1.val == st.layer_1.val && + st_.layer_2.layer_1.val == st.layer_2.layer_1.val + + st = st_ + st_ = LuxCore.trainmode(st) + + @test st_.layer_1.training == Val(true) && + st_.layer_2.layer_2.training == Val(true) && + st_.layer_1.val == st.layer_1.val && + st_.layer_2.layer_1.val == st.layer_2.layer_1.val + + st_ = LuxCore.update_state(st, :val, -1) + @test st_.layer_1.training == st.layer_1.training && + st_.layer_2.layer_2.training == st.layer_2.layer_2.training && + st_.layer_1.val == -1 && + st_.layer_2.layer_1.val == -1 +end + +@testset "Functor Compatibility" begin + @testset "Basic Usage" begin + model = Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) + + children, reconstructor = Functors.functor(model) + + @test children isa NamedTuple + @test fieldnames(typeof(children)) == (:layers,) + @test children.layers isa NamedTuple + @test fieldnames(typeof(children.layers)) == (:layer_1, :layer_2) + @test children.layers.layer_1 isa Dense + @test children.layers.layer_2 isa Dense + @test children.layers.layer_1.in == 5 + @test children.layers.layer_1.out == 10 + @test children.layers.layer_2.in == 10 + @test children.layers.layer_2.out == 5 + + new_model = reconstructor((; + layers=(; layer_1=Dense(10, 5), layer_2=Dense(5, 10)), + )) + + @test new_model isa Chain + @test new_model.layers.layer_1.in == 10 + @test new_model.layers.layer_1.out == 5 + @test new_model.layers.layer_2.in == 5 + @test new_model.layers.layer_2.out == 10 + + model = ChainWrapper((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) + + children, reconstructor = Functors.functor(model) + + @test children isa NamedTuple + @test fieldnames(typeof(children)) == (:layers,) + @test children.layers isa NamedTuple + @test fieldnames(typeof(children.layers)) == (:layer_1, :layer_2) + @test children.layers.layer_1 isa Dense + @test children.layers.layer_2 isa Dense + @test children.layers.layer_1.in == 5 + @test children.layers.layer_1.out == 10 + @test children.layers.layer_2.in == 10 + @test children.layers.layer_2.out == 5 + + new_model = reconstructor((; + layers=(; layer_1=Dense(10, 5), layer_2=Dense(5, 10)), + )) + + @test new_model isa ChainWrapper + @test new_model.layers.layer_1.in == 10 + @test new_model.layers.layer_1.out == 5 + @test new_model.layers.layer_2.in == 5 + @test new_model.layers.layer_2.out == 10 + end + + @testset "Method Ambiguity" begin + # Needed if defining a layer that works with both Flux and Lux -- See DiffEqFlux.jl + # See https://github.com/SciML/DiffEqFlux.jl/pull/750#issuecomment-1373874944 + struct CustomLayer{M,P} <: AbstractLuxContainerLayer{(:model,)} + model::M + p::P + end + + @functor CustomLayer (p,) + + l = CustomLayer(x -> x, nothing) # Dummy Struct + + @test_nowarn Optimisers.trainable(l) + end +end + +@testset "Display Name" begin + struct StructWithoutName <: AbstractLuxLayer end + + model = StructWithoutName() + + @test LuxCore.display_name(model) == "StructWithoutName" + + struct StructWithName{N} <: AbstractLuxLayer + name::N + end + + model = StructWithName("Test") + + @test LuxCore.display_name(model) == "Test" + + model = StructWithName(nothing) + + @test LuxCore.display_name(model) == "StructWithName" + + @test LuxCore.display_name(rand(20)) == "Array" +end + +@testset "initialparameter/initialstate for Default Containers" begin + models1 = [ + Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), + Chain2(Dense(5, 10), Dense(10, 5)), + [Dense(5, 10), Dense(10, 5)], + ] + models2 = [ + Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), + Chain2(Dense(5, 10), Dense(10, 5)), + (Dense(5, 10), Dense(10, 5)), + ] + + for models in (models1, models2) + ps, st = LuxCore.setup(rng, models) + @test length(ps) == length(models) + @test length(st) == length(models) + @test typeof(ps[1]) == typeof(LuxCore.initialparameters(rng, models[1])) + @test typeof(ps[2]) == typeof(LuxCore.initialparameters(rng, models[2])) + @test typeof(ps[3][1]) == typeof(LuxCore.initialparameters(rng, models[3][1])) + @test typeof(ps[3][2]) == typeof(LuxCore.initialparameters(rng, models[3][2])) + @test typeof(st[1]) == typeof(LuxCore.initialstates(rng, models[1])) + @test typeof(st[2]) == typeof(LuxCore.initialstates(rng, models[2])) + @test typeof(st[3][1]) == typeof(LuxCore.initialstates(rng, models[3][1])) + @test typeof(st[3][2]) == typeof(LuxCore.initialstates(rng, models[3][2])) + end +end + +@testset "Convenience Checks" begin + models1 = [ + Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), + Chain2(Dense(5, 10), Dense(10, 5)), + [Dense(5, 10), Dense(10, 5)], + ] + + @test LuxCore.contains_lux_layer(models1) + + models2 = [1, 2, 3, 4] + + @test !LuxCore.contains_lux_layer(models2) + + models3 = [1, 2, 3, (; a=Dense(5, 10), b=Dense(10, 5))] + + @test LuxCore.contains_lux_layer(models3) +end + +@testset "replicate" begin + rng = Random.default_rng() + @test LuxCore.replicate(rng) === rng + @test LuxCore.replicate(rng) == rng + + rng = Xoshiro(1234) + @test LuxCore.replicate(rng) !== rng + @test LuxCore.replicate(rng) == rng +end + +@testset "empty fleaves" begin + @test length(fleaves(NamedTuple())) == 0 + @test !LuxCore.check_fmap_condition(isodd, nothing, NamedTuple()) +end + +@testset "Common Lux + Enzyme Mistakes" begin + d = Dense(2, 2) + + @test_throws ArgumentError Active(d) + @test_throws ArgumentError Duplicated(d, d) + @test_throws ArgumentError DuplicatedNoNeed(d, d) + @test_throws ArgumentError BatchDuplicated(d, (d, d)) + @test_throws ArgumentError BatchDuplicatedNoNeed(d, (d, d)) + @test Const(d) isa Const +end + +@testset "Device Transfer Warnings" begin + my_layer = Dense(2, 2) + + dev = cpu_device() + @test_logs ( + :warn, + "Lux layers are stateless and hence don't participate in device \ + transfers. Apply this function on the parameters and states generated \ + using `LuxCore.setup`.", + ) dev(my_layer) +end + +@testset "nested `training` key: Issue Lux.jl#849" begin + st = ( + encoder=(layer_1=NamedTuple(), layer_2=(; training=Val{true}())), + μ=NamedTuple(), + logσ=NamedTuple(), + decoder=( + layer_1=NamedTuple(), + layer_2=NamedTuple(), + layer_3=NamedTuple(), + layer_4=(running_mean=Float32[0.0, 0.0], training=Val{true}()), + ), + rng=Xoshiro(), + training=Val{true}(), + ) + + @test st.encoder.layer_2.training isa Val{true} + @test st.decoder.layer_4.training isa Val{true} + @test st.training isa Val{true} + + st_test = LuxCore.testmode(st) + + @test st_test.encoder.layer_2.training isa Val{false} + @test st_test.decoder.layer_4.training isa Val{false} + @test st_test.training isa Val{false} + + st_train = LuxCore.trainmode(st_test) + + @test st_train.encoder.layer_2.training isa Val{true} + @test st_train.decoder.layer_4.training isa Val{true} + @test st_train.training isa Val{true} +end diff --git a/lib/LuxCore/test/qa.jl b/lib/LuxCore/test/qa.jl new file mode 100644 index 0000000000..a5230c593c --- /dev/null +++ b/lib/LuxCore/test/qa.jl @@ -0,0 +1,6 @@ +using LuxCore, Test, ExplicitImports, Aqua + +@testset "Quality Assurance" begin + Aqua.test_all(LuxCore) + ExplicitImports.test_explicit_imports(LuxCore) +end diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl index d0c48051de..667b85c320 100644 --- a/lib/LuxCore/test/runtests.jl +++ b/lib/LuxCore/test/runtests.jl @@ -1,4 +1,4 @@ -using LuxCore, Test +using LuxCore, Test, ParallelTestRunner @testset "Extension Loading Checks (Fail)" begin @test !LuxCore.Internal.is_extension_loaded(Val(:Setfield)) @@ -15,430 +15,6 @@ using Functors, Setfield @test LuxCore.Internal.is_extension_loaded(Val(:Functors)) end -using Aqua, ExplicitImports, Optimisers, Random, EnzymeCore, MLDataDevices - -rng = LuxCore.Internal.default_rng() - -# Define some custom layers -struct Dense <: AbstractLuxLayer - in::Int - out::Int -end - -function LuxCore.initialparameters(rng::AbstractRNG, l::Dense) - return (w=randn(rng, l.out, l.in), b=randn(rng, l.out)) -end - -(::Dense)(x, _, st) = x, st # Dummy Forward Pass - -struct DenseWrapper{L} <: AbstractLuxWrapperLayer{:layer} - layer::L -end - -# For checking ambiguities in the dispatch -struct DenseWrapper2{L} <: AbstractLuxWrapperLayer{:layer} - layer::L -end - -(d::DenseWrapper2)(x::AbstractArray, ps, st) = d.layer(x, ps, st) - -struct Chain{L} <: AbstractLuxContainerLayer{(:layers,)} - layers::L -end - -function (c::Chain)(x, ps, st) - y, st1 = c.layers[1](x, ps.layers.layer_1, st.layers.layer_1) - y, st2 = c.layers[2](y, ps.layers.layer_2, st.layers.layer_2) - return y, (; layers=(; layer_1=st1, layer_2=st2)) -end - -struct ChainWrapper{L} <: AbstractLuxWrapperLayer{:layers} - layers::L -end - -function (c::ChainWrapper)(x, ps, st) - y, st1 = c.layers[1](x, ps.layer_1, st.layer_1) - y, st2 = c.layers[2](y, ps.layer_2, st.layer_2) - return y, (; layer_1=st1, layer_2=st2) -end - -struct Chain2{L1,L2} <: AbstractLuxContainerLayer{(:layer1, :layer2)} - layer1::L1 - layer2::L2 -end - -function (c::Chain2)(x, ps, st) - y, st1 = c.layer1(x, ps.layer1, st.layer1) - y, st2 = c.layer1(y, ps.layer2, st.layer2) - return y, (; layer1=st1, layer2=st2) -end - -@testset "LuxCore.jl Tests" begin - @testset "AbstractLuxLayer Interface" begin - @testset "Custom Layer" begin - model = Dense(5, 6) - x = randn(rng, Float32, 5) - ps, st = LuxCore.setup(rng, model) - - @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model) - @test LuxCore.statelength(st) == LuxCore.statelength(model) - - @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) - - @test LuxCore.stateless_apply(model, x, ps) == - first(LuxCore.apply(model, x, ps, NamedTuple())) - - @test_nowarn println(model) - - @testset for wrapper in (DenseWrapper, DenseWrapper2) - model2 = DenseWrapper(model) - ps, st = LuxCore.setup(rng, model2) - - @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model2) - @test LuxCore.statelength(st) == LuxCore.statelength(model2) - - @test model2(x, ps, st)[1] == model(x, ps, st)[1] - - @test_nowarn println(model2) - end - end - - @testset "Default Fallbacks" begin - struct NoParamStateLayer <: AbstractLuxLayer end - - layer = NoParamStateLayer() - @test LuxCore.initialparameters(rng, layer) == NamedTuple() - @test LuxCore.initialstates(rng, layer) == NamedTuple() - - @test LuxCore.parameterlength(zeros(10, 2)) == 20 - @test LuxCore.statelength(zeros(10, 2)) == 20 - @test LuxCore.statelength(Val(true)) == 1 - @test LuxCore.statelength((zeros(10), zeros(5, 2))) == 20 - @test LuxCore.statelength((layer_1=zeros(10), layer_2=zeros(5, 2))) == 20 - - @test LuxCore.initialparameters(rng, NamedTuple()) == NamedTuple() - @test_throws MethodError LuxCore.initialparameters(rng, ()) - @test LuxCore.initialparameters(rng, nothing) == NamedTuple() - @test LuxCore.initialparameters(rng, (nothing, layer)) == - (NamedTuple(), NamedTuple()) - - @test LuxCore.initialstates(rng, NamedTuple()) == NamedTuple() - @test_throws MethodError LuxCore.initialstates(rng, ()) - @test LuxCore.initialstates(rng, nothing) == NamedTuple() - @test LuxCore.initialparameters(rng, (nothing, layer)) == - (NamedTuple(), NamedTuple()) - end - end - - @testset "AbstractLuxContainerLayer Interface" begin - model = Chain((; layer_1=Dense(5, 5), layer_2=Dense(5, 6))) - x = randn(rng, Float32, 5) - ps, st = LuxCore.setup(rng, model) - - @test fieldnames(typeof(ps)) == (:layers,) - @test fieldnames(typeof(st)) == (:layers,) - - @test LuxCore.parameterlength(ps) == - LuxCore.parameterlength(model) == - LuxCore.parameterlength(model.layers[1]) + - LuxCore.parameterlength(model.layers[2]) - @test LuxCore.statelength(st) == - LuxCore.statelength(model) == - LuxCore.statelength(model.layers[1]) + LuxCore.statelength(model.layers[2]) - - @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) - - @test LuxCore.stateless_apply(model, x, ps) == - first(LuxCore.apply(model, x, ps, st)) - - @test_nowarn println(model) - - model = Chain2(Dense(5, 5), Dense(5, 6)) - x = randn(rng, Float32, 5) - ps, st = LuxCore.setup(rng, model) - - @test LuxCore.parameterlength(ps) == - LuxCore.parameterlength(model) == - LuxCore.parameterlength(model.layer1) + LuxCore.parameterlength(model.layer2) - @test LuxCore.statelength(st) == - LuxCore.statelength(model) == - LuxCore.statelength(model.layer1) + LuxCore.statelength(model.layer2) - - @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) - - @test LuxCore.stateless_apply(model, x, ps) == - first(LuxCore.apply(model, x, ps, st)) - - @test_nowarn println(model) - end - - @testset "AbstractLuxWrapperLayer Interface" begin - model = ChainWrapper((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) - x = randn(rng, Float32, 5) - ps, st = LuxCore.setup(rng, model) - - @test fieldnames(typeof(ps)) == (:layer_1, :layer_2) - @test fieldnames(typeof(st)) == (:layer_1, :layer_2) - - @test LuxCore.parameterlength(ps) == - LuxCore.parameterlength(model) == - LuxCore.parameterlength(model.layers.layer_1) + - LuxCore.parameterlength(model.layers.layer_2) - @test LuxCore.statelength(st) == - LuxCore.statelength(model) == - LuxCore.statelength(model.layers.layer_1) + - LuxCore.statelength(model.layers.layer_2) - - @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) - - @test LuxCore.stateless_apply(model, x, ps) == - first(LuxCore.apply(model, x, ps, st)) - - @test_nowarn println(model) - end - - @testset "update_state API" begin - st = ( - layer_1=(training=Val(true), val=1), - layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),)), - ) - - st_ = LuxCore.testmode(st) - - @test st_.layer_1.training == Val(false) && - st_.layer_2.layer_2.training == Val(false) && - st_.layer_1.val == st.layer_1.val && - st_.layer_2.layer_1.val == st.layer_2.layer_1.val - - st = st_ - st_ = LuxCore.trainmode(st) - - @test st_.layer_1.training == Val(true) && - st_.layer_2.layer_2.training == Val(true) && - st_.layer_1.val == st.layer_1.val && - st_.layer_2.layer_1.val == st.layer_2.layer_1.val - - st_ = LuxCore.update_state(st, :val, -1) - @test st_.layer_1.training == st.layer_1.training && - st_.layer_2.layer_2.training == st.layer_2.layer_2.training && - st_.layer_1.val == -1 && - st_.layer_2.layer_1.val == -1 - end - - @testset "Functor Compatibility" begin - @testset "Basic Usage" begin - model = Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) - - children, reconstructor = Functors.functor(model) - - @test children isa NamedTuple - @test fieldnames(typeof(children)) == (:layers,) - @test children.layers isa NamedTuple - @test fieldnames(typeof(children.layers)) == (:layer_1, :layer_2) - @test children.layers.layer_1 isa Dense - @test children.layers.layer_2 isa Dense - @test children.layers.layer_1.in == 5 - @test children.layers.layer_1.out == 10 - @test children.layers.layer_2.in == 10 - @test children.layers.layer_2.out == 5 - - new_model = reconstructor((; - layers=(; layer_1=Dense(10, 5), layer_2=Dense(5, 10)), - )) - - @test new_model isa Chain - @test new_model.layers.layer_1.in == 10 - @test new_model.layers.layer_1.out == 5 - @test new_model.layers.layer_2.in == 5 - @test new_model.layers.layer_2.out == 10 - - model = ChainWrapper((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) - - children, reconstructor = Functors.functor(model) - - @test children isa NamedTuple - @test fieldnames(typeof(children)) == (:layers,) - @test children.layers isa NamedTuple - @test fieldnames(typeof(children.layers)) == (:layer_1, :layer_2) - @test children.layers.layer_1 isa Dense - @test children.layers.layer_2 isa Dense - @test children.layers.layer_1.in == 5 - @test children.layers.layer_1.out == 10 - @test children.layers.layer_2.in == 10 - @test children.layers.layer_2.out == 5 - - new_model = reconstructor((; - layers=(; layer_1=Dense(10, 5), layer_2=Dense(5, 10)), - )) - - @test new_model isa ChainWrapper - @test new_model.layers.layer_1.in == 10 - @test new_model.layers.layer_1.out == 5 - @test new_model.layers.layer_2.in == 5 - @test new_model.layers.layer_2.out == 10 - end - - @testset "Method Ambiguity" begin - # Needed if defining a layer that works with both Flux and Lux -- See DiffEqFlux.jl - # See https://github.com/SciML/DiffEqFlux.jl/pull/750#issuecomment-1373874944 - struct CustomLayer{M,P} <: AbstractLuxContainerLayer{(:model,)} - model::M - p::P - end - - @functor CustomLayer (p,) - - l = CustomLayer(x -> x, nothing) # Dummy Struct - - @test_nowarn Optimisers.trainable(l) - end - end - - @testset "Display Name" begin - struct StructWithoutName <: AbstractLuxLayer end - - model = StructWithoutName() - - @test LuxCore.display_name(model) == "StructWithoutName" - - struct StructWithName{N} <: AbstractLuxLayer - name::N - end - - model = StructWithName("Test") - - @test LuxCore.display_name(model) == "Test" - - model = StructWithName(nothing) - - @test LuxCore.display_name(model) == "StructWithName" - - @test LuxCore.display_name(rand(20)) == "Array" - end - - @testset "initialparameter/initialstate for Default Containers" begin - models1 = [ - Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), - Chain2(Dense(5, 10), Dense(10, 5)), - [Dense(5, 10), Dense(10, 5)], - ] - models2 = [ - Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), - Chain2(Dense(5, 10), Dense(10, 5)), - (Dense(5, 10), Dense(10, 5)), - ] - - for models in (models1, models2) - ps, st = LuxCore.setup(rng, models) - @test length(ps) == length(models) - @test length(st) == length(models) - @test typeof(ps[1]) == typeof(LuxCore.initialparameters(rng, models[1])) - @test typeof(ps[2]) == typeof(LuxCore.initialparameters(rng, models[2])) - @test typeof(ps[3][1]) == typeof(LuxCore.initialparameters(rng, models[3][1])) - @test typeof(ps[3][2]) == typeof(LuxCore.initialparameters(rng, models[3][2])) - @test typeof(st[1]) == typeof(LuxCore.initialstates(rng, models[1])) - @test typeof(st[2]) == typeof(LuxCore.initialstates(rng, models[2])) - @test typeof(st[3][1]) == typeof(LuxCore.initialstates(rng, models[3][1])) - @test typeof(st[3][2]) == typeof(LuxCore.initialstates(rng, models[3][2])) - end - end - - @testset "Convenience Checks" begin - models1 = [ - Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), - Chain2(Dense(5, 10), Dense(10, 5)), - [Dense(5, 10), Dense(10, 5)], - ] - - @test LuxCore.contains_lux_layer(models1) - - models2 = [1, 2, 3, 4] - - @test !LuxCore.contains_lux_layer(models2) - - models3 = [1, 2, 3, (; a=Dense(5, 10), b=Dense(10, 5))] - - @test LuxCore.contains_lux_layer(models3) - end - - @testset "Quality Assurance" begin - Aqua.test_all(LuxCore) - - @test check_no_implicit_imports(LuxCore) === nothing - @test check_no_stale_explicit_imports(LuxCore) === nothing - @test check_no_self_qualified_accesses(LuxCore) === nothing - @test check_all_explicit_imports_via_owners(LuxCore) === nothing - @test check_all_qualified_accesses_via_owners(LuxCore) === nothing - @test_broken check_all_explicit_imports_are_public(LuxCore) === nothing - end - - @testset "replicate" begin - rng = Random.default_rng() - @test LuxCore.replicate(rng) === rng - @test LuxCore.replicate(rng) == rng - - rng = Xoshiro(1234) - @test LuxCore.replicate(rng) !== rng - @test LuxCore.replicate(rng) == rng - end - - @testset "empty fleaves" begin - @test length(fleaves(NamedTuple())) == 0 - @test !LuxCore.check_fmap_condition(isodd, nothing, NamedTuple()) - end - - @testset "Common Lux + Enzyme Mistakes" begin - d = Dense(2, 2) - - @test_throws ArgumentError Active(d) - @test_throws ArgumentError Duplicated(d, d) - @test_throws ArgumentError DuplicatedNoNeed(d, d) - @test_throws ArgumentError BatchDuplicated(d, (d, d)) - @test_throws ArgumentError BatchDuplicatedNoNeed(d, (d, d)) - @test Const(d) isa Const - end - - @testset "Device Transfer Warnings" begin - my_layer = Dense(2, 2) - - dev = cpu_device() - @test_logs ( - :warn, - "Lux layers are stateless and hence don't participate in device \ - transfers. Apply this function on the parameters and states generated \ - using `LuxCore.setup`.", - ) dev(my_layer) - end - - @testset "nested `training` key: Issue Lux.jl#849" begin - st = ( - encoder=(layer_1=NamedTuple(), layer_2=(; training=Val{true}())), - μ=NamedTuple(), - logσ=NamedTuple(), - decoder=( - layer_1=NamedTuple(), - layer_2=NamedTuple(), - layer_3=NamedTuple(), - layer_4=(running_mean=Float32[0.0, 0.0], training=Val{true}()), - ), - rng=Xoshiro(), - training=Val{true}(), - ) - - @test st.encoder.layer_2.training isa Val{true} - @test st.decoder.layer_4.training isa Val{true} - @test st.training isa Val{true} - - st_test = LuxCore.testmode(st) - - @test st_test.encoder.layer_2.training isa Val{false} - @test st_test.decoder.layer_4.training isa Val{false} - @test st_test.training isa Val{false} - - st_train = LuxCore.trainmode(st_test) - - @test st_train.encoder.layer_2.training isa Val{true} - @test st_train.decoder.layer_4.training isa Val{true} - @test st_train.training isa Val{true} - end -end +testsuite = find_tests(@__DIR__) +delete!(testsuite, "common") +runtests(LuxCore, ARGS; testsuite) diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 0e2966e765..945ed0d09d 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.17.3" +version = "1.17.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/lib/MLDataDevices/ext/AMDGPUExt.jl b/lib/MLDataDevices/ext/AMDGPUExt.jl index d125f0a572..8e5a96c5e3 100644 --- a/lib/MLDataDevices/ext/AMDGPUExt.jl +++ b/lib/MLDataDevices/ext/AMDGPUExt.jl @@ -82,37 +82,24 @@ function amdgpu_array_adapt(::Type{T}, x) where {T} return Internal.array_adapt(AMDGPU.roc, ROCArray, T, x) end -function Adapt.adapt_storage(::AMDGPUDevice{D,Missing}, x::AbstractArray) where {D} - MLDataDevices.get_device_type(x) <: AMDGPUDevice && return x - return amdgpu_array_adapt(Missing, x) -end - -function Adapt.adapt_storage(::AMDGPUDevice{D,Nothing}, x::AbstractArray) where {D} - MLDataDevices.get_device_type(x) <: AMDGPUDevice && return x - return amdgpu_array_adapt(Nothing, x) -end - -function Adapt.adapt_storage( - ::AMDGPUDevice{D,T}, x::AbstractArray{ET} -) where {D,T<:AbstractFloat,ET<:Number} - MLDataDevices.get_device_type(x) <: AMDGPUDevice && ET == T && return x - return amdgpu_array_adapt(T, x) -end - function Adapt.adapt_storage(to::AMDGPUDevice{D,E}, x::AbstractArray) where {D,E} old_dev = AMDGPU.device() # remember the current device dev = MLDataDevices.get_device(x) if !(dev isa AMDGPUDevice) - AMDGPU.device!(to.device) + to.device !== nothing && AMDGPU.device!(to.device) x_new = amdgpu_array_adapt(to, x) - AMDGPU.device!(old_dev) + to.device !== nothing && AMDGPU.device!(old_dev) return x_new - elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) + elseif ( + dev.device === nothing || + to.device === nothing || + AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) + ) return x else - AMDGPU.device!(to.device) + to.device !== nothing && AMDGPU.device!(to.device) x_new = copy(x) - AMDGPU.device!(old_dev) + to.device !== nothing && AMDGPU.device!(old_dev) return x_new end end diff --git a/lib/MLDataDevices/ext/CUDAExt.jl b/lib/MLDataDevices/ext/CUDAExt.jl index 5e6b23440c..40499e28dc 100644 --- a/lib/MLDataDevices/ext/CUDAExt.jl +++ b/lib/MLDataDevices/ext/CUDAExt.jl @@ -56,37 +56,24 @@ function Internal.unsafe_free_internal!(::Type{CUDADevice}, x::AbstractArray) end # Device Transfer -cuda_array_adapt(::Type{T}, x) where {T} = Internal.array_adapt(CUDA.cu, CuArray, T, x) - -function Adapt.adapt_storage(::CUDADevice{D,Missing}, x::AbstractArray) where {D} - MLDataDevices.get_device_type(x) <: CUDADevice && return x - return cuda_array_adapt(Missing, x) -end - -function Adapt.adapt_storage(::CUDADevice{D,Nothing}, x::AbstractArray) where {D} - MLDataDevices.get_device_type(x) <: CUDADevice && return x - return cuda_array_adapt(Nothing, x) -end - -function Adapt.adapt_storage(::CUDADevice{D,T}, x::AbstractArray) where {D,T<:AbstractFloat} - MLDataDevices.get_device_type(x) <: CUDADevice && eltype(x) == T && return x - return cuda_array_adapt(T, x) +function cuda_array_adapt(::CUDADevice{D,E}, x) where {D,E} + return Internal.array_adapt(CUDA.cu, CuArray, E, x) end function Adapt.adapt_storage(to::CUDADevice{D,E}, x::AbstractArray) where {D,E} old_dev = CUDA.device() # remember the current device dev = MLDataDevices.get_device(x) if !(dev isa CUDADevice) - CUDA.device!(to.device) + to.device !== nothing && CUDA.device!(to.device) x_new = cuda_array_adapt(to, x) - CUDA.device!(old_dev) + to.device !== nothing && CUDA.device!(old_dev) return x_new - elseif dev.device == to.device + elseif dev.device === nothing || to.device === nothing || dev.device == to.device return x else - CUDA.device!(to.device) + to.device !== nothing && CUDA.device!(to.device) x_new = copy(x) - CUDA.device!(old_dev) + to.device !== nothing && CUDA.device!(old_dev) return x_new end end diff --git a/lib/MLDataDevices/ext/GPUArraysSparseArraysExt.jl b/lib/MLDataDevices/ext/GPUArraysSparseArraysExt.jl index b602983e32..fa5df44bcf 100644 --- a/lib/MLDataDevices/ext/GPUArraysSparseArraysExt.jl +++ b/lib/MLDataDevices/ext/GPUArraysSparseArraysExt.jl @@ -11,14 +11,16 @@ Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng() Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state) Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state) -for (T1, T2) in ( - (:AbstractGPUSparseMatrixCSC, :SparseMatrixCSC), - (:AbstractGPUSparseVector, :SparseVector), -) +for (T1, T2) in + ((AbstractGPUSparseMatrixCSC, SparseMatrixCSC), (AbstractGPUSparseVector, SparseVector)) @eval begin - Adapt.adapt_storage(::CPUDevice{Missing}, x::$(T1)) = $(T2)(x) - Adapt.adapt_storage(::CPUDevice{Nothing}, x::$(T1)) = $(T2)(x) - Adapt.adapt_storage(::CPUDevice{T}, x::$(T1)) where {T<:AbstractFloat} = $(T2){T}(x) + function Adapt.adapt_storage(::CPUDevice{T}, x::$(T1)) where {T} + if T <: AbstractFloat + eltype(x) <: Complex && return $(T2){Complex{T}}(x) + return $(T2){T}(x) + end + return $(T2)(x) + end end end diff --git a/lib/MLDataDevices/ext/MetalExt.jl b/lib/MLDataDevices/ext/MetalExt.jl index 29a9fd3e8e..37bc1086fe 100644 --- a/lib/MLDataDevices/ext/MetalExt.jl +++ b/lib/MLDataDevices/ext/MetalExt.jl @@ -2,7 +2,8 @@ module MetalExt using Adapt: Adapt using GPUArrays: GPUArrays -using MLDataDevices: MLDataDevices, Internal, MetalDevice, reset_gpu_device! +using MLDataDevices: + MLDataDevices, Internal, MetalDevice, reset_gpu_device!, get_device_type using Metal: Metal, MtlArray __init__() = reset_gpu_device!() @@ -29,21 +30,12 @@ function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray) end # Device Transfer -metal_array_adapt(::Type{T}, x) where {T} = Internal.array_adapt(Metal.mtl, MtlArray, T, x) - -function Adapt.adapt_storage(::MetalDevice{Missing}, x::AbstractArray) - MLDataDevices.get_device_type(x) <: MetalDevice && return x - return metal_array_adapt(Missing, x) -end - -function Adapt.adapt_storage(::MetalDevice{Nothing}, x::AbstractArray) - MLDataDevices.get_device_type(x) <: MetalDevice && return x - return metal_array_adapt(Nothing, x) -end - -function Adapt.adapt_storage(::MetalDevice{T}, x::AbstractArray) where {T<:AbstractFloat} - MLDataDevices.get_device_type(x) <: MetalDevice && eltype(x) == T && return x - return metal_array_adapt(T, x) +function Adapt.adapt_storage(::MetalDevice{T}, x::AbstractArray) where {T} + # Metal is single-device, so we only need to check the device type + if get_device_type(x) <: MetalDevice + Internal.return_without_conversion(T, x) && return x + end + return Internal.array_adapt(Metal.mtl, MtlArray, T, x) end end diff --git a/lib/MLDataDevices/ext/OpenCLExt.jl b/lib/MLDataDevices/ext/OpenCLExt.jl index 6c5273ec93..c5e1894000 100644 --- a/lib/MLDataDevices/ext/OpenCLExt.jl +++ b/lib/MLDataDevices/ext/OpenCLExt.jl @@ -76,22 +76,12 @@ end opencl_array_adapt(::Type{T}, x) where {T} = Internal.array_adapt(CLArray, CLArray, T, x) -function Adapt.adapt_storage(::OpenCLDevice{Missing}, x::AbstractArray) - MLDataDevices.get_device_type(x) <: OpenCLDevice && return x - return opencl_array_adapt(Missing, x) -end - -function Adapt.adapt_storage(::OpenCLDevice{Nothing}, x::AbstractArray) - MLDataDevices.get_device_type(x) <: OpenCLDevice && return x - return opencl_array_adapt(Nothing, x) -end - -function Adapt.adapt_storage(::OpenCLDevice{T}, x::AbstractArray) where {T<:AbstractFloat} - MLDataDevices.get_device_type(x) <: OpenCLDevice && eltype(x) == T && return x - if T === Float64 && !SUPPORTS_FP64[cl.device()] - throw(ArgumentError("FP64 is not supported on this device")) +function Adapt.adapt_storage(::OpenCLDevice{T}, x::AbstractArray) where {T} + if MLDataDevices.get_device_type(x) <: OpenCLDevice + Internal.return_without_conversion(T, x) && return x end - return opencl_array_adapt(T, x) + + return Internal.array_adapt(CLArray, CLArray, T, x) end end diff --git a/lib/MLDataDevices/ext/ReactantExt.jl b/lib/MLDataDevices/ext/ReactantExt.jl index 984fa1d488..aa1f81675d 100644 --- a/lib/MLDataDevices/ext/ReactantExt.jl +++ b/lib/MLDataDevices/ext/ReactantExt.jl @@ -103,20 +103,14 @@ Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothi # Device Transfer Profiler.@annotate "Device Transfer (Reactant)" function Adapt.adapt_storage( - dev::ReactantDevice{C,D,S,Missing}, x::AbstractArray -) where {C,D,S} - return ConcreteRArray(x; device_to_kwargs(dev, x)...) # Preserves eltype -end + dev::ReactantDevice{C,D,S,T}, x::AbstractArray{ET} +) where {C,D,S,T,ET} + if T === Nothing || T === Missing + return ConcreteRArray(x; device_to_kwargs(dev, x)...) # Preserves eltype + end -Profiler.@annotate "Device Transfer (Reactant)" function Adapt.adapt_storage( - dev::ReactantDevice{C,D,S,Nothing}, x::AbstractArray -) where {C,D,S} - return ConcreteRArray(x; device_to_kwargs(dev, x)...) # Preserves eltype -end + @assert T <: AbstractFloat -Profiler.@annotate "Device Transfer (Reactant)" function Adapt.adapt_storage( - dev::ReactantDevice{C,D,S,T}, x::AbstractArray{ET} -) where {C,D,S,T<:AbstractFloat,ET} # Convert eltype first, then move to device if ET <: AbstractFloat x_converted = convert(AbstractArray{T}, x) diff --git a/lib/MLDataDevices/ext/SparseArraysExt.jl b/lib/MLDataDevices/ext/SparseArraysExt.jl index ef9190f678..f4e6098e76 100644 --- a/lib/MLDataDevices/ext/SparseArraysExt.jl +++ b/lib/MLDataDevices/ext/SparseArraysExt.jl @@ -7,12 +7,12 @@ using SparseArrays: AbstractSparseArray, nonzeros Internal.get_device(x::AbstractSparseArray) = Internal.get_device(nonzeros(x)) Internal.get_device_type(x::AbstractSparseArray) = Internal.get_device_type(nonzeros(x)) -Adapt.adapt_storage(::CPUDevice{Missing}, x::AbstractSparseArray) = x -Adapt.adapt_storage(::CPUDevice{Nothing}, x::AbstractSparseArray) = x -function Adapt.adapt_storage( - ::CPUDevice{T}, x::AbstractSparseArray -) where {T<:AbstractFloat} - return convert(AbstractSparseArray{T}, x) +function Adapt.adapt_storage(::CPUDevice{T}, x::AbstractSparseArray) where {T} + if T <: AbstractFloat + eltype(x) <: Complex && return convert(AbstractSparseArray{Complex{T}}, x) + return convert(AbstractSparseArray{T}, x) + end + return x end end diff --git a/lib/MLDataDevices/ext/oneAPIExt.jl b/lib/MLDataDevices/ext/oneAPIExt.jl index fcda118c68..f372e0caa7 100644 --- a/lib/MLDataDevices/ext/oneAPIExt.jl +++ b/lib/MLDataDevices/ext/oneAPIExt.jl @@ -76,24 +76,11 @@ for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) end end -oneapi_array_adapt(::Type{T}, x) where {T} = Internal.array_adapt(oneArray, oneArray, T, x) - -function Adapt.adapt_storage(::oneAPIDevice{Missing}, x::AbstractArray) - MLDataDevices.get_device_type(x) <: oneAPIDevice && return x - return oneapi_array_adapt(Missing, x) -end - -function Adapt.adapt_storage(::oneAPIDevice{Nothing}, x::AbstractArray) - MLDataDevices.get_device_type(x) <: oneAPIDevice && return x - return oneapi_array_adapt(Nothing, x) -end - -function Adapt.adapt_storage(::oneAPIDevice{T}, x::AbstractArray) where {T<:AbstractFloat} - MLDataDevices.get_device_type(x) <: oneAPIDevice && eltype(x) == T && return x - if T === Float64 && !SUPPORTS_FP64[oneAPI.device()] - throw(ArgumentError("FP64 is not supported on this device")) +function Adapt.adapt_storage(::oneAPIDevice{T}, x::AbstractArray) where {T} + if MLDataDevices.get_device_type(x) <: oneAPIDevice + Internal.return_without_conversion(T, x) && return x end - return oneapi_array_adapt(T, x) + return Internal.array_adapt(oneArray, oneArray, T, x) end end diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl index bc9f5971d1..5e5e0c9de0 100644 --- a/lib/MLDataDevices/src/internal.jl +++ b/lib/MLDataDevices/src/internal.jl @@ -289,8 +289,9 @@ end function to_rarray_internal end # Utility function to facilitate data transfer -# For AbstractFloat and Complex{<:AbstractFloat} arrays, we provide specialized methods to avoid -# ambiguity with the general fallback and to enable efficient type conversion when needed. +# For AbstractFloat and Complex{<:AbstractFloat} arrays, we provide specialized methods to +# avoid ambiguity with the general fallback and to enable efficient type conversion when +# needed. function array_adapt( f::F, ::Type{aType}, ::Type{Missing}, x::AbstractArray{<:AbstractFloat} ) where {F,aType} @@ -357,4 +358,14 @@ function array_adapt(::F, ::Type{aType}, ::Type{E}, x::AbstractArray{T}) where { return aType(x) end +return_without_conversion(::Type{Nothing}, ::AbstractArray) = true +return_without_conversion(::Type{Missing}, ::AbstractArray) = true +return_without_conversion(::Type{T}, ::AbstractArray{T}) where {T<:AbstractFloat} = true +function return_without_conversion( + ::Type{T}, ::AbstractArray{Complex{T}} +) where {T<:AbstractFloat} + return true +end +return_without_conversion(::Type{T}, ::AbstractArray) where {T} = false + end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 2da973f03e..221633f43d 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -528,19 +528,13 @@ for op in (:get_device, :get_device_type) end # Adapt Interface -function Adapt.adapt_storage(dev::CPUDevice{Missing}, x::AbstractArray) - get_device_type(x) <: CPUDevice && return x - return Array(x) -end - -function Adapt.adapt_storage(dev::CPUDevice{Nothing}, x::AbstractArray) - get_device_type(x) <: CPUDevice && return x - return Array(x) # Preserve eltype -end +function Adapt.adapt_storage(::CPUDevice{T}, x::AbstractArray) where {T} + if get_device_type(x) <: CPUDevice + Internal.return_without_conversion(T, x) && return x + end -function Adapt.adapt_storage(dev::CPUDevice{T}, x::AbstractArray) where {T<:AbstractFloat} - get_device_type(x) <: CPUDevice && eltype(x) == T && return x x_cpu = Array(x) + Internal.return_without_conversion(T, x_cpu) && return x_cpu # Only convert floating-point and complex floating-point types ET = eltype(x_cpu) diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index 89d0a0a8fb..3e0c1e2add 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -1,17 +1,23 @@ [deps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2" +ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" @@ -21,6 +27,9 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" +pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd" [sources] LuxTestUtils = {path = "../../LuxTestUtils"} @@ -38,6 +47,7 @@ ForwardDiff = "0.10.36, 1" Functors = "0.5" MLUtils = "0.4.4" OneHotArrays = "0.2.5" +ParallelTestRunner = "2.1" Pkg = "1.10" Random = "1.10" Reactant = "0.2.170" diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index 9d6058290c..6a52201495 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -1,7 +1,11 @@ using MLDataDevices, Random, Test using ArrayInterface: parameterless_type -@testset "CPU Fallback" begin +include("common.jl") + +@test_in_separate_process "CPU Fallback" begin + using MLDataDevices, Test + @test !MLDataDevices.functional(AMDGPUDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice diff --git a/lib/MLDataDevices/test/common.jl b/lib/MLDataDevices/test/common.jl new file mode 100644 index 0000000000..c67e810d51 --- /dev/null +++ b/lib/MLDataDevices/test/common.jl @@ -0,0 +1,24 @@ +using Pkg: Pkg +using Test: @test, @testset + +macro test_in_separate_process(testname, expr) + tmpfile = tempname() * ".jl" + open(tmpfile, "w") do io + println(io, "using Pkg, MLDataDevices, Test") + println(io, expr) + end + project_path = dirname(Pkg.project().path) + + run_cmd = `$(Base.julia_cmd()) --color=yes --project=$(project_path) --startup-file=no --code-coverage=user $(tmpfile)` + + return quote + @testset $(testname) begin + try + run($run_cmd) + @test true + catch + @test false + end + end + end +end diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index 639cd317f8..192c6f0074 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -1,7 +1,11 @@ using MLDataDevices, Random, Functors, Test using ArrayInterface: parameterless_type -@testset "CPU Fallback" begin +include("common.jl") + +@test_in_separate_process "CPU Fallback" begin + using MLDataDevices, Test + @test !MLDataDevices.functional(CUDADevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice diff --git a/lib/MLDataDevices/test/eltype_tests.jl b/lib/MLDataDevices/test/eltype_tests.jl index f0ec6c8967..24d66263ef 100644 --- a/lib/MLDataDevices/test/eltype_tests.jl +++ b/lib/MLDataDevices/test/eltype_tests.jl @@ -1,25 +1,25 @@ -@testitem "Device Eltype Functionality" setup = [SharedTestSetup] tags = [:misc] begin - using MLDataDevices, Random, Test +using MLDataDevices, Random, Test +@testset "Device Eltype Functionality" begin @testset "CPU Device with Eltype" begin # Test default behavior (missing eltype) cpu_default = cpu_device() @test cpu_default isa CPUDevice{Missing} # Test eltype=nothing (preserve type) - cpu_preserve = cpu_device(; eltype=nothing) + cpu_preserve = cpu_device(nothing) @test cpu_preserve isa CPUDevice{Nothing} # Test specific eltype - cpu_f32 = cpu_device(; eltype=Float32) + cpu_f32 = cpu_device(Float32) @test cpu_f32 isa CPUDevice{Float32} - cpu_f64 = cpu_device(; eltype=Float64) + cpu_f64 = cpu_device(Float64) @test cpu_f64 isa CPUDevice{Float64} # Test invalid eltype - @test_throws ArgumentError cpu_device(eltype=Int) - @test_throws ArgumentError cpu_device(eltype=String) + @test_throws MethodError cpu_device(Int) + @test_throws MethodError cpu_device(String) end @testset "CPU Device Array Conversion" begin @@ -34,13 +34,13 @@ @test y_f64 == x_f64 # Test nothing eltype (preserve) - cpu_preserve = cpu_device(; eltype=nothing) + cpu_preserve = cpu_device(nothing) y_f64_preserve = cpu_preserve(x_f64) @test eltype(y_f64_preserve) === Float64 @test y_f64_preserve == x_f64 # Test specific eltype conversion - cpu_f32 = cpu_device(; eltype=Float32) + cpu_f32 = cpu_device(Float32) y_f32 = cpu_f32(x_f64) @test eltype(y_f32) === Float32 @test y_f32 ≈ Float32.(x_f64) @@ -57,55 +57,6 @@ @test y_complex ≈ ComplexF32.(x_complex) end - @testset "GPU Device Creation with Eltype" begin - # Test default behavior - try - gpu_default = gpu_device(; eltype=nothing) - @test MLDataDevices.get_eltype(gpu_default) === Nothing - catch e - if e isa MLDataDevices.Internal.DeviceSelectionException - @test_skip "No functional GPU available" - else - rethrow() - end - end - - try - gpu_f32 = gpu_device(; eltype=Float32) - @test MLDataDevices.get_eltype(gpu_f32) === Float32 - catch e - if e isa MLDataDevices.Internal.DeviceSelectionException - @test_skip "No functional GPU available" - else - rethrow() - end - end - end - - @testset "Reactant Device with Eltype" begin - # Test eltype parameter - reactant_default = reactant_device(; eltype=nothing) - @test reactant_default isa CPUDevice{Nothing} # Falls back to CPU since Reactant not loaded - - reactant_f32 = reactant_device(; eltype=Float32) - @test reactant_f32 isa CPUDevice{Float32} # Falls back to CPU since Reactant not loaded - end - - @testset "Helper Functions" begin - cpu_f32 = cpu_device(; eltype=Float32) - cpu_f64 = cpu_device(; eltype=Float64) - cpu_nothing = cpu_device(; eltype=nothing) - - # Test get_eltype - @test MLDataDevices.get_eltype(cpu_f32) === Float32 - @test MLDataDevices.get_eltype(cpu_f64) === Float64 - @test MLDataDevices.get_eltype(cpu_nothing) === Nothing - - # Test with_eltype - cpu_new = MLDataDevices.with_eltype(cpu_f32, Float64) - @test MLDataDevices.get_eltype(cpu_new) === Float64 - end - @testset "Device Constructor Backward Compatibility" begin # Test that old constructors still work cpu_old = CPUDevice() diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index c674e983e9..9b65f6cbc9 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -1,7 +1,11 @@ using MLDataDevices, Random, Test using ArrayInterface: parameterless_type -@testset "CPU Fallback" begin +include("common.jl") + +@test_in_separate_process "CPU Fallback" begin + using MLDataDevices, Test + @test !MLDataDevices.functional(MetalDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 1e352c4ac0..79af3f9593 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -56,16 +56,18 @@ end gdev = gpu_device() if !(gdev isa MetalDevice) # On intel devices causes problems - x = randn(10) - ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, gdev, x) - @test ∂dev === nothing - @test ∂x ≈ ones(10) - - x = gdev(randn(10)) - ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, cpu_device(), x) - @test ∂dev === nothing - @test ∂x ≈ gdev(ones(10)) - @test get_device(∂x) isa parameterless_type(typeof(gdev)) + if VERSION < v"1.12-" + x = randn(10) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, gdev, x) + @test ∂dev === nothing + @test ∂x ≈ ones(10) + + x = gdev(randn(10)) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, cpu_device(), x) + @test ∂dev === nothing + @test ∂x ≈ gdev(ones(10)) + @test get_device(∂x) isa parameterless_type(typeof(gdev)) + end end end @@ -75,7 +77,10 @@ end @test cdev(sprand(10, 10, 0.9)) isa SparseMatrixCSC @test cdev(1:10) isa AbstractRange - @test cdev(Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4))) isa Zygote.OneElement + + if VERSION < v"1.12-" + @test cdev(Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4))) isa Zygote.OneElement + end end @testset "RecursiveArrayTools" begin @@ -132,7 +137,7 @@ end "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend.", ) gpu_backend!() - for backend in ( + @testset for backend in ( :CUDA, :AMDGPU, :oneAPI, @@ -230,31 +235,37 @@ end end @testset "Zygote.gradient(wrapped arrays)" begin - using Zygote + if VERSION < v"1.12-" + using Zygote - x = rand(4, 4) - cdev = cpu_device() + x = rand(4, 4) + cdev = cpu_device() - @test get_device(only(Zygote.gradient(x -> sum(abs2, cdev(x)), x'))) isa CPUDevice + @test get_device(only(Zygote.gradient(x -> sum(abs2, cdev(x)), x'))) isa CPUDevice - gdev = gpu_device() + gdev = gpu_device() - @test get_device(only(Zygote.gradient(x -> sum(abs2, gdev(x)), x'))) isa CPUDevice + @test get_device(only(Zygote.gradient(x -> sum(abs2, gdev(x)), x'))) isa CPUDevice + end end @testset "Zygote and ChainRules OneElement #1016" begin - using Zygote + if VERSION < v"1.12-" + using Zygote - cdev = cpu_device() - gdev = gpu_device() + cdev = cpu_device() + gdev = gpu_device() - g = only(Zygote.gradient(x -> cdev(2 .* gdev(x))[1], Float32[1, 2, 3])) - @test g isa Vector{Float32} + g = only(Zygote.gradient(x -> cdev(2 .* gdev(x))[1], Float32[1, 2, 3])) + @test g isa Vector{Float32} - g = only( - Zygote.gradient(x -> cdev(gdev(x) * gdev(x))[1, 2], Float32[1 2 3; 4 5 6; 7 8 9]) - ) - @test g isa Matrix{Float32} + g = only( + Zygote.gradient( + x -> cdev(gdev(x) * gdev(x))[1, 2], Float32[1 2 3; 4 5 6; 7 8 9] + ), + ) + @test g isa Matrix{Float32} + end end @testset "OneHotArrays" begin diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 29603e15cd..dad8dcf6d2 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -1,7 +1,11 @@ using MLDataDevices, Random, Test using ArrayInterface: parameterless_type -@testset "CPU Fallback" begin +include("common.jl") + +@test_in_separate_process "CPU Fallback" begin + using MLDataDevices, Test + @test !MLDataDevices.functional(oneAPIDevice) @test cpu_device() isa CPUDevice @test gpu_device() isa CPUDevice diff --git a/lib/MLDataDevices/test/opencl_tests.jl b/lib/MLDataDevices/test/opencl_tests.jl index d8f5cf0d76..f59fbae940 100644 --- a/lib/MLDataDevices/test/opencl_tests.jl +++ b/lib/MLDataDevices/test/opencl_tests.jl @@ -1,8 +1,19 @@ using OpenCL, pocl_jll - using MLDataDevices, Random, Test using ArrayInterface: parameterless_type +include("common.jl") + +@test_in_separate_process "CPU Fallback" begin + using MLDataDevices, Test + + @test !MLDataDevices.functional(OpenCLDevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true) + @test_throws Exception default_device_rng(OpenCLDevice()) +end + if !MLDataDevices.functional(OpenCLDevice) @warn "OpenCL.jl is not functional. Skipping OpenCL tests." exit() diff --git a/lib/MLDataDevices/test/openclcpu_tests.jl b/lib/MLDataDevices/test/openclcpu_tests.jl deleted file mode 100644 index e1313c3d59..0000000000 --- a/lib/MLDataDevices/test/openclcpu_tests.jl +++ /dev/null @@ -1,10 +0,0 @@ -using MLDataDevices, Random, Test -using ArrayInterface: parameterless_type - -@testset "CPU Fallback" begin - @test !MLDataDevices.functional(OpenCLDevice) - @test cpu_device() isa CPUDevice - @test gpu_device() isa CPUDevice - @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true) - @test_throws Exception default_device_rng(OpenCLDevice()) -end diff --git a/lib/MLDataDevices/test/reactant_tests.jl b/lib/MLDataDevices/test/reactant_tests.jl index 6b3dc286ba..71992eac27 100644 --- a/lib/MLDataDevices/test/reactant_tests.jl +++ b/lib/MLDataDevices/test/reactant_tests.jl @@ -1,7 +1,11 @@ using MLDataDevices, Random, Test using ArrayInterface: parameterless_type -@testset "CPU Fallback" begin +include("common.jl") + +@test_in_separate_process "CPU Fallback" begin + using MLDataDevices, Test + @test !MLDataDevices.functional(ReactantDevice) @test cpu_device() isa CPUDevice @test reactant_device() isa CPUDevice diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl index 0e0ac01ea6..856f9852bf 100644 --- a/lib/MLDataDevices/test/runtests.jl +++ b/lib/MLDataDevices/test/runtests.jl @@ -1,56 +1,58 @@ -using Pkg: Pkg, PackageSpec -using Test -using LuxTestUtils - -function parse_test_args() - test_args_from_env = @isdefined(TEST_ARGS) ? TEST_ARGS : ARGS - test_args = Dict{String,String}() - for arg in test_args_from_env - if contains(arg, "=") - key, value = split(arg, "="; limit=2) - test_args[key] = value +using Pkg, MLDataDevices, Test, ParallelTestRunner, LuxTestUtils + +parsed_args = parse_args(@isdefined(TEST_ARGS) ? TEST_ARGS : ARGS; custom=["BACKEND_GROUP"]) + +const BACKEND_GROUP = lowercase( + something(get(parsed_args.custom, "BACKEND_GROUP", nothing), "all") +) + +testsuite = find_tests(@__DIR__) + +# Filter testsuite based on BACKEND_GROUP +backend_test_files = Set([ + "reactant_tests", + "cuda_tests", + "amdgpu_tests", + "metal_tests", + "oneapi_tests", + "opencl_tests", + "openclcpu_tests", +]) + +for file in keys(testsuite) + if file ∈ backend_test_files + # Remove backend-specific tests unless that backend is being tested + if BACKEND_GROUP == "all" + # Keep all + elseif BACKEND_GROUP == "cpu" + delete!(testsuite, file) + elseif file == "$(BACKEND_GROUP)_tests" + # Keep this backend's tests + else + delete!(testsuite, file) end end - @info "Parsed test args" test_args - return test_args end -const PARSED_TEST_ARGS = parse_test_args() +delete!(testsuite, "common") +delete!(testsuite, "iterator_tests") +delete!(testsuite, "misc_tests") -const BACKEND_GROUP = lowercase(get(PARSED_TEST_ARGS, "BACKEND_GROUP", "none")) +total_jobs = min( + something(parsed_args.jobs, ParallelTestRunner.default_njobs()), length(keys(testsuite)) +) -const EXTRA_PKGS = LuxTestUtils.packages_to_install(BACKEND_GROUP) - -if !isempty(EXTRA_PKGS) - @info "Installing Extra Packages for testing" EXTRA_PKGS - isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS) - Base.retry_load_extensions() - Pkg.instantiate() +additional_testsuite = Dict{String,Expr}() +for file in ("iterator_tests.jl", "misc_tests.jl") + testfile = joinpath(@__DIR__, file) + additional_testsuite[file] = :(include($testfile)) end -@testset "MLDataDevices Tests" begin - all_files = map( - Base.Fix2(*, "_tests.jl"), - ["reactant", "cuda", "amdgpu", "metal", "oneapi", "opencl"], - ) - file_names = if BACKEND_GROUP == "all" - all_files - elseif BACKEND_GROUP ∈ ("cpu", "none") - [] - elseif BACKEND_GROUP == "opencl" - ["opencl_tests.jl", "openclcpu_tests.jl"] - else - [BACKEND_GROUP * "_tests.jl"] - end - - append!(file_names, ["iterator_tests.jl", "misc_tests.jl", "qa_tests.jl"]) - - @testset "$(file_name)" for file_name in file_names - @info "Running $(file_name)" - withenv("BACKEND_GROUP" => BACKEND_GROUP) do - run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) - --startup-file=no --code-coverage=user $(@__DIR__)/$file_name`) - Test.@test true - end - end +withenv( + "XLA_REACTANT_GPU_MEM_FRACTION" => 1 / (total_jobs + 0.1), + "XLA_REACTANT_GPU_PREALLOCATE" => false, + "BACKEND_GROUP" => BACKEND_GROUP, +) do + runtests(MLDataDevices, parsed_args; testsuite) + runtests(MLDataDevices, parsed_args; testsuite=additional_testsuite) end diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml index accd3148ad..8b0500e759 100644 --- a/lib/WeightInitializers/test/Project.toml +++ b/lib/WeightInitializers/test/Project.toml @@ -1,21 +1,28 @@ [deps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2" +ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" +pocl_jll = "627d6b7a-bbe6-5189-83e7-98cc0a5aeadd" [sources] LuxTestUtils = {path = "../../LuxTestUtils"} @@ -23,16 +30,15 @@ WeightInitializers = {path = ".."} [compat] Aqua = "0.8.7" -CPUSummary = "0.2.6" Documenter = "1.5.0" ExplicitImports = "1.9.0" GPUArrays = "10.2, 11" GPUArraysCore = "0.1.6, 0.2" InteractiveUtils = "<0.0.1, 1" LinearAlgebra = "1.10" +ParallelTestRunner = "2.1" Pkg = "1.10" Random = "1.10" -ReTestItems = "1.24.0" Reactant = "0.2.170" StableRNGs = "1" Statistics = "1.10" diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/basic.jl similarity index 54% rename from lib/WeightInitializers/test/initializers_tests.jl rename to lib/WeightInitializers/test/basic.jl index eb285d892a..72cd711882 100644 --- a/lib/WeightInitializers/test/initializers_tests.jl +++ b/lib/WeightInitializers/test/basic.jl @@ -1,189 +1,13 @@ -@testitem "Warning: truncated_normal" begin - @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \ - the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) -end - -@testitem "Identity Initialization" begin - using LinearAlgebra - - @testset "2D identity matrices" begin - # Square matrix should be identity - mat = identity_init(5, 5) - @test mat ≈ Matrix{Float32}(I, 5, 5) - @test diag(mat) == ones(Float32, 5) - # Check off-diagonal elements are zero - for i in 1:5, j in 1:5 - if i != j - @test mat[i, j] == 0.0f0 - end - end - - # Test with gain parameter - mat_gain = identity_init(4, 4; gain=2.5) - @test mat_gain ≈ 2.5f0 * Matrix{Float32}(I, 4, 4) - @test diag(mat_gain) == fill(2.5f0, 4) - - # Non-square matrices - mat_rect1 = identity_init(3, 5) - @test size(mat_rect1) == (3, 5) - @test diag(mat_rect1) == ones(Float32, 3) - @test mat_rect1[:, 4:5] == zeros(Float32, 3, 2) - - mat_rect2 = identity_init(5, 3) - @test size(mat_rect2) == (5, 3) - @test diag(mat_rect2) == ones(Float32, 3) - @test mat_rect2[4:5, :] == zeros(Float32, 2, 3) - end - - @testset "Non-identity sizes" begin - @test identity_init(2, 3)[:, end] == zeros(Float32, 2) - @test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2) - @test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3) - @test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3) - @test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3) - end -end - -@testitem "Orthogonal Initialization" setup = [SharedTestSetup] begin - using GPUArraysCore, LinearAlgebra - - @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for ( - rng, arrtype, supports_fp64, backend - ) in RNGS_ARRTYPES - # A matrix of dim = (m,n) with m > n should produce a QR decomposition. - # In the other case, the transpose should be taken to compute the QR decomposition. - if backend == "oneapi" || backend == "metal" # `qr` not implemented - @test_broken orthogonal(rng, 3, 5) isa arrtype{Float32,2} - continue - end - - for (rows, cols) in [(5, 3), (3, 5)] - v = orthogonal(rng, rows, cols) - GPUArraysCore.@allowscalar if rows < cols - (@test v * v' ≈ I(rows)) - else - (@test v' * v ≈ I(cols)) - end - end - - for mat in [(3, 4, 5), (2, 2, 5)] - v = orthogonal(rng, mat...) - cols = mat[end] - rows = div(prod(mat), cols) - v = reshape(v, (rows, cols)) - GPUArraysCore.@allowscalar if rows < cols - (@test v * v' ≈ I(rows)) - else - (@test v' * v ≈ I(cols)) - end - end +using LinearAlgebra, Statistics, WeightInitializers, Test - @testset "Orthogonal Types $T" for T in (Float32, Float64) - !supports_fp64 && T == Float64 && continue +include("common.jl") - @test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T - @test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T - end - - @testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64) - !supports_fp64 && T == Float64 && continue - - @test orthogonal(rng, T, 3, 5) isa AbstractArray{T,2} - @test orthogonal(rng, T, 3, 5) isa arrtype{T,2} - - cl = orthogonal(rng) - display(cl) - @test cl(T, 3, 5) isa arrtype{T,2} - - cl = orthogonal(rng, T) - display(cl) - @test cl(3, 5) isa arrtype{T,2} - end - - @testset "Orthogonal Closure" begin - cl = orthogonal() - display(cl) - - # Sizes - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - @test size(cl(3, 4, 5)) == (3, 4, 5) - @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) - - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 - end - end -end - -@testitem "Sparse Initialization" setup = [SharedTestSetup] begin - using Statistics - - @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for ( - rng, arrtype, supports_fp64, backend - ) in RNGS_ARRTYPES - # sparse_init should yield an error for non 2-d dimensions - # sparse_init should yield no zero elements if sparsity < 0 - # sparse_init should yield all zero elements if sparsity > 1 - # sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for - # other sparsity values - # sparse_init should yield a kernel in its non-zero elements consistent with the std - # parameter - - @test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1) - @test_throws ArgumentError sparse_init(3, sparsity=0.1) - v = sparse_init(100, 100; sparsity=-0.1) - @test sum(v .== 0) == 0 - v = sparse_init(100, 100; sparsity=1.1) - @test sum(v .== 0) == length(v) - - for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)] - expected_zeros = ceil(Integer, n_in * sparsity) - v = sparse_init(n_in, n_out; sparsity=sparsity, std=σ) - @test all([sum(v[:, col] .== 0) == expected_zeros for col in 1:n_out]) - @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ - end - - @testset "sparse_init Type $T" for T in (Float16, Float32, Float64) - !supports_fp64 && T == Float64 && continue - - @test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T - end - - @testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64) - !supports_fp64 && T == Float64 && continue - - @test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T,2} - @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T,2} - - cl = sparse_init(rng; sparsity=0.5) - display(cl) - @test cl(T, 3, 5) isa arrtype{T,2} - - cl = sparse_init(rng, T; sparsity=0.5) - display(cl) - @test cl(3, 5) isa arrtype{T,2} - end - - @testset "sparse_init Closure" begin - cl = sparse_init(; sparsity=0.5) - display(cl) - - # Sizes - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 - end - end +@testset "Warning: truncated_normal" begin + @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \ + the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) end -@testitem "Basic Initializations" setup = [SharedTestSetup] begin - using LinearAlgebra, Statistics - +@testset "Basic Initializations" begin @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for ( rng, arrtype, supports_fp64, backend ) in RNGS_ARRTYPES @@ -426,32 +250,3 @@ end end end end - -@testitem "Kaiming Uniform: Complex" begin - using WeightInitializers, Test - - x = kaiming_uniform(ComplexF32, 1024, 1024) - @test eltype(x) == ComplexF32 - @test size(x) == (1024, 1024) - @test minimum(imag.(x)) < 0.0 -end - -@testitem "Initialization inside compile" begin - using Reactant, WeightInitializers, Test - - rrng = Reactant.ReactantRNG() - - @testset "Concrete: $(op)" for op in (zeros32, ones32) - gen_arr = op(rrng, 3, 4) - @test eltype(gen_arr) == Float32 - @test size(gen_arr) == (3, 4) - @test gen_arr isa Reactant.ConcreteRArray{Float32,2} - end - - @testset "Traced: $(op)" for op in (zeros32, ones32, rand32, randn32) - gen_arr = @jit op(rrng, 3, 4) - @test eltype(gen_arr) == Float32 - @test size(gen_arr) == (3, 4) - @test gen_arr isa Reactant.ConcreteRArray{Float32,2} - end -end diff --git a/lib/WeightInitializers/test/shared_testsetup.jl b/lib/WeightInitializers/test/common.jl similarity index 92% rename from lib/WeightInitializers/test/shared_testsetup.jl rename to lib/WeightInitializers/test/common.jl index d509fe6511..4b48419532 100644 --- a/lib/WeightInitializers/test/shared_testsetup.jl +++ b/lib/WeightInitializers/test/common.jl @@ -1,10 +1,8 @@ -@testsetup module SharedTestSetup - using GPUArrays, GPUArraysCore, Random, StableRNGs, LuxTestUtils GPUArraysCore.allowscalar(false) -const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) RNGS_ARRTYPES = [] if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" @@ -66,7 +64,3 @@ if LuxTestUtils.test_oneapi(BACKEND_GROUP) @assert BACKEND_GROUP == "all" "Expected oneAPI.functional() to be true" end end - -export StableRNG, RNGS_ARRTYPES, BACKEND_GROUP, GPUArrays - -end diff --git a/lib/WeightInitializers/test/complex.jl b/lib/WeightInitializers/test/complex.jl new file mode 100644 index 0000000000..286c60bb61 --- /dev/null +++ b/lib/WeightInitializers/test/complex.jl @@ -0,0 +1,8 @@ +using WeightInitializers, Test + +@testset "Kaiming Uniform: Complex" begin + x = kaiming_uniform(ComplexF32, 1024, 1024) + @test eltype(x) == ComplexF32 + @test size(x) == (1024, 1024) + @test minimum(imag.(x)) < 0.0 +end diff --git a/lib/WeightInitializers/test/identity.jl b/lib/WeightInitializers/test/identity.jl new file mode 100644 index 0000000000..48ccca0fd2 --- /dev/null +++ b/lib/WeightInitializers/test/identity.jl @@ -0,0 +1,40 @@ +using LinearAlgebra, WeightInitializers, Test + +@testset "Identity Initialization" begin + @testset "2D identity matrices" begin + # Square matrix should be identity + mat = identity_init(5, 5) + @test mat ≈ Matrix{Float32}(I, 5, 5) + @test diag(mat) == ones(Float32, 5) + # Check off-diagonal elements are zero + for i in 1:5, j in 1:5 + if i != j + @test mat[i, j] == 0.0f0 + end + end + + # Test with gain parameter + mat_gain = identity_init(4, 4; gain=2.5) + @test mat_gain ≈ 2.5f0 * Matrix{Float32}(I, 4, 4) + @test diag(mat_gain) == fill(2.5f0, 4) + + # Non-square matrices + mat_rect1 = identity_init(3, 5) + @test size(mat_rect1) == (3, 5) + @test diag(mat_rect1) == ones(Float32, 3) + @test mat_rect1[:, 4:5] == zeros(Float32, 3, 2) + + mat_rect2 = identity_init(5, 3) + @test size(mat_rect2) == (5, 3) + @test diag(mat_rect2) == ones(Float32, 3) + @test mat_rect2[4:5, :] == zeros(Float32, 2, 3) + end + + @testset "Non-identity sizes" begin + @test identity_init(2, 3)[:, end] == zeros(Float32, 2) + @test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2) + @test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3) + @test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3) + @test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3) + end +end diff --git a/lib/WeightInitializers/test/orthogonal.jl b/lib/WeightInitializers/test/orthogonal.jl new file mode 100644 index 0000000000..96562f8dcf --- /dev/null +++ b/lib/WeightInitializers/test/orthogonal.jl @@ -0,0 +1,75 @@ +using LinearAlgebra, WeightInitializers, Test +using GPUArraysCore + +include("common.jl") + +@testset "Orthogonal Initialization" begin + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for ( + rng, arrtype, supports_fp64, backend + ) in RNGS_ARRTYPES + # A matrix of dim = (m,n) with m > n should produce a QR decomposition. + # In the other case, the transpose should be taken to compute the QR decomposition. + if backend == "oneapi" || backend == "metal" # `qr` not implemented + @test_broken orthogonal(rng, 3, 5) isa arrtype{Float32,2} + continue + end + + for (rows, cols) in [(5, 3), (3, 5)] + v = orthogonal(rng, rows, cols) + GPUArraysCore.@allowscalar if rows < cols + (@test v * v' ≈ I(rows)) + else + (@test v' * v ≈ I(cols)) + end + end + + for mat in [(3, 4, 5), (2, 2, 5)] + v = orthogonal(rng, mat...) + cols = mat[end] + rows = div(prod(mat), cols) + v = reshape(v, (rows, cols)) + GPUArraysCore.@allowscalar if rows < cols + (@test v * v' ≈ I(rows)) + else + (@test v' * v ≈ I(cols)) + end + end + + @testset "Orthogonal Types $T" for T in (Float32, Float64) + !supports_fp64 && T == Float64 && continue + + @test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T + @test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T + end + + @testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64) + !supports_fp64 && T == Float64 && continue + + @test orthogonal(rng, T, 3, 5) isa AbstractArray{T,2} + @test orthogonal(rng, T, 3, 5) isa arrtype{T,2} + + cl = orthogonal(rng) + display(cl) + @test cl(T, 3, 5) isa arrtype{T,2} + + cl = orthogonal(rng, T) + display(cl) + @test cl(3, 5) isa arrtype{T,2} + end + + @testset "Orthogonal Closure" begin + cl = orthogonal() + display(cl) + + # Sizes + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + end +end diff --git a/lib/WeightInitializers/test/qa_tests.jl b/lib/WeightInitializers/test/qa.jl similarity index 85% rename from lib/WeightInitializers/test/qa_tests.jl rename to lib/WeightInitializers/test/qa.jl index 98c292fb38..93762d78e7 100644 --- a/lib/WeightInitializers/test/qa_tests.jl +++ b/lib/WeightInitializers/test/qa.jl @@ -1,11 +1,15 @@ -@testitem "Aqua: Quality Assurance" begin +using WeightInitializers, Test + +include("common.jl") + +@testset "Aqua: Quality Assurance" begin using Aqua Aqua.test_all(WeightInitializers; ambiguities=false) Aqua.test_ambiguities(WeightInitializers; recursive=false) end -@testitem "Explicit Imports: Quality Assurance" setup = [SharedTestSetup] begin +@testset "Explicit Imports: Quality Assurance" begin using ExplicitImports @test check_no_implicit_imports(WeightInitializers) === nothing @@ -24,7 +28,7 @@ end end end -@testitem "doctests: Quality Assurance" begin +@testset "doctests: Quality Assurance" begin using Documenter doctestexpr = :(using Random, WeightInitializers) diff --git a/lib/WeightInitializers/test/reactant.jl b/lib/WeightInitializers/test/reactant.jl new file mode 100644 index 0000000000..871da90868 --- /dev/null +++ b/lib/WeightInitializers/test/reactant.jl @@ -0,0 +1,19 @@ +using WeightInitializers, Test, Reactant + +@testset "Initialization inside compile" begin + rrng = Reactant.ReactantRNG() + + @testset "Concrete: $(op)" for op in (zeros32, ones32) + gen_arr = op(rrng, 3, 4) + @test eltype(gen_arr) == Float32 + @test size(gen_arr) == (3, 4) + @test gen_arr isa Reactant.ConcreteRArray{Float32,2} + end + + @testset "Traced: $(op)" for op in (zeros32, ones32, rand32, randn32) + gen_arr = @jit op(rrng, 3, 4) + @test eltype(gen_arr) == Float32 + @test size(gen_arr) == (3, 4) + @test gen_arr isa Reactant.ConcreteRArray{Float32,2} + end +end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl index f12ca93540..bb21d271ac 100644 --- a/lib/WeightInitializers/test/runtests.jl +++ b/lib/WeightInitializers/test/runtests.jl @@ -1,50 +1,22 @@ -using Pkg, ReTestItems, WeightInitializers -using InteractiveUtils, CPUSummary, LuxTestUtils +using Pkg, WeightInitializers, Test, ParallelTestRunner, LuxTestUtils -@info sprint(versioninfo) +parsed_args = parse_args(@isdefined(TEST_ARGS) ? TEST_ARGS : ARGS; custom=["BACKEND_GROUP"]) -function parse_test_args() - test_args_from_env = @isdefined(TEST_ARGS) ? TEST_ARGS : ARGS - test_args = Dict{String,String}() - for arg in test_args_from_env - if contains(arg, "=") - key, value = split(arg, "="; limit=2) - test_args[key] = value - end - end - @info "Parsed test args" test_args - return test_args -end - -const PARSED_TEST_ARGS = parse_test_args() - -const BACKEND_GROUP = lowercase(get(PARSED_TEST_ARGS, "BACKEND_GROUP", "All")) - -const EXTRA_PKGS = LuxTestUtils.packages_to_install(BACKEND_GROUP) +const BACKEND_GROUP = lowercase( + something(get(parsed_args.custom, "BACKEND_GROUP", nothing), "all") +) -if !isempty(EXTRA_PKGS) - @info "Installing Extra Packages for testing" EXTRA_PKGS = EXTRA_PKGS - Pkg.add(EXTRA_PKGS) - Base.retry_load_extensions() - Pkg.instantiate() -end +testsuite = find_tests(@__DIR__) +delete!(testsuite, "common") -const RETESTITEMS_NWORKERS = parse( - Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Int(CPUSummary.num_cores()), 4))) -) -const RETESTITEMS_NWORKER_THREADS = parse( - Int, - get( - ENV, - "RETESTITEMS_NWORKER_THREADS", - string(max(Int(CPUSummary.sys_threads()) ÷ RETESTITEMS_NWORKERS, 1)), - ), +total_jobs = min( + something(parsed_args.jobs, ParallelTestRunner.default_njobs()), length(keys(testsuite)) ) -withenv("BACKEND_GROUP" => BACKEND_GROUP) do - ReTestItems.runtests( - WeightInitializers; - nworkers=RETESTITEMS_NWORKERS, - nworker_threads=RETESTITEMS_NWORKER_THREADS, - ) +withenv( + "XLA_REACTANT_GPU_MEM_FRACTION" => 1 / (total_jobs + 0.1), + "XLA_REACTANT_GPU_PREALLOCATE" => false, + "BACKEND_GROUP" => BACKEND_GROUP, +) do + runtests(WeightInitializers, parsed_args; testsuite) end diff --git a/lib/WeightInitializers/test/sparse.jl b/lib/WeightInitializers/test/sparse.jl new file mode 100644 index 0000000000..bb2106b226 --- /dev/null +++ b/lib/WeightInitializers/test/sparse.jl @@ -0,0 +1,65 @@ +using Statistics, WeightInitializers, Test + +include("common.jl") + +@testset "Sparse Initialization" begin + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for ( + rng, arrtype, supports_fp64, backend + ) in RNGS_ARRTYPES + # sparse_init should yield an error for non 2-d dimensions + # sparse_init should yield no zero elements if sparsity < 0 + # sparse_init should yield all zero elements if sparsity > 1 + # sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for + # other sparsity values + # sparse_init should yield a kernel in its non-zero elements consistent with the std + # parameter + + @test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1) + @test_throws ArgumentError sparse_init(3, sparsity=0.1) + v = sparse_init(100, 100; sparsity=-0.1) + @test sum(v .== 0) == 0 + v = sparse_init(100, 100; sparsity=1.1) + @test sum(v .== 0) == length(v) + + for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)] + expected_zeros = ceil(Integer, n_in * sparsity) + v = sparse_init(n_in, n_out; sparsity=sparsity, std=σ) + @test all([sum(v[:, col] .== 0) == expected_zeros for col in 1:n_out]) + @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ + end + + @testset "sparse_init Type $T" for T in (Float16, Float32, Float64) + !supports_fp64 && T == Float64 && continue + + @test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T + end + + @testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64) + !supports_fp64 && T == Float64 && continue + + @test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T,2} + @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T,2} + + cl = sparse_init(rng; sparsity=0.5) + display(cl) + @test cl(T, 3, 5) isa arrtype{T,2} + + cl = sparse_init(rng, T; sparsity=0.5) + display(cl) + @test cl(3, 5) isa arrtype{T,2} + end + + @testset "sparse_init Closure" begin + cl = sparse_init(; sparsity=0.5) + display(cl) + + # Sizes + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + end +end diff --git a/lib/WeightInitializers/test/utils_tests.jl b/lib/WeightInitializers/test/utils.jl similarity index 83% rename from lib/WeightInitializers/test/utils_tests.jl rename to lib/WeightInitializers/test/utils.jl index 027fd6d217..1cbb8a2a83 100644 --- a/lib/WeightInitializers/test/utils_tests.jl +++ b/lib/WeightInitializers/test/utils.jl @@ -1,4 +1,6 @@ -@testitem "Utils.nfan" begin +using WeightInitializers, Test + +@testset "Utils.nfan" begin using WeightInitializers: Utils @test Utils.nfan() == (1, 1) # Fallback