Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU Compatibility Issue: Compilation Error with Complex-Valued Data in LuxCUDA Broadcasting Kernel #844

Open
RomanSahakyan03 opened this issue Apr 3, 2024 · 8 comments
Assignees
Labels

Comments

@RomanSahakyan03
Copy link

Bug Description

Summary

When attempting to solve a neural network optimization problem on a GPU using Lux and LuxCUDA packages in Julia, a GPU compilation error occurs.

Steps to Reproduce

  • Define a neural network architecture using Lux and LuxCUDA packages.
  • Set up the optimization problem with specified optimizer and solver.
  • Attempt to solve the optimization problem on a GPU.

Expected Behavior

The optimization problem should be solved without errors, utilizing GPU acceleration provided by the LuxCUDA package.
Observed Behavior

The GPU compilation of MethodInstance for broadcasting fails with a KernelError, specifically mentioning a non-bitstype argument issue.
Code Snippet

using Lux, LuxCUDA, ComponentArrays, Random

# Define neural network architecture
const gpud = gpu_device()
rng = Random.default_rng()
Random.seed!(rng, 0)

inner = 16
chain = Chain(Dense(1, inner, tanh; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)),
              Dense(inner, inner, tanh; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)), 
              Dense(inner, inner, tanh; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)), 
              Dense(inner, 9; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)))
ps = Lux.setup(rng, chain)[1]
ps = ps |> ComponentArray |> gpud .|> ComplexF64
ComponentVector{ComplexF64, CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(layer_1 = ViewAxis(1:32, Axis(weight = ViewAxis(1:16, ShapedAxis((16, 1))), bias = ViewAxis(17:32, ShapedAxis((16, 1))))), layer_2 = ViewAxis(33:304, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(305:576, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_4 = ViewAxis(577:729, Axis(weight = ViewAxis(1:144, ShapedAxis((9, 16))), bias = ViewAxis(145:153, ShapedAxis((9, 1))))))}}}(layer_1 = (weight = ComplexF64[0.9429705142974854 + 0.1339227557182312im; 1.5250688791275024 + 0.12390123307704926im; … ; 0.5579001307487488 - 0.35648801922798157im; 0.9500746726989746 - 0.20232219994068146im;;], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]), layer_2 = (weight = ComplexF64[0.059399593621492386 + 0.025146976113319397im 0.1949768215417862 + 0.24093444645404816im … 0.02936505898833275 - 0.1352502554655075im 0.5359262824058533 - 0.491843044757843im; -0.07353769242763519 + 0.050222259014844894im -0.23228807747364044 + 0.01972302421927452im … -0.1863224059343338 + 0.030169149860739708im -0.2124786078929901 - 0.04057123884558678im; … ; 0.04917571693658829 + 0.06531829386949539im -0.26813575625419617 - 0.24699832499027252im … -0.005230876617133617 + 0.021611899137496948im -0.1623590737581253 + 0.14148622751235962im; 0.3998381197452545 - 0.09549206495285034im 0.01471997331827879 - 0.27302247285842896im … -0.09034821391105652 + 0.11481619626283646im -0.5329245924949646 + 0.3032892346382141im], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]), layer_3 = (weight = ComplexF64[0.18369489908218384 - 0.17931848764419556im -0.4184981882572174 + 0.15965186059474945im … 0.22417707741260529 - 0.22444866597652435im 0.3134605288505554 - 0.005288226064294577im; 0.5319058299064636 - 0.12305065989494324im 0.02565431408584118 - 0.02762402780354023im … -0.11335651576519012 + 0.2669583559036255im -0.0010091445874422789 - 0.053010717034339905im; … ; -0.3982292413711548 - 0.006003747694194317im -0.29939648509025574 + 0.17847703397274017im … -0.012875470332801342 - 0.3082279860973358im -0.5564959049224854 + 0.09695551544427872im; 0.007936030626296997 - 0.2567330002784729im 0.11311032623052597 + 0.1972206085920334im … 0.02036339044570923 - 0.14611773192882538im -0.024891655892133713 + 0.17227661609649658im], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]), layer_4 = (weight = ComplexF64[-0.046117015182971954 + 0.09711457043886185im 0.5025700330734253 + 0.05446240305900574im … 0.2066519558429718 - 0.01681804470717907im 0.15362724661827087 + 0.24123860895633698im; -0.11880122870206833 - 0.2789801061153412im -0.08881326764822006 + 0.14416104555130005im … 0.34971800446510315 + 0.02146727591753006im 0.10826357454061508 - 0.021323617547750473im; … ; -0.15876266360282898 - 0.6521790027618408im 0.04549488052725792 + 0.018977994099259377im … -0.04921087995171547 + 0.2560370862483978im -0.23153409361839294 - 0.29215309023857117im; -0.13698288798332214 - 0.28654682636260986im 0.03768850117921829 + 0.06687548756599426im … -0.4321778416633606 + 0.4295826852321625im -0.0034131575375795364 - 0.45368692278862im], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]))
opt = Adam(0.01)
alg = NNODE(chain, opt, ps; strategy = StochasticTraining(300,30000))
SciMLBase.allowscomplex(::NNODE) = true

# Attempt to solve the problem
sol = solve(problem, alg, verbose = true, maxiters = 1000, saveat = 0.001)
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::CUDA.CuKernelContext, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{Vector{ComplexF64}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Float64, Float64}}, Base.Broadcast.Extruded{CuDeviceVector{ComplexF64, 1}, Tuple{Bool}, Tuple{Int64}}}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{Vector{ComplexF64}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Float64, Float64}}, Base.Broadcast.Extruded{CuDeviceVector{ComplexF64, 1}, Tuple{Bool}, Tuple{Int64}}}}} which is not isbits.
    .1 is of type Base.Broadcast.Extruded{Vector{ComplexF64}, Tuple{Bool}, Tuple{Int64}} which is not isbits.
      .x is of type Vector{ComplexF64} which is not isbits.

Additional Information

  • Environment: Julia 1.10.2, Lux v0.5.19, LuxCUDA v0.3.2, ComponentArrays v0.15.10
  • The error message specifically points to a non-bitstype argument passed to the broadcasting kernel.
  • This issue prevents the successful execution of the neural network optimization problem on a GPU, limiting performance and efficiency.
@IromainI
Copy link

IromainI commented Apr 5, 2024

I am also trying to solve the problem of optimization a neural network on a GPU (LuxCUDA in Julia) and I also get the same GPU compilation error

@DrEmilGazazyan
Copy link

I have the same bug with the GPU (LuxCUDA in Julia) and encounter the same GPU compilation error.

@RomanSahakyan03
Copy link
Author

@sathvikbhagavan how can I assist you?

@RomanSahakyan03
Copy link
Author

@sathvikbhagavan ?

@sathvikbhagavan
Copy link
Member

@RomanSahakyan03, apologies for the late reply. I will try to finish it up by this weekend.

@RomanSahakyan03
Copy link
Author

@sathvikbhagavan it's ok. Thank for your efforts! If you need assist. I can help

@RomanSahakyan03
Copy link
Author

@sathvikbhagavan what about now? Did you finish it?

@sathvikbhagavan
Copy link
Member

Hi @RomanSahakyan03, I have a draft PR #866 for fixing this, but currently running into some issues. Hopefully would get resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants