From c5f0a50853d4ebd91c8b352f634d0dda5c5ad5b9 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Sun, 4 Jan 2026 21:40:24 +0100 Subject: [PATCH 1/4] Migrate to `usize` indexing --- Cargo.lock | 192 +++++++----------- Cargo.toml | 12 +- .../tests/common/autodiff.rs | 2 +- .../burn-backend-tests/tests/common/tensor.rs | 2 +- .../tests/cubecl/mask_fill.rs | 2 +- crates/burn-cubecl-fusion/src/base.rs | 13 +- .../src/engine/codegen/io.rs | 189 +++++++++-------- .../src/engine/codegen/ir.rs | 114 ++++++----- .../src/engine/codegen/kernel.rs | 115 +++++------ .../src/engine/codegen/view.rs | 54 ++--- crates/burn-cubecl-fusion/src/engine/fuser.rs | 10 +- .../src/engine/launch/executor.rs | 23 +-- .../src/engine/launch/input.rs | 10 +- .../src/engine/launch/output.rs | 16 +- .../src/engine/launch/plan.rs | 16 +- .../src/engine/launch/runner.rs | 6 +- .../src/engine/launch/vectorization/base.rs | 81 ++++---- .../engine/launch/vectorization/planner.rs | 35 ++-- .../src/engine/trace/base.rs | 24 +-- .../src/engine/trace/block.rs | 21 +- .../src/engine/trace/fuser.rs | 4 +- .../src/optim/elemwise/optimization.rs | 10 +- .../src/optim/matmul/args.rs | 80 ++++---- .../src/optim/matmul/optimization.rs | 6 +- .../src/optim/reduce/args.rs | 39 ++-- .../src/optim/reduce/optimization.rs | 20 +- crates/burn-cubecl/src/kernel/binary.rs | 4 +- crates/burn-cubecl/src/kernel/binary_int.rs | 6 +- crates/burn-cubecl/src/kernel/clamp.rs | 2 +- crates/burn-cubecl/src/kernel/comparison.rs | 6 +- crates/burn-cubecl/src/kernel/contiguous.rs | 10 +- .../kernel/conv/conv_transpose2d/col2im.rs | 48 ++--- .../conv/conv_transpose2d/transpose_direct.rs | 36 ++-- .../src/kernel/conv/conv_transpose3d.rs | 52 ++--- .../src/kernel/conv/deform_conv2d.rs | 68 +++---- .../kernel/conv/deform_conv_transpose2d.rs | 91 ++++----- crates/burn-cubecl/src/kernel/conv/direct.rs | 62 +++--- crates/burn-cubecl/src/kernel/conv/im2col.rs | 6 +- crates/burn-cubecl/src/kernel/index/flip.rs | 5 +- crates/burn-cubecl/src/kernel/index/gather.rs | 14 +- .../src/kernel/index/repeat_dim.rs | 4 +- .../burn-cubecl/src/kernel/index/scatter.rs | 12 +- crates/burn-cubecl/src/kernel/index/select.rs | 6 +- .../src/kernel/index/select_assign.rs | 13 +- crates/burn-cubecl/src/kernel/index/slice.rs | 46 ++--- .../src/kernel/index/slice_assign.rs | 59 +++--- .../src/kernel/interpolate/bicubic.rs | 12 +- .../src/kernel/interpolate/bilinear.rs | 16 +- .../src/kernel/interpolate/nearest.rs | 12 +- .../kernel/interpolate/nearest_backward.rs | 16 +- crates/burn-cubecl/src/kernel/mask/base.rs | 2 +- .../burn-cubecl/src/kernel/mask/mask_fill.rs | 6 +- .../src/kernel/pool/adaptive_avg_pool2d.rs | 4 +- .../pool/adaptive_avg_pool2d_backward.rs | 4 +- .../burn-cubecl/src/kernel/pool/avg_pool2d.rs | 6 +- .../src/kernel/pool/avg_pool2d_backward.rs | 15 +- .../burn-cubecl/src/kernel/pool/max_pool2d.rs | 12 +- .../src/kernel/pool/max_pool2d_backward.rs | 6 +- crates/burn-cubecl/src/kernel/pool/pool2d.rs | 18 +- crates/burn-cubecl/src/kernel/utils.rs | 17 +- crates/burn-cubecl/src/ops/base.rs | 9 +- crates/burn-cubecl/src/ops/bool_ops.rs | 2 +- crates/burn-cubecl/src/ops/float_ops.rs | 1 - crates/burn-cubecl/src/ops/int_ops.rs | 2 +- crates/burn-cubecl/src/ops/numeric.rs | 8 +- crates/burn-cubecl/src/template/base.rs | 5 + crates/burn-cubecl/src/tensor/base.rs | 6 +- .../hardware_accelerated.rs | 149 +++++++------- .../cube/connected_components/prefix_sum.rs | 73 +++---- examples/custom-cubecl-kernel/src/kernel.rs | 6 +- 70 files changed, 996 insertions(+), 1057 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9f5b148212..45ccf870d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -337,9 +337,9 @@ dependencies = [ [[package]] name = "axum" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b098575ebe77cb6d14fc7f32749631a6e44edbef6b796f89b020e99ba20d425" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ "axum-core", "base64 0.22.1", @@ -373,9 +373,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" dependencies = [ "bytes", "futures-core", @@ -1415,9 +1415,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.50" +version = "1.2.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f50d563227a1c37cc0a263f64eca3334388c01c5e4c4861a9def205c614383c" +checksum = "7a0aeaff4ff1a90589618835a598e545176939b97874f7abc7851caa0618f203" dependencies = [ "find-msvc-tools", "jobserver", @@ -1928,7 +1928,7 @@ dependencies = [ "crossterm_winapi", "document-features", "parking_lot", - "rustix 1.1.2", + "rustix 1.1.3", "winapi", ] @@ -1992,8 +1992,6 @@ dependencies = [ [[package]] name = "cubecl" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e30a2bee1f72e79761c60dfe3397ed19113ab87ea693aa653a1efb4c35ca4e9" dependencies = [ "cubecl-core", "cubecl-cpu", @@ -2008,8 +2006,6 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e729945e49d84f1690c89d60be85116f8481cba80a131ed7e1dee62cc874dbf" dependencies = [ "backtrace", "bytemuck", @@ -2045,8 +2041,6 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2e640727aea13ac0bcbc679ad42ca4db04ca6bd2f3b6007c72540e2cada3e1" dependencies = [ "bitflags 2.10.0", "bytemuck", @@ -2071,8 +2065,6 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "610349ded854522339df2f3a0135d9a2dc1a953fde4356247b562df566a0dff9" dependencies = [ "bytemuck", "cubecl-common", @@ -2088,8 +2080,6 @@ dependencies = [ [[package]] name = "cubecl-cpu" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514943c585e7e5b0515ac033498d1015fc67317baf592980ca04faf970a414f4" dependencies = [ "bytemuck", "cubecl-common", @@ -2110,8 +2100,6 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab39a1cb9ade4d960418200810f272ee24aebd5c089232e0ae1e77b9738a1c87" dependencies = [ "bytemuck", "cubecl-common", @@ -2129,8 +2117,6 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc6b3c6740d63118ea5ef89dea287f976d8949595e10bdf03875a266ec54e088" dependencies = [ "bytemuck", "cubecl-common", @@ -2143,6 +2129,7 @@ dependencies = [ "log", "paste", "serde", + "tracing", ] [[package]] @@ -2158,8 +2145,6 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d1214a2446701f2f4b7ce8330c9527fc5c1c844c1b62df673fd49dd498840d2" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -2179,8 +2164,6 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e04d494dcf9649d68a27b33a698cf27fb6bc5fb40916272af02f562a3deeaca" dependencies = [ "cubecl-common", "darling 0.21.3", @@ -2195,8 +2178,6 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "143b95de65bd26c40a3e6016d6d1278e826fab960c61930959fe7ef7dec863ec" dependencies = [ "darling 0.21.3", "proc-macro2", @@ -2207,8 +2188,6 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1230723615eb57d80f20d5bf1829d24d8fe697ad8df837d129b1dbe004970e93" dependencies = [ "cubecl-common", "cubecl-core", @@ -2225,8 +2204,6 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aef53aa0a6dadf74534af9ca154e1bda2f5fdeb4ea56270e80266a1409a9a1ac" dependencies = [ "async-channel", "bytemuck", @@ -2255,8 +2232,6 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7584eb46fbce3edff10e65bb98ff761dcc8aeac1d79595b92bfdab55fabc53b9" dependencies = [ "bitflags 2.10.0", "cubecl-common", @@ -2271,8 +2246,6 @@ dependencies = [ [[package]] name = "cubecl-std" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "465f7752da8594cdd22c82edfcf33041937b3c8e0d6fad9483753088bf1b6a39" dependencies = [ "cubecl-common", "cubecl-core", @@ -2289,8 +2262,6 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.9.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bb0d584467a0e8aa4d983cae126748241aaeb3aa72a23a3739cf25886767999" dependencies = [ "ash", "async-channel", @@ -2309,14 +2280,13 @@ dependencies = [ "log", "sanitize-filename", "tracel-ash", + "tracing", "wgpu", ] [[package]] name = "cubek" version = "0.1.0-pre.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa51b739e092523e8e32692f627cddc17a852746f184b494654d0b50adc1c19" dependencies = [ "cubecl", "cubek-attention", @@ -2330,8 +2300,6 @@ dependencies = [ [[package]] name = "cubek-attention" version = "0.1.0-pre.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac0855d864e2fb16e7da8b7221d8a39f122aa80f7bd53ccf5bd49b2d65a8aa05" dependencies = [ "bytemuck", "cubecl", @@ -2345,8 +2313,6 @@ dependencies = [ [[package]] name = "cubek-convolution" version = "0.1.0-pre.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "111b5d1af7c5d703dc96712436ef3210766768cfc1524dd5c56a844fc506630c" dependencies = [ "bytemuck", "cubecl", @@ -2361,8 +2327,6 @@ dependencies = [ [[package]] name = "cubek-matmul" version = "0.1.0-pre.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c688e4171bf324cb8b0aa776304043ccb664be700b8e5dbddb61e5e6c573b18" dependencies = [ "bytemuck", "cubecl", @@ -2374,8 +2338,6 @@ dependencies = [ [[package]] name = "cubek-quant" version = "0.1.0-pre.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9f495bfe2b26f37893e4fa95ca703b087dcd134add8348635a59cc530475e86" dependencies = [ "cubecl", "cubecl-common", @@ -2386,8 +2348,6 @@ dependencies = [ [[package]] name = "cubek-random" version = "0.1.0-pre.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df73f4d82dbd40588643d36363f47855c585842c93fe394105fc7554b04e08f7" dependencies = [ "cubecl", "cubecl-common", @@ -2400,8 +2360,6 @@ dependencies = [ [[package]] name = "cubek-reduce" version = "0.1.0-pre.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdb50d7aeb02090f6019913fb63c759de93f1b1d30b5ed096373dd9c68551e40" dependencies = [ "cubecl", "half", @@ -2692,18 +2650,18 @@ dependencies = [ [[package]] name = "derive_more" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618" +checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" dependencies = [ "derive_more-impl", ] [[package]] name = "derive_more-impl" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b" +checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" dependencies = [ "convert_case 0.10.0", "proc-macro2", @@ -3221,9 +3179,9 @@ dependencies = [ [[package]] name = "find-msvc-tools" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" +checksum = "645cbb3a84e60b7531617d5ae4e57f7e27308f6445f5abf653209ea76dec8dff" [[package]] name = "fixedbitset" @@ -3347,7 +3305,7 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4" dependencies = [ - "rustix 1.1.2", + "rustix 1.1.3", "windows-sys 0.59.0", ] @@ -3803,9 +3761,9 @@ dependencies = [ [[package]] name = "gix-trace" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d3f59a8de2934f6391b6b3a1a7654eae18961fcb9f9c843533fed34ad0f3457" +checksum = "edd971cd6961fb1ebb29a0052a4ab04d8498dbf363c122e137b04753a3bbb5c3" [[package]] name = "gix-utils" @@ -4399,7 +4357,7 @@ dependencies = [ "rgb", "tiff", "zune-core 0.5.0", - "zune-jpeg 0.5.7", + "zune-jpeg 0.5.8", ] [[package]] @@ -4594,15 +4552,15 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jiff" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49cce2b81f2098e7e3efc35bc2e0a6b7abec9d34128283d7a26fa8f32a6dbb35" +checksum = "a87d9b8105c23642f50cbbae03d1f75d8422c5cb98ce7ee9271f7ff7505be6b8" dependencies = [ "jiff-static", "log", @@ -4613,9 +4571,9 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "980af8b43c3ad5d8d349ace167ec8170839f753a42d233ba19e08afe1850fa69" +checksum = "b787bebb543f8969132630c51fd0afab173a86c6abae56ff3b9e5e3e3f9f6e58" dependencies = [ "proc-macro2", "quote", @@ -4738,13 +4696,13 @@ checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libredox" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df15f6eac291ed1cf25865b1ee60399f57e7c227e7f51bdbd4c5270396a9ed50" +checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" dependencies = [ "bitflags 2.10.0", "libc", - "redox_syscall 0.6.0", + "redox_syscall 0.7.0", ] [[package]] @@ -4760,9 +4718,9 @@ dependencies = [ [[package]] name = "libz-rs-sys" -version = "0.5.4" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15413ef615ad868d4d65dce091cb233b229419c7c0c4bcaa746c0901c49ff39c" +checksum = "c10501e7805cee23da17c7790e59df2870c0d4043ec6d03f67d31e2b53e77415" dependencies = [ "zlib-rs", ] @@ -5309,9 +5267,9 @@ dependencies = [ [[package]] name = "ntapi" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +checksum = "c70f219e21142367c70c0b30c6a9e3a14d55b4d12a204d897fbec83a0363f081" dependencies = [ "winapi", ] @@ -6546,9 +6504,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.11.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" dependencies = [ "serde", ] @@ -6639,9 +6597,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.103" +version = "1.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +checksum = "9695f8df41bb4f3d222c95a67532365f569318332d03d5f3f67f37b20e6ebdf0" dependencies = [ "unicode-ident", ] @@ -7165,9 +7123,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec96166dafa0886eb81fe1c0a388bece180fbef2135f97c1e2cf8302e74b43b5" +checksum = "49f3fe0889e69e2ae9e41f4d6c4c0181701d00e4697b356fb1f74173a5e0ee27" dependencies = [ "bitflags 2.10.0", ] @@ -7232,9 +7190,9 @@ checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" [[package]] name = "reqwest" -version = "0.12.26" +version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b4c14b2d9afca6a60277086b0cc6a6ae0b568f6f7916c943a8cdc79f8be240f" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64 0.22.1", "bytes", @@ -7299,22 +7257,19 @@ dependencies = [ [[package]] name = "rmp" -version = "0.8.14" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" dependencies = [ - "byteorder", "num-traits", - "paste", ] [[package]] name = "rmp-serde" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" dependencies = [ - "byteorder", "rmp", "serde", ] @@ -7415,9 +7370,9 @@ dependencies = [ [[package]] name = "rustix" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" dependencies = [ "bitflags 2.10.0", "errno", @@ -7482,9 +7437,9 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" [[package]] name = "safetensors" @@ -7687,15 +7642,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.145" +version = "1.0.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +checksum = "3084b546a1dd6289475996f182a22aba973866ea8e8b02c51d9f46b1336a22da" dependencies = [ "itoa", "memchr", - "ryu", "serde", "serde_core", + "zmij", ] [[package]] @@ -7845,10 +7800,11 @@ dependencies = [ [[package]] name = "signal-hook-registry" -version = "1.4.7" +version = "1.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7664a098b8e616bdfcc2dc0e9ac44eb231eedf41db4e9fe95d8d32ec728dedad" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" dependencies = [ + "errno", "libc", ] @@ -8250,14 +8206,14 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.23.0" +version = "3.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" dependencies = [ "fastrand", "getrandom 0.3.4", "once_cell", - "rustix 1.1.2", + "rustix 1.1.3", "windows-sys 0.61.2", ] @@ -8276,7 +8232,7 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b8cb979cb11c32ce1603f8137b22262a9d131aaa5c37b5678025f22b8becd0" dependencies = [ - "rustix 1.1.2", + "rustix 1.1.3", "windows-sys 0.60.2", ] @@ -8595,9 +8551,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.9.9+spec-1.0.0" +version = "0.9.10+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb5238e643fc34a1d5d7e753e1532a91912d74b63b92b3ea51fde8d1b7bc79dd" +checksum = "0825052159284a1a8b4d6c0c86cbc801f2da5afd2b225fa548c72f2e74002f48" dependencies = [ "indexmap", "serde_core", @@ -8610,9 +8566,9 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.7.4+spec-1.0.0" +version = "0.7.5+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe3cea6b2aa3b910092f6abd4053ea464fab5f9c170ba5e9a6aead16ec4af2b6" +checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" dependencies = [ "serde_core", ] @@ -8631,18 +8587,18 @@ dependencies = [ [[package]] name = "toml_parser" -version = "1.0.5+spec-1.0.0" +version = "1.0.6+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c03bee5ce3696f31250db0bbaff18bc43301ce0e8db2ed1f07cbb2acf89984c" +checksum = "a3198b4b0a8e11f09dd03e133c0280504d0801269e9afa46362ffde1cbeebf44" dependencies = [ "winnow", ] [[package]] name = "toml_writer" -version = "1.0.5+spec-1.0.0" +version = "1.0.6+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9cd6190959dce0994aa8970cd32ab116d1851ead27e866039acaf2524ce44fa" +checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" [[package]] name = "tonic" @@ -10107,7 +10063,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" dependencies = [ "libc", - "rustix 1.1.2", + "rustix 1.1.3", ] [[package]] @@ -10350,9 +10306,15 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.5.4" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40990edd51aae2c2b6907af74ffb635029d5788228222c4bb811e9351c0caad3" + +[[package]] +name = "zmij" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51f936044d677be1a1168fae1d03b583a285a5dd9d8cbf7b24c23aa1fc775235" +checksum = "e6d6085d62852e35540689d1f97ad663e3971fc19cf5eceab364d62c646ea167" [[package]] name = "zopfli" @@ -10445,9 +10407,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d915729b0e7d5fe35c2f294c5dc10b30207cc637920e5b59077bfa3da63f28" +checksum = "e35aee689668bf9bd6f6f3a6c60bb29ba1244b3b43adfd50edd554a371da37d5" dependencies = [ "zune-core 0.5.0", ] diff --git a/Cargo.toml b/Cargo.toml index 079235390a..001d54cd74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -184,13 +184,13 @@ portable-atomic-util = { version = "0.2.4", features = ["alloc"] } # cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b14b1a30ea45bd7357c55be1811f9e435224cb71" } # cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "a08eb35be38bb56f174408371d200ede3137847a" } ### For local development. ### -# cubecl = { path = "../../cubecl/crates/cubecl", default-features = false } -# cubecl-common = { path = "../../cubecl/crates/cubecl-common", default-features = false } -# cubek = { path = "../../cubek/crates/cubek", default-features = false } +cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +cubek = { path = "../cubek/crates/cubek", default-features = false } ### For the release. ### -cubecl = { version = "=0.9.0-pre.6", default-features = false } -cubecl-common = { version = "=0.9.0-pre.6", default-features = false } -cubek = { version = "=0.1.0-pre.1", default-features = false } +# cubecl = { version = "=0.9.0-pre.6", default-features = false } +# cubecl-common = { version = "=0.9.0-pre.6", default-features = false } +# cubek = { version = "=0.1.0-pre.1", default-features = false } ### For xtask crate ### tracel-xtask = { version = "=2.2.1" } diff --git a/crates/burn-backend-tests/tests/common/autodiff.rs b/crates/burn-backend-tests/tests/common/autodiff.rs index 21b6b858f7..1e0ea8d54a 100644 --- a/crates/burn-backend-tests/tests/common/autodiff.rs +++ b/crates/burn-backend-tests/tests/common/autodiff.rs @@ -30,5 +30,5 @@ test_float_elem_variant!( bf16, burn_tensor::bf16, "../autodiff/mod.rs", - ["vulkan", "metal"] // ["cuda", "rocm"] TODO + ["metal"] // ["cuda", "rocm"] TODO, ["vulkan"] only supports bf16 for matmul ); diff --git a/crates/burn-backend-tests/tests/common/tensor.rs b/crates/burn-backend-tests/tests/common/tensor.rs index f9bf3b64d2..5f0c1a2fe2 100644 --- a/crates/burn-backend-tests/tests/common/tensor.rs +++ b/crates/burn-backend-tests/tests/common/tensor.rs @@ -34,5 +34,5 @@ test_float_elem_variant!( bf16, burn_tensor::bf16, "../tensor/float/mod.rs", - ["vulkan", "metal"] // ["cuda", "rocm"] TODO + ["metal"] // ["cuda", "rocm"] TODO, ["vulkan"] only supports bf16 for matmul ); diff --git a/crates/burn-backend-tests/tests/cubecl/mask_fill.rs b/crates/burn-backend-tests/tests/cubecl/mask_fill.rs index 730bc10777..d36e0c23ea 100644 --- a/crates/burn-backend-tests/tests/cubecl/mask_fill.rs +++ b/crates/burn-backend-tests/tests/cubecl/mask_fill.rs @@ -2,7 +2,7 @@ use super::*; use burn_cubecl::kernel::{MaskFillStrategy, mask_fill}; use burn_tensor::Tolerance; use burn_tensor::{Bool, Distribution, Element, Tensor, TensorPrimitive, backend::Backend}; -use cubecl::std::scalar::InputScalar; +use cubecl::prelude::InputScalar; #[test] fn mask_fill_should_match_reference_backend() { diff --git a/crates/burn-cubecl-fusion/src/base.rs b/crates/burn-cubecl-fusion/src/base.rs index 92ada55fdf..376d622e71 100644 --- a/crates/burn-cubecl-fusion/src/base.rs +++ b/crates/burn-cubecl-fusion/src/base.rs @@ -1,12 +1,15 @@ use burn_fusion::stream::Context; use burn_std::{DType, quantization::QParamTensor}; -use cubecl::quant::scheme::{QuantParam, QuantScheme}; use cubecl::{ CubeElement, Runtime, client::ComputeClient, ir::ElemType, prelude::{TensorArg, TensorHandleRef}, }; +use cubecl::{ + ir::LineSize, + quant::scheme::{QuantParam, QuantScheme}, +}; use std::marker::PhantomData; /// Defines a fallback operation when fusion isn't possible. @@ -73,7 +76,11 @@ impl CubeFusionHandle { } } /// Return the reference to a tensor argument. - pub fn as_tensor_arg<'a>(&'a self, shape: &'a [usize], vectorisation: u8) -> TensorArg<'a, R> { + pub fn as_tensor_arg<'a>( + &'a self, + shape: &'a [usize], + line_size: LineSize, + ) -> TensorArg<'a, R> { let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape); unsafe { @@ -81,7 +88,7 @@ impl CubeFusionHandle { handle.handle, handle.strides, handle.shape, - vectorisation, + line_size, self.dtype.size(), ) } diff --git a/crates/burn-cubecl-fusion/src/engine/codegen/io.rs b/crates/burn-cubecl-fusion/src/engine/codegen/io.rs index 4988b5f88a..66901430dd 100644 --- a/crates/burn-cubecl-fusion/src/engine/codegen/io.rs +++ b/crates/burn-cubecl-fusion/src/engine/codegen/io.rs @@ -18,11 +18,11 @@ pub enum Transform { /// /// This enum entry contains a sequence of [arguments](FuseArg) that points to global scalars representing the /// new shape for the current tensor. - Reshape(Sequence), + Reshape(Vec), /// Two axes have been swapped on a tensor. /// /// The enum entry contains those two axes. - SwapDims(u32, u32), + SwapDims(usize, usize), } /// Reads the value from the [arg](FuseArg) and cast it to the generic cube primitive. @@ -39,7 +39,7 @@ pub fn read( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, - ref_pos: u32, + ref_pos: usize, #[comptime] arg: FuseArg, #[comptime] config: &FuseBlockConfig, ) -> Line { @@ -48,7 +48,7 @@ pub fn read( let global = inputs.tensors.index(pos); let line_size = global.tensor.line_size(); - if comptime![!global.broadcasted && line_size != config.width as u32] { + if comptime![!global.broadcasted && line_size != config.width] { read_input_aligned(inputs, locals, pos, ref_pos, layout, config, None) } else { read_input(inputs, locals, pos, ref_pos, layout, config, None) @@ -90,7 +90,7 @@ pub fn read( let global = inputs.tensors.index(pos); let line_size = global.tensor.line_size(); - if comptime![!broadcasted && line_size != config.width as u32] { + if comptime![!broadcasted && line_size != config.width] { read_input_aligned( inputs, locals, @@ -123,7 +123,7 @@ pub fn read( let global = inputs.tensors.index(pos); let line_size = global.tensor.line_size(); - if comptime![!broadcasted && line_size != config.width as u32] { + if comptime![!broadcasted && line_size != config.width] { read_input_aligned( inputs, locals, @@ -158,15 +158,15 @@ pub fn read( fn index_offset_with_quant_layout( tensor: &GlobalTensor, locals: &LocalArgs, - index: u32, - #[comptime] rank: u32, + index: usize, + #[comptime] rank: usize, #[comptime] scheme: QuantScheme, -) -> u32 { - let (start, end) = comptime![(0u32, rank - 1)]; - let num_quants = comptime!(scheme.num_quants() as u32); +) -> usize { + let (start, end) = (0, rank - 1); + let num_quants = scheme.num_quants(); let offset_ref = index * locals.ref_line_size; - let mut offset = 0u32; + let mut offset = 0; #[unroll] for i in start..end { @@ -192,7 +192,7 @@ fn index_offset_with_quant_layout( pub fn read_quantized( inputs: &GlobalArgs, locals: &LocalArgs, - ref_pos: u32, + ref_pos: usize, #[comptime] arg: FuseArg, #[comptime] config: &FuseBlockConfig, #[comptime] scheme: QuantScheme, @@ -224,9 +224,9 @@ pub fn read_scalar(inputs: &GlobalArgs, #[comptime] arg: FuseA /// Reads a global scalar that is used as a reshape position. #[cube] -pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> u32 { +pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> usize { match arg { - FuseArg::ScalarShape(pos) => *inputs.reshapes.index(pos), + FuseArg::ScalarShape(pos) => inputs.reshapes[pos], _ => comptime![panic!("Not a scalar shape")], } } @@ -236,8 +236,8 @@ pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> u32 { pub fn read_input( inputs: &GlobalArgs, locals: &LocalArgs, - #[comptime] pos: u32, - ref_pos: u32, + #[comptime] pos: usize, + ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, @@ -255,9 +255,9 @@ pub fn read_input( #[cube] pub fn read_input_window( inputs: &GlobalArgs, - #[comptime] pos: u32, - start: u32, - end: u32, + #[comptime] pos: usize, + start: usize, + end: usize, ) -> Slice { let tensor = inputs.tensors.index(pos); let slice = tensor.tensor.slice(start, end); @@ -266,7 +266,7 @@ pub fn read_input_window( /// Returns the input as a slice. #[cube] -pub fn input_as_slice(inputs: &GlobalArgs, #[comptime] pos: u32) -> Slice { +pub fn input_as_slice(inputs: &GlobalArgs, #[comptime] pos: usize) -> Slice { let tensor = inputs.tensors.index(pos); let slice = tensor.tensor.to_slice(); slice.try_cast_unchecked() @@ -276,11 +276,11 @@ pub fn input_as_slice(inputs: &GlobalArgs, #[comptime] pos: u3 #[cube] pub fn input_as_scales_view( inputs: &GlobalArgs, - #[comptime] pos: u32, - #[comptime] tensor_pos: u32, + #[comptime] pos: usize, + #[comptime] tensor_pos: usize, #[comptime] level: QuantLevel, #[comptime] config: &FuseBlockConfig, -) -> View { +) -> View { set_polyfill_typed::>(); let tensor = inputs.tensors.index(tensor_pos); let scales = inputs.tensors.index(pos); @@ -289,7 +289,7 @@ pub fn input_as_scales_view( let layout = match level { QuantLevel::Tensor => ScalesLayout::new_PerTensor(PerTensorLayout::new(tensor_len)), QuantLevel::Block(block_size) => { - let block_size = comptime![block_size.to_dim_vec(rank as usize)]; + let block_size = comptime![block_size.to_dim_vec(rank)]; let mut tensor_shape = Sequence::new(); let mut scales_strides = Sequence::new(); #[unroll] @@ -308,7 +308,7 @@ pub fn input_as_scales_view( ScalesLayout::new_BlockScaled(layout) } }; - View::new::, u32>(&scales.tensor.to_slice().try_cast_unchecked(), layout) + View::new::, usize>(&scales.tensor.to_slice().try_cast_unchecked(), layout) } /// Reads the input tensor aligned. @@ -316,22 +316,22 @@ pub fn input_as_scales_view( pub fn read_input_aligned( inputs: &GlobalArgs, locals: &LocalArgs, - #[comptime] pos: u32, - ref_pos: u32, + #[comptime] pos: usize, + ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, ) -> Line { - let mut result: Line = Line::::empty(comptime![config.width as u32]); + let mut result: Line = Line::::empty(config.width); let tensor = inputs.tensors.index(pos); - match comptime![transform.clone()] { + match transform.clone() { Some(Transform::Reshape(shape)) => { // Very brute force, not really efficient, but not easy to optimize and not a very // frequent workflow. - let ref_pos = ref_pos * comptime![config.width as u32]; + let ref_pos = ref_pos * config.width; #[unroll] - for i in 0u32..comptime!(config.width as u32) { + for i in 0..config.width { let index = reshaped_index( inputs, locals, @@ -346,11 +346,11 @@ pub fn read_input_aligned( Some(Transform::SwapDims(dim1, dim2)) => { let offset = get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform); - let i = comptime![swap_dims_transform(&(config.rank - 1), (dim1, dim2))]; - let stride = tensor.tensor.stride(comptime![i]); + let i = comptime![swap_dims_transform(config.rank - 1, (dim1, dim2))]; + let stride = tensor.tensor.stride(i); #[unroll] - for i in 0u32..comptime!(config.width as u32) { + for i in 0..config.width { let index = offset + i * stride; result[i] = C::cast_from(tensor.tensor[index][0]) } @@ -358,9 +358,9 @@ pub fn read_input_aligned( None => { let offset = get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform); - let stride = tensor.tensor.stride(comptime![config.rank - 1]); + let stride = tensor.tensor.stride(config.rank - 1); #[unroll] - for i in 0u32..comptime!(config.width as u32) { + for i in 0..config.width { let index = offset + i * stride; result[i] = C::cast_from(tensor.tensor[index][0]) } @@ -377,11 +377,11 @@ pub fn get_offset_aligned( inputs: &GlobalArgs, locals: &LocalArgs, tensor: &GlobalTensor, - ref_pos: u32, + ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, -) -> u32 { +) -> usize { match layout { LayoutInfo::SameAsRef | LayoutInfo::IsRef => { (ref_pos * locals.ref_line_size) / tensor.tensor.line_size() @@ -404,8 +404,8 @@ pub fn read_output( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, - pos: u32, - ref_pos: u32, + #[comptime] pos: usize, + ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, ) -> Line { @@ -424,7 +424,7 @@ pub fn write( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - ref_pos: u32, + ref_pos: usize, value: Line, #[comptime] arg: FuseArg, #[comptime] config: &FuseBlockConfig, @@ -468,11 +468,11 @@ pub(crate) fn global_offset( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, - index: u32, + index: usize, #[comptime] arg: FuseArg, - #[comptime] range: Option<(u32, u32)>, + #[comptime] range: Option<(usize, usize)>, #[comptime] config: &FuseBlockConfig, -) -> u32 { +) -> usize { match arg { FuseArg::Input(pos, _precision, _layout) => { let tensor = inputs.tensors.index(pos); @@ -491,11 +491,11 @@ fn get_offset( inputs: &GlobalArgs, locals: &LocalArgs, tensor: &GlobalTensor, - ref_pos: u32, - #[comptime] range: Option<(u32, u32)>, + ref_pos: usize, + #[comptime] range: Option<(usize, usize)>, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, -) -> u32 { +) -> usize { index_offset_with_layout( inputs, tensor, @@ -509,28 +509,28 @@ fn get_offset( #[cube] /// Gets the line size for a global tensor. -pub fn global_line_size(global: &GlobalArgs, #[comptime] pos: u32) -> comptime_type!(u32) { +pub fn global_line_size(global: &GlobalArgs, #[comptime] pos: usize) -> comptime_type!(LineSize) { let tensor = global.tensors.index(pos); tensor.tensor.line_size() } #[cube] /// Gets the rank for a global tensor. -pub fn global_rank(global: &GlobalArgs, #[comptime] pos: u32) -> u32 { +pub fn global_rank(global: &GlobalArgs, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.rank() } #[cube] /// Gets the length for a global tensor. -pub fn global_len(global: &GlobalArgs, #[comptime] pos: u32) -> u32 { +pub fn global_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.len() } #[cube] /// Gets the buffer length for a global tensor. -pub fn global_buffer_len(global: &GlobalArgs, #[comptime] pos: u32) -> u32 { +pub fn global_buffer_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.buffer_len() } @@ -542,8 +542,8 @@ pub fn ref_len( outputs: &GlobalArgs, locals: &LocalArgs, #[comptime] config: &FuseBlockConfig, -) -> u32 { - match comptime![config.ref_layout.clone()] { +) -> usize { + match config.ref_layout.clone() { RefLayout::Concrete(arg) => match comptime![arg] { FuseArg::Input(index, _, _) => global_len(inputs, index), FuseArg::Output(index, _, _) => global_len(outputs, index), @@ -560,8 +560,8 @@ pub fn ref_buffer_len( outputs: &GlobalArgs, locals: &LocalArgs, #[comptime] config: &FuseBlockConfig, -) -> u32 { - match comptime![config.ref_layout.clone()] { +) -> usize { + match config.ref_layout.clone() { RefLayout::Concrete(arg) => match comptime![arg] { FuseArg::Input(index, _, _) => global_buffer_len(inputs, index), FuseArg::Output(index, _, _) => global_buffer_len(outputs, index), @@ -579,8 +579,8 @@ pub fn ref_buffer_len( #[cube] /// Gets the reference number of elements. -pub fn num_elements(locals: &LocalArgs, #[comptime] config: &FuseBlockConfig) -> u32 { - let mut length = 1u32; +pub fn num_elements(locals: &LocalArgs, #[comptime] config: &FuseBlockConfig) -> usize { + let mut length = 1; for i in 0..config.rank { length *= locals.ref_shape[i]; @@ -591,32 +591,32 @@ pub fn num_elements(locals: &LocalArgs, #[comptime] config: &FuseBlockConfig) -> #[cube] /// Gets the reference axis shape. -pub fn ref_shape(locals: &LocalArgs, axis: u32) -> u32 { +pub fn ref_shape(locals: &LocalArgs, axis: usize) -> usize { locals.ref_shape[axis] } #[cube] /// Gets the reference axis stride. -pub fn ref_stride(locals: &LocalArgs, axis: u32) -> u32 { +pub fn ref_stride(locals: &LocalArgs, axis: usize) -> usize { locals.ref_strides[axis] } #[cube] /// Gets the reference line size. -pub fn ref_line_size(locals: &LocalArgs) -> comptime_type!(u32) { +pub fn ref_line_size(locals: &LocalArgs) -> comptime_type!(LineSize) { comptime![locals.ref_line_size] } #[cube] /// Gets the given tensor axis shape. -pub fn global_shape(global: &GlobalArgs, axis: u32, #[comptime] pos: u32) -> u32 { +pub fn global_shape(global: &GlobalArgs, axis: usize, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.shape(axis) } #[cube] /// Gets the given tensor axis stride. -pub fn global_stride(global: &GlobalArgs, dim: u32, #[comptime] pos: u32) -> u32 { +pub fn global_stride(global: &GlobalArgs, dim: usize, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.stride(dim) } @@ -626,11 +626,11 @@ fn index_offset_with_layout( inputs: &GlobalArgs, tensor: &GlobalTensor, locals: &LocalArgs, - index: u32, - #[comptime] range: Option<(u32, u32)>, - #[comptime] rank: u32, + index: usize, + #[comptime] range: Option<(usize, usize)>, + #[comptime] rank: usize, #[comptime] transform: Option, -) -> u32 { +) -> usize { match comptime![transform.clone()] { Some(Transform::Reshape(shape)) => { comptime![assert!( @@ -645,15 +645,15 @@ fn index_offset_with_layout( Some(Transform::SwapDims(dim1, dim2)) => { let (start, end) = comptime! {match range { Some(range) => range, - None => (0u32, rank), + None => (0, rank), }}; let offset_ref = index * locals.ref_line_size; - let mut offset = 0u32; + let mut offset = 0; #[unroll] for i in start..end { - let index = comptime![swap_dims_transform(&i, (dim1, dim2))]; + let index = comptime![swap_dims_transform(i, (dim1, dim2))]; let ogwl = offset_ref / locals.ref_strides[i]; offset += ogwl % tensor.tensor.shape(index) * tensor.tensor.stride(index); } @@ -663,11 +663,11 @@ fn index_offset_with_layout( None => { let (start, end) = comptime! {match range { Some(range) => range, - None => (0u32, rank), + None => (0, rank), }}; let offset_ref = index * locals.ref_line_size; - let mut offset = 0u32; + let mut offset = 0; #[unroll] for i in start..end { @@ -680,10 +680,7 @@ fn index_offset_with_layout( } } -pub(crate) fn swap_dims_transform(i: &I, dims: (u32, u32)) -> u32 { - let i_cloned: I = i.clone(); - let i = i_cloned.value().as_const().unwrap().as_u32(); - +pub(crate) fn swap_dims_transform(i: usize, dims: (usize, usize)) -> usize { if i == dims.0 { dims.1 } else if i == dims.1 { @@ -699,18 +696,18 @@ pub(crate) fn swap_dims_transform(i: &I, dims: (u32, u32)) -> fn reshaped_index( inputs: &GlobalArgs, locals: &LocalArgs, - index: u32, - #[comptime] rank: u32, - #[comptime] shape: Sequence, -) -> u32 { - let mut offset = 0u32; - let mut stride_curr = 1u32; + index: usize, + #[comptime] rank: usize, + #[comptime] shape: Vec, +) -> usize { + let mut offset = 0; + let mut stride_curr = 1; #[unroll] for r in 0..rank { - let i = reverse_index(rank, r); - let arg = comptime![shape.index(i)]; - let shape_i = read_scalar_shape(inputs, comptime![arg.clone()]); + let i = reverse_index(rank, r).comptime(); + let arg = shape[i].clone(); + let shape_i = read_scalar_shape(inputs, arg); let ogwl = index / locals.ref_strides[i]; offset += ogwl % shape_i * stride_curr; @@ -726,9 +723,9 @@ fn reshaped_index( #[allow(clippy::clone_on_copy)] fn reshaped_index_to_original_index( original: &Tensor>, - index_reshaped: u32, - #[comptime] rank: u32, -) -> u32 { + index_reshaped: usize, + #[comptime] rank: usize, +) -> usize { let mut remaining = index_reshaped; let mut offset = 0; @@ -749,21 +746,19 @@ fn reshaped_index_to_original_index( #[cube] #[allow(unused_variables)] -pub(crate) fn reverse_index(#[comptime] rank: u32, iter: u32) -> comptime_type!(u32) { - intrinsic!(|_| { - let elem = iter.constant().map(|cons| cons.as_u32()).unwrap(); - rank - elem - 1 - }) +pub(crate) fn reverse_index( + #[comptime] rank: usize, + #[comptime] iter: usize, +) -> comptime_type!(usize) { + rank - iter - 1 } /// Generic way to construct any [`CubePrimitive`] from an int. Used for fusion. #[allow(unused_variables)] #[cube] -fn from_const_int(#[comptime] value: u32) -> C { +fn from_const_int(#[comptime] value: usize) -> C { intrinsic!(|scope| { - let constant: ExpandElement = value.into(); - let constant_c = constant.as_const().unwrap().cast_to(C::as_type(scope)); - ExpandElement::Plain(Variable::constant(constant_c)).into() + ExpandElement::Plain(Variable::constant(value.into(), C::as_type(scope))).into() }) } diff --git a/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs b/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs index 318d01f93c..d2be1e33da 100644 --- a/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs +++ b/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs @@ -3,7 +3,6 @@ use burn_std::quantization::{QuantScheme, QuantStore, QuantValue}; use burn_std::{bf16, f16}; use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind}; use cubecl::prelude::*; -use cubecl::std::scalar::InputScalar; use serde::{Deserialize, Serialize}; use super::tensor::GlobalTensor; @@ -12,30 +11,30 @@ use super::tensor::GlobalTensor; /// Argument to a [fuse operation](FuseOp). pub enum FuseArg { /// A readonly input tensor. - Input(u32, FuseType, LayoutInfo), + Input(usize, FuseType, LayoutInfo), /// A temporary local variable. - Local(u32, FuseType), + Local(usize, FuseType), /// A readwrite output tensor. - Output(u32, FuseType, LayoutInfo), + Output(usize, FuseType, LayoutInfo), /// A global scalar. - Scalar(u32, FuseType), + Scalar(usize, FuseType), /// A global scalar used in a reshape operation. /// /// This is not a scalar defined by a user for computation, but a scalar defined as part of /// a reshape operation. - ScalarShape(u32), + ScalarShape(usize), /// Only constant that can be encoded into an u32 can be used as literal. - Literal(u32, FuseType), + Literal(usize, FuseType), /// A readonly input tensor that is reshaped. InputReshaped { original: Box, - shape: Sequence, + shape: Vec, broadcasted: bool, }, /// A readonly input tensor with swapped dimensions. InputSwapDims { original: Box, - dims: (u32, u32), + dims: (usize, usize), broadcasted: bool, }, } @@ -128,13 +127,13 @@ pub enum FuseOp { input: FuseArg, indices: FuseArg, output: FuseArg, - dim: u32, + dim: usize, }, Select { input: FuseArg, indices: FuseArg, output: FuseArg, - dim: u32, + dim: usize, }, Dequantize { values: FuseArg, @@ -197,7 +196,7 @@ impl FuseOp { pub struct GlobalArgs { pub tensors: Sequence, pub scalars: Sequence, - pub reshapes: Sequence, + pub reshapes: Sequence, } impl Default for GlobalArgsLaunch<'_, R> { @@ -238,15 +237,15 @@ impl GlobalArgsLaunch<'_, R> { RefLayout::Virtual(layout) => match layout { VirtualLayout::SwapDims(original, dims) => { let mut shape = self.shape(original); - shape.swap(dims.0 as usize, dims.1 as usize); + shape.swap(dims.0, dims.1); shape } VirtualLayout::Reshaped { reshape_pos, .. } => { - let start = *reshape_pos as usize * rank; + let start = *reshape_pos * rank; let end = start + rank; self.reshapes.values[start..end] .iter() - .map(|s| s.elem as usize) + .map(|s| s.elem) .collect() } VirtualLayout::Shape(original, _) => self.shape(original), @@ -290,7 +289,7 @@ impl GlobalArgsLaunch<'_, R> { /// # Panics /// /// If the argument doesn't have an handle. - pub fn line_size(&self, arg: &FuseArg) -> u8 { + pub fn line_size(&self, arg: &FuseArg) -> LineSize { match self.resolve_arg(arg) { TensorArg::Handle { line_size, .. } => *line_size, TensorArg::Alias { .. } => panic!("Unsupported yet"), @@ -304,8 +303,8 @@ impl GlobalArgsLaunch<'_, R> { /// If the argument isn't a global input or output tensor. pub fn resolve_arg(&self, arg: &FuseArg) -> &TensorArg<'_, R> { match arg { - FuseArg::Input(pos, _, _) => &self.tensors.values[*pos as usize].tensor, - FuseArg::Output(pos, _, _) => &self.tensors.values[*pos as usize].tensor, + FuseArg::Input(pos, _, _) => &self.tensors.values[*pos].tensor, + FuseArg::Output(pos, _, _) => &self.tensors.values[*pos].tensor, other => panic!("Arg not found: {other:?}"), } } @@ -315,47 +314,47 @@ impl GlobalArgsLaunch<'_, R> { /// Keep track of all local variables that are used as argument in fused /// [element wise operations](ElemwiseOp). pub struct LocalArgs { - pub l_f64: Registry>, - pub l_f32: Registry>, - pub l_f16: Registry>, - pub l_bf16: Registry>, - pub l_i64: Registry>, - pub l_i32: Registry>, - pub l_i16: Registry>, - pub l_i8: Registry>, - pub l_u64: Registry>, - pub l_u32: Registry>, - pub l_u16: Registry>, - pub l_u8: Registry>, - pub l_bool: Registry>, - pub ref_shape: Slice, - pub ref_strides: Slice, + pub l_f64: Registry>, + pub l_f32: Registry>, + pub l_f16: Registry>, + pub l_bf16: Registry>, + pub l_i64: Registry>, + pub l_i32: Registry>, + pub l_i16: Registry>, + pub l_i8: Registry>, + pub l_u64: Registry>, + pub l_u32: Registry>, + pub l_u16: Registry>, + pub l_u8: Registry>, + pub l_bool: Registry>, + pub ref_shape: Slice, + pub ref_strides: Slice, #[cube(comptime)] - pub ref_line_size: u32, + pub ref_line_size: LineSize, } #[cube] impl LocalArgs { /// Creates a new [LocalArgs] container. pub fn new( - ref_shape: Slice, - ref_strides: Slice, - #[comptime] ref_line_size: u32, + ref_shape: Slice, + ref_strides: Slice, + #[comptime] ref_line_size: LineSize, ) -> LocalArgs { LocalArgs { - l_f64: Registry::>::new(), - l_f32: Registry::>::new(), - l_f16: Registry::>::new(), - l_bf16: Registry::>::new(), - l_i64: Registry::>::new(), - l_i32: Registry::>::new(), - l_i16: Registry::>::new(), - l_i8: Registry::>::new(), - l_u64: Registry::>::new(), - l_u32: Registry::>::new(), - l_u16: Registry::>::new(), - l_u8: Registry::>::new(), - l_bool: Registry::>::new(), + l_f64: Registry::>::new(), + l_f32: Registry::>::new(), + l_f16: Registry::>::new(), + l_bf16: Registry::>::new(), + l_i64: Registry::>::new(), + l_i32: Registry::>::new(), + l_i16: Registry::>::new(), + l_i8: Registry::>::new(), + l_u64: Registry::>::new(), + l_u32: Registry::>::new(), + l_u16: Registry::>::new(), + l_u8: Registry::>::new(), + l_bool: Registry::>::new(), ref_shape, ref_strides, ref_line_size, @@ -405,10 +404,10 @@ pub enum FuseType { #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] /// Configuration that encapsulates all comptime information necessary for element wise fusion. pub struct FuseBlockConfig { - pub rank: u32, + pub rank: usize, pub ref_layout: RefLayout, - pub ops: Sequence, - pub width: u8, + pub ops: Vec, + pub width: LineSize, } #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] @@ -425,12 +424,15 @@ pub enum RefLayout { /// tensor with swap dimensions. pub enum VirtualLayout { /// Virtual tensor with the provided shape id and contiguous strides. - Reshaped { reshape_pos: u32, line_size: u32 }, + Reshaped { + reshape_pos: usize, + line_size: LineSize, + }, /// Virtual tensor with the same shape as the given input, but with swap dims and contiguous /// strides. - SwapDims(FuseArg, (u32, u32)), + SwapDims(FuseArg, (usize, usize)), /// Virtual tensor with the same shape as the given input, but with contiguous strides. - Shape(FuseArg, u32), + Shape(FuseArg, usize), } impl FuseArg { diff --git a/crates/burn-cubecl-fusion/src/engine/codegen/kernel.rs b/crates/burn-cubecl-fusion/src/engine/codegen/kernel.rs index 7724083d13..e0cba172de 100644 --- a/crates/burn-cubecl-fusion/src/engine/codegen/kernel.rs +++ b/crates/burn-cubecl-fusion/src/engine/codegen/kernel.rs @@ -26,16 +26,16 @@ pub fn fuse_on_write( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, write_values: Registry>, - #[comptime] write_args: Sequence, + #[comptime] write_args: Vec, #[comptime] config: &FuseBlockConfig, ) { // Write the values given as arguments. #[unroll] - for i in 0..write_args.len() { - let arg = comptime![write_args.index(i).clone()]; - let val = write_values.find(comptime![arg.clone()]); + for _ in 0..write_args.len() { + let arg = comptime![write_args[0].clone()]; + let val = write_values.find(arg.clone()); write::(inputs, outputs, locals, write_pos, val, arg, config); } @@ -62,7 +62,7 @@ pub fn fuse_on_read( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - read_pos: u32, + read_pos: usize, #[comptime] read_args: Sequence, #[comptime] config: &FuseBlockConfig, ) -> Sequence> { @@ -76,11 +76,11 @@ pub fn fuse_on_read( let value = read::(inputs, outputs, locals, read_pos, arg, config); let value_line_size = value.line_size(); - let output_line_size = comptime!(config.width as u32); + let output_line_size = config.width; // We currently don't support broadcasting __across__ blocks. if comptime!(value_line_size != output_line_size) { - let mut tmp = Line::::empty(comptime!(config.width as u32)); + let mut tmp = Line::::empty(config.width); comptime!( assert_eq!(value_line_size, 1, "The input line_size must be 1 or the same as the config width."); ); @@ -88,7 +88,7 @@ pub fn fuse_on_read( let val = value[0]; #[unroll] - for i in 0..comptime!(config.width as u32) { + for i in 0..config.width { tmp[i] = val; } @@ -116,7 +116,7 @@ pub fn init_locals( let mut ref_shape = Array::new(config.rank); let mut ref_strides = Array::new(config.rank); - match comptime![config.ref_layout.clone()] { + match config.ref_layout.clone() { RefLayout::Concrete(arg) => match comptime![arg] { FuseArg::Input(index, ..) => { let layout = inputs.tensors.index(index); @@ -152,23 +152,23 @@ pub fn init_locals( }, RefLayout::Virtual(layout) => match layout { VirtualLayout::SwapDims(original, dims) => { - let layout = match comptime![original.clone()] { + let layout = match original.clone() { FuseArg::Input(pos, ..) => inputs.tensors.index(pos), FuseArg::Output(pos, ..) => outputs.tensors.index(pos), _ => comptime![panic!("Unsupported")], }; - let mut stride_curr = 1u32; + let mut stride_curr = 1; #[unroll] #[allow(clippy::clone_on_copy)] for i in 0..config.rank { let reverse = reverse_index(config.rank, i); - let swap = comptime![swap_dims_transform(comptime![&reverse], dims)]; - let shape = layout.tensor.shape(comptime![swap.clone()]); + let swap = comptime![swap_dims_transform(reverse, dims)]; + let shape = layout.tensor.shape(swap.clone()); - ref_shape[comptime![reverse.clone()]] = shape; - ref_strides[comptime![reverse.clone()]] = stride_curr; + ref_shape[reverse] = shape; + ref_strides[reverse] = stride_curr; stride_curr *= ref_shape[comptime![reverse]]; } @@ -183,7 +183,7 @@ pub fn init_locals( reshape_pos, line_size, } => { - let mut stride_curr = 1u32; + let mut stride_curr = 1; let start = reshape_pos * config.rank; #[unroll] @@ -191,7 +191,7 @@ pub fn init_locals( for i in 0..config.rank { let reverse = reverse_index(config.rank, i); let arg = comptime![FuseArg::ScalarShape(start + reverse)]; - let shape = read_scalar_shape(inputs, comptime![arg.clone()]); + let shape = read_scalar_shape(inputs, arg.clone()); ref_shape[comptime![reverse]] = shape; ref_strides[comptime![reverse]] = stride_curr; @@ -202,12 +202,12 @@ pub fn init_locals( LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), line_size) } VirtualLayout::Shape(original, line_size) => { - let layout = match comptime![original.clone()] { + let layout = match original.clone() { FuseArg::Input(pos, ..) => inputs.tensors.index(pos), FuseArg::Output(pos, ..) => outputs.tensors.index(pos), _ => comptime![panic!("Unsupported")], }; - let mut stride_curr = 1u32; + let mut stride_curr = 1; #[unroll] #[allow(clippy::clone_on_copy)] @@ -233,13 +233,13 @@ fn fuse( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - pos: u32, + pos: usize, #[comptime] config: &FuseBlockConfig, ) { #[unroll] for index in 0..config.ops.len() { - let op = comptime! { config.ops.index(index).clone() }; - set_polyfill::>(comptime![op.cmp_type()]); + let op = config.ops[index].clone(); + set_polyfill::>(op.cmp_type()); match op { FuseOp::Add(op) => { @@ -367,7 +367,7 @@ macro_rules! binary_op { inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] op: BinaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { @@ -387,7 +387,7 @@ macro_rules! binary_func { inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] op: BinaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { @@ -407,7 +407,7 @@ macro_rules! comparison_op { inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] op: BinaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { @@ -427,7 +427,7 @@ macro_rules! unary_func { inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] op: UnaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { @@ -444,7 +444,7 @@ fn assign( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] op: UnaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { @@ -458,8 +458,8 @@ fn gather( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, - #[comptime] dim: u32, + write_pos: usize, + #[comptime] dim: usize, #[comptime] input: FuseArg, #[comptime] indices: FuseArg, #[comptime] output: FuseArg, @@ -482,7 +482,7 @@ fn gather( let stride_input_dim = global_stride(inputs, dim, pos_input); - let mut index = 0u32; + let mut index = 0; let mut result = Line::empty(line_size); if comptime![dim > 0] { @@ -491,8 +491,8 @@ fn gather( outputs, locals, write_pos, - comptime!(input.clone()), - comptime![Some((0u32, dim))], + input.clone(), + comptime![Some((0, dim))], config, ); index += index_before; @@ -517,7 +517,7 @@ fn gather( locals, write_pos, indices, - comptime![Some((0u32, config.rank))], + comptime![Some((0, config.rank))], config, ); @@ -539,7 +539,7 @@ fn gather( inputs, locals, pos_input, - index + (offset[0] * stride_input_dim), + index + (offset[0] as usize * stride_input_dim), LayoutInfo::IsRef, config, None, @@ -549,7 +549,7 @@ fn gather( } } else { // Shared index for whole line - let stride_input_line = global_stride(inputs, comptime!(config.rank - 1), pos_input); + let stride_input_line = global_stride(inputs, config.rank - 1, pos_input); let offset = read_input::( inputs, @@ -561,7 +561,7 @@ fn gather( None, ); - index += offset[0] * stride_input_dim; + index += offset[0] as usize * stride_input_dim; #[unroll] for i in 0..line_size { @@ -587,8 +587,8 @@ fn select_indices( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, - #[comptime] dim: u32, + write_pos: usize, + #[comptime] dim: usize, #[comptime] input: FuseArg, #[comptime] indices: FuseArg, #[comptime] output: FuseArg, @@ -613,7 +613,7 @@ fn select_indices( let stride_input_dim = global_stride(inputs, dim, pos_input); - let mut index = 0u32; + let mut index = 0; let mut result = Line::empty(line_size_ref); if comptime![dim != config.rank - 1] { @@ -627,8 +627,8 @@ fn select_indices( outputs, locals, write_pos, - comptime!(input.clone()), - comptime![Some((0u32, dim))], + input.clone(), + comptime![Some((0, dim))], config, ); index += index_before; @@ -640,7 +640,7 @@ fn select_indices( outputs, locals, write_pos, - comptime!(input.clone()), + input.clone(), comptime![Some((dim + 1, config.rank))], config, ); @@ -660,7 +660,7 @@ fn select_indices( None, ); - index += offset_dim[0] * stride_input_dim; + index += offset_dim[0] as usize * stride_input_dim; #[unroll] for i in 0..line_size_ref { @@ -687,8 +687,8 @@ fn select_indices( outputs, locals, write_pos, - comptime!(input.clone()), - comptime![Some((0u32, dim))], + input.clone(), + comptime![Some((0, dim))], config, ); index += index_before; @@ -726,7 +726,7 @@ fn select_indices( inputs, locals, pos_input, - index + (offset_dim[0] * stride_input_dim), + index + (offset_dim[0] as usize * stride_input_dim), LayoutInfo::IsRef, config, None, @@ -743,7 +743,7 @@ fn conditional_assign( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] cond: FuseArg, #[comptime] lhs: FuseArg, #[comptime] rhs: FuseArg, @@ -763,7 +763,7 @@ fn clamp( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] input: FuseArg, #[comptime] min: FuseArg, #[comptime] max: FuseArg, @@ -784,7 +784,7 @@ fn dequantize( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] input: FuseArg, #[comptime] scales: FuseArg, #[comptime] output: FuseArg, @@ -834,7 +834,7 @@ fn dequantize( ); let line_size = input.line_size(); - let num_quants = comptime!(scheme.num_quants() as u32); + let num_quants = scheme.num_quants(); let scales = input_as_scales_view::>( inputs, @@ -856,22 +856,15 @@ fn dequantize( } else { let mut line = Line::empty(line_size_result); - // We have to do all index work as comptime because higher line sizes removes the - // possibility to index dynamically on lines. - let mut i = comptime!(0); - #[unroll] - for _ in 0..line_size { - let mut j = comptime!(0); + for i in 0..line_size { let value = result[i]; #[unroll] - for _ in 0..num_quants { - let index = comptime!(i * num_quants + j); + for j in 0..num_quants { + let index = i * num_quants + j; line[index] = value[j]; - comptime!(j += 1); } - comptime!(i += 1); } line diff --git a/crates/burn-cubecl-fusion/src/engine/codegen/view.rs b/crates/burn-cubecl-fusion/src/engine/codegen/view.rs index 4fc2143c8a..4d6da2c31a 100644 --- a/crates/burn-cubecl-fusion/src/engine/codegen/view.rs +++ b/crates/burn-cubecl-fusion/src/engine/codegen/view.rs @@ -24,7 +24,7 @@ pub struct GlobalInput { inputs: GlobalArgs, locals: LocalArgs, #[cube(comptime)] - pos: u32, + pos: usize, #[cube(comptime)] ty: StorageType, #[cube(comptime)] @@ -67,7 +67,7 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_read_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, ) -> ::ExpandType { ViewOperationsExpand::::__expand_read_unchecked_method(self, scope, pos) } @@ -76,7 +76,7 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_read_checked_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, ) -> ::ExpandType { let zero = E::__expand_cast_from(scope, 0.into()); ViewOperationsExpand::::__expand_read_masked_method(self, scope, pos, zero) @@ -86,7 +86,7 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_read_masked_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, value: ::ExpandType, ) -> ::ExpandType { let in_bounds = ViewOperationsExpand::::__expand_is_in_bounds_method( @@ -103,7 +103,7 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_read_unchecked_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, ) -> ::ExpandType { let value = read_input::expand::( scope, @@ -122,8 +122,8 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_to_linear_slice_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, - end: ExpandElementTyped, + pos: ExpandElementTyped, + end: ExpandElementTyped, ) -> SliceExpand { scope.register_type::>(self.ty); let end = add::expand(scope, end.clone(), 1.into()); @@ -136,13 +136,13 @@ impl ViewOperationsExpand for GlobalInputExpand { _scope: &mut Scope, _barrier: BarrierExpand, _shared_memory: SliceExpand, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) { panic!("Not a tensor map") } #[allow(clippy::too_many_arguments)] - fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped { + fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped { global_buffer_len::expand(scope, self.inputs.clone(), self.pos) } @@ -150,7 +150,7 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_is_in_bounds_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, ) -> ExpandElementTyped { let buffer_len = global_buffer_len::expand(scope, self.inputs.clone(), self.pos); lt::expand(scope, pos, buffer_len) @@ -159,7 +159,7 @@ impl ViewOperationsExpand for GlobalInputExpand { impl Lined for GlobalInput {} impl LinedExpand for GlobalInputExpand { - fn line_size(&self) -> u32 { + fn line_size(&self) -> LineSize { let mut temp_scope = Scope::root(false); global_line_size::expand(&mut temp_scope, self.inputs.clone(), self.pos) } @@ -201,7 +201,7 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_read_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) -> as CubeType>::ExpandType { todo!() } @@ -210,7 +210,7 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_read_checked_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) -> as CubeType>::ExpandType { todo!() } @@ -219,7 +219,7 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_read_masked_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, _value: as CubeType>::ExpandType, ) -> as CubeType>::ExpandType { todo!() @@ -229,7 +229,7 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_read_unchecked_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) -> as CubeType>::ExpandType { todo!() } @@ -238,8 +238,8 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_to_linear_slice_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, - _size: ExpandElementTyped, + _pos: ExpandElementTyped, + _size: ExpandElementTyped, ) -> SliceExpand, ReadOnly> { todo!() } @@ -250,13 +250,13 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx _scope: &mut Scope, _barrier: BarrierExpand, _shared_memory: SliceExpand, ReadWrite>, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) { panic!("Not a tensor map") } #[allow(clippy::too_many_arguments)] - fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped { + fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped { ref_len::expand( scope, self.inputs.clone(), @@ -270,7 +270,7 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_is_in_bounds_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, ) -> ExpandElementTyped { let buffer_len = ref_buffer_len::expand( scope, @@ -289,11 +289,11 @@ impl ViewOperationsMutExpand, Coords1d> for FusedOutpu fn __expand_write_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, value: as CubeType>::ExpandType, ) { let values = Registry::>::__expand_new(scope); - let mut args = comptime![Sequence::::new()]; + let mut args = comptime![Vec::::new()]; values .clone() @@ -316,7 +316,7 @@ impl ViewOperationsMutExpand, Coords1d> for FusedOutpu fn __expand_write_checked_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, value: as CubeType>::ExpandType, ) { let in_bounds = ViewOperationsExpand::, Coords1d>::__expand_is_in_bounds_method( @@ -333,8 +333,8 @@ impl ViewOperationsMutExpand, Coords1d> for FusedOutpu fn __expand_to_linear_slice_mut_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, - _size: ExpandElementTyped, + _pos: ExpandElementTyped, + _size: ExpandElementTyped, ) -> SliceExpand, ReadWrite> { todo!("Not yet supported") } @@ -344,7 +344,7 @@ impl ViewOperationsMutExpand, Coords1d> for FusedOutpu &self, _scope: &mut Scope, _shared_memory: SliceExpand, ReadOnly>, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) { panic!("Not a tensor map") } @@ -352,7 +352,7 @@ impl ViewOperationsMutExpand, Coords1d> for FusedOutpu impl Lined for FusedOutput {} impl LinedExpand for FusedOutputExpand { - fn line_size(&self) -> u32 { + fn line_size(&self) -> LineSize { self.locals.ref_line_size } } diff --git a/crates/burn-cubecl-fusion/src/engine/fuser.rs b/crates/burn-cubecl-fusion/src/engine/fuser.rs index 1e18d1f2f1..2f117dde30 100644 --- a/crates/burn-cubecl-fusion/src/engine/fuser.rs +++ b/crates/burn-cubecl-fusion/src/engine/fuser.rs @@ -276,11 +276,7 @@ impl TraceOperationFuser { } if self.fuser.fuse(|fuser| { - fuser.input_swap_dims( - &desc.input, - &desc.out, - (desc.dim1 as u32, desc.dim2 as u32), - )?; + fuser.input_swap_dims(&desc.input, &desc.out, (desc.dim1, desc.dim2))?; Some(()) }) { @@ -364,7 +360,7 @@ impl TraceOperationFuser { input, indices, output, - dim: desc.dim as u32, + dim: desc.dim, }); Some(()) @@ -384,7 +380,7 @@ impl TraceOperationFuser { input, indices, output, - dim: desc.dim as u32, + dim: desc.dim, }); Some(()) diff --git a/crates/burn-cubecl-fusion/src/engine/launch/executor.rs b/crates/burn-cubecl-fusion/src/engine/launch/executor.rs index 850f1cdace..c7d30e922a 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/executor.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/executor.rs @@ -16,8 +16,7 @@ use burn_std::DType; use cubecl::{ CubeElement, Runtime, client::ComputeClient, - prelude::{ScalarArg, Sequence, TensorArg}, - std::scalar::InputScalar, + prelude::{InputScalar, ScalarArg, TensorArg}, }; use std::marker::PhantomData; @@ -87,15 +86,15 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { let reference = match block_plan.reference { ReferenceSelection::Concrete { layout, .. } => RefLayout::Concrete(layout), ReferenceSelection::VirtualShape { original, .. } => { - RefLayout::Virtual(VirtualLayout::Shape(original, block_plan.width as u32)) + RefLayout::Virtual(VirtualLayout::Shape(original, block_plan.width)) } ReferenceSelection::SwapDims { original, dims } => { RefLayout::Virtual(VirtualLayout::SwapDims(original, dims)) } ReferenceSelection::Reshaped { reshape_pos } => { RefLayout::Virtual(VirtualLayout::Reshaped { - reshape_pos: reshape_pos as u32, - line_size: block_plan.width as u32, + reshape_pos, + line_size: block_plan.width, }) } ReferenceSelection::NotFound | ReferenceSelection::Searching => { @@ -107,7 +106,7 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { } }; - let mut ops = Sequence::::new(); + let mut ops = Vec::::new(); for read_ops in block_plan.reads.into_values() { for op in read_ops { @@ -124,7 +123,7 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { } let config = FuseBlockConfig { - rank: plan.rank as u32, + rank: plan.rank, ref_layout: reference, ops, width: block_plan.width, @@ -153,7 +152,7 @@ fn register_inputs<'h, R: Runtime>( HandleInput::Normal(hi) => { let arg = hi .handle - .as_tensor_arg(&hi.global_ir.shape.dims, hi.vectorization); + .as_tensor_arg(&hi.global_ir.shape.dims, hi.line_size); inputs.tensors.push(GlobalTensorArg::new( arg, hi.precision.into_elem(), @@ -163,7 +162,7 @@ fn register_inputs<'h, R: Runtime>( HandleInput::QuantValues(hi) => { let arg = hi .handle - .as_tensor_arg(&hi.global_ir.shape.dims, hi.vectorization); + .as_tensor_arg(&hi.global_ir.shape.dims, hi.line_size); inputs .tensors .push(GlobalTensorArg::new(arg, hi.precision.into_elem(), false)); @@ -209,12 +208,12 @@ fn register_outputs<'s, BT: CubeElement, R: Runtime>( precision, handle, global_shape, - vectorization, + vectorization: line_size, #[cfg(feature = "autotune-checks")] relative_id, .. } => { - let arg = handle.as_tensor_arg(global_shape, *vectorization); + let arg = handle.as_tensor_arg(global_shape, *line_size); let elem = match precision { FuseType::Bool => match elem_dtype::() { @@ -269,7 +268,7 @@ fn register_scalars<'h, R: Runtime>( let global = context.tensors.get(reshaped).unwrap(); for shape in global.shape.iter() { - inputs.reshapes.push(ScalarArg::new(*shape as u32)); + inputs.reshapes.push(ScalarArg::new(*shape)); } } } diff --git a/crates/burn-cubecl-fusion/src/engine/launch/input.rs b/crates/burn-cubecl-fusion/src/engine/launch/input.rs index f712d95f6f..549e554ed6 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/input.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/input.rs @@ -84,7 +84,7 @@ impl<'a, R: Runtime> InputPlanner<'a, R> { global_ir: tensor_global, precision, handle, - vectorization: 1, + line_size: 1, })); plan.handle_inputs @@ -208,7 +208,7 @@ impl<'a, R: Runtime> InputPlanner<'a, R> { && shape_relative == &block.shape_ref { block_plan.potential_reference_input = Some(InputReference::Reshaped { - reshape_pos: *reshape_pos as usize, + reshape_pos: *reshape_pos, }); } return true; @@ -225,11 +225,7 @@ impl<'a, R: Runtime> InputPlanner<'a, R> { } if original == &tensor_relative.id { - let shape = tensor_relative - .shape - .clone() - .swap(dims.0 as usize, dims.1 as usize) - .unwrap(); + let shape = tensor_relative.shape.clone().swap(dims.0, dims.1).unwrap(); if block_plan.potential_reference_input.is_none() && shape.dims == block.shape_ref diff --git a/crates/burn-cubecl-fusion/src/engine/launch/output.rs b/crates/burn-cubecl-fusion/src/engine/launch/output.rs index deefeaa1a5..312c4a716e 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/output.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/output.rs @@ -198,7 +198,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { let set_ref_as_concrete = |block: &mut BlockPlan<'_>| { block.reference = ReferenceSelection::Concrete { layout: FuseArg::Input( - input_pos as u32, + input_pos, reference.precision, LayoutInfo::IsRef, ), @@ -210,7 +210,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { let set_ref_as_virtual = |block: &mut BlockPlan<'_>| { block.reference = ReferenceSelection::VirtualShape { original: FuseArg::Input( - input_pos as u32, + input_pos, reference.precision, LayoutInfo::Unknown, ), @@ -243,7 +243,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { .expect("Quant can't be used in swap dims operation"); block.reference = ReferenceSelection::SwapDims { original: FuseArg::Input( - original_pos as u32, + original_pos, reference.precision, LayoutInfo::Unknown, ), @@ -407,11 +407,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { && self.blocks[block_idx].shape_ref == output.tensor_relative.shape.dims { block.reference = ReferenceSelection::Concrete { - layout: FuseArg::Output( - output.pos_original as u32, - output.precision, - LayoutInfo::IsRef, - ), + layout: FuseArg::Output(output.pos_original, output.precision, LayoutInfo::IsRef), shape: tensor_global.shape.dims.clone(), strides: strides.clone(), }; @@ -553,7 +549,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { output: OutputSorted, tensor_global: TensorIr, original: TensorId, - dims: (u32, u32), + dims: (usize, usize), block_idx: usize, ) { let block = &mut plan.blocks[block_idx]; @@ -578,7 +574,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { dtype, qparams: original_handle.handle.qparams.clone(), }; - handle.strides.swap(dims.0 as usize, dims.1 as usize); + handle.strides.swap(dims.0, dims.1); context .handles diff --git a/crates/burn-cubecl-fusion/src/engine/launch/plan.rs b/crates/burn-cubecl-fusion/src/engine/launch/plan.rs index aae753e130..7f359be83b 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/plan.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/plan.rs @@ -7,7 +7,7 @@ use crate::{ }, }; use burn_ir::{TensorId, TensorIr}; -use cubecl::Runtime; +use cubecl::{Runtime, ir::LineSize}; use std::collections::BTreeMap; /// The plan is responsible to keep runtime information related to the launch of a fused kernel @@ -30,7 +30,7 @@ pub struct BlockPlan<'a> { pub reference: ReferenceSelection, pub reads: BTreeMap>, pub writes: BTreeMap, - pub width: u8, + pub width: LineSize, } #[derive(Debug)] @@ -40,7 +40,7 @@ pub enum InputReference { }, SwapDims { original_pos: usize, - dims: (u32, u32), + dims: (usize, usize), }, Reshaped { reshape_pos: usize, @@ -59,7 +59,7 @@ pub enum ReferenceSelection { }, SwapDims { original: FuseArg, - dims: (u32, u32), + dims: (usize, usize), }, Reshaped { reshape_pos: usize, @@ -137,7 +137,7 @@ pub enum HandleOutput { precision: FuseType, handle: CubeFusionHandle, global_shape: Vec, - vectorization: u8, + vectorization: LineSize, }, } @@ -147,7 +147,7 @@ pub struct NormalHandleInput { pub global_ir: TensorIr, pub precision: FuseType, pub handle: CubeFusionHandle, - pub vectorization: u8, + pub line_size: LineSize, pub broadcated: bool, // Strides can be modified during plan execution, but need to be restored on rollback pub orig_strides: Vec, @@ -159,7 +159,7 @@ pub struct QuantValuesHandleInput { pub global_ir: TensorIr, pub precision: FuseType, pub handle: CubeFusionHandle, - pub vectorization: u8, + pub line_size: LineSize, } #[derive(Debug)] @@ -200,7 +200,7 @@ impl NormalHandleInput { handle, relative_id: tensor_relative.id, global_ir: tensor_global, - vectorization: 1, + line_size: 1, broadcated: false, orig_strides: strides, } diff --git a/crates/burn-cubecl-fusion/src/engine/launch/runner.rs b/crates/burn-cubecl-fusion/src/engine/launch/runner.rs index 0f53dd8970..e08d243ef1 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/runner.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/runner.rs @@ -76,9 +76,9 @@ pub trait Vectorization { inputs: impl Iterator>, outputs: impl Iterator, reshaped: impl Iterator, - swapped: impl Iterator, - line_sizes: &[u8], - max: u8, + swapped: impl Iterator, + line_sizes: &[LineSize], + max: LineSize, axis: VectorizationAxis, ) { vectorization_default( diff --git a/crates/burn-cubecl-fusion/src/engine/launch/vectorization/base.rs b/crates/burn-cubecl-fusion/src/engine/launch/vectorization/base.rs index da33e5598e..c0a45d949a 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/vectorization/base.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/vectorization/base.rs @@ -4,18 +4,18 @@ use crate::{ }; use burn_fusion::stream::Context; use burn_ir::{TensorId, TensorIr}; -use cubecl::Runtime; +use cubecl::{Runtime, ir::LineSize}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; #[derive(Debug, Clone, Copy)] pub enum Vect { Broadcasted, - Aligned(u8), + Aligned(LineSize), } impl Vect { - pub fn line_size(&self) -> u8 { + pub fn line_size(&self) -> LineSize { match self { Vect::Broadcasted => 1, Vect::Aligned(val) => *val, @@ -29,13 +29,13 @@ impl Vect { #[derive(Default, Clone, Serialize, Deserialize, Debug)] pub struct LineSizeOverrides { - state: Option>>, - default: Option>, + state: Option>>, + default: Option>, } #[allow(unused)] impl LineSizeOverrides { - pub fn overrides(&mut self, tensor_id: &TensorId, line_sizes: Vec) { + pub fn overrides(&mut self, tensor_id: &TensorId, line_sizes: Vec) { let map = match &mut self.state { Some(val) => val, None => { @@ -46,7 +46,7 @@ impl LineSizeOverrides { map.insert(*tensor_id, line_sizes); } - pub fn overrides_default(&mut self, line_sizes: Vec) { + pub fn overrides_default(&mut self, line_sizes: Vec) { self.default = Some(line_sizes); } @@ -72,7 +72,7 @@ impl LineSizeOverrides { } } - pub fn tensor(&self, tensor_id: &TensorId) -> Option<&Vec> { + pub fn tensor(&self, tensor_id: &TensorId) -> Option<&Vec> { let map = match &self.state { Some(val) => val, None => match &self.default { @@ -97,10 +97,10 @@ pub(crate) fn vectorization_default<'a, R: Runtime>( inputs: impl Iterator>, outputs: impl Iterator, reshaped: impl Iterator, - swapped: impl Iterator, - line_sizes: &[u8], + swapped: impl Iterator, + line_sizes: &[LineSize], overrides: &LineSizeOverrides, - max: u8, + max: LineSize, axis: &VectorizationAxis, ) { let swapped: Vec<_> = swapped.collect(); @@ -148,7 +148,7 @@ pub(crate) fn vectorization_default<'a, R: Runtime>( overrides.tensor(&tensor_ir.id), ); let num_quants = match tensor_ir.dtype { - burn_std::DType::QFloat(quant_scheme) => quant_scheme.num_quants() as u8, + burn_std::DType::QFloat(quant_scheme) => quant_scheme.num_quants(), _ => panic!(""), }; let val = match val { @@ -214,8 +214,8 @@ fn vectorization_input( handle: &CubeFusionHandle, desc: &TensorIr, axis: &VectorizationAxis, - line_sizes: &[u8], - overrides: Option<&Vec>, + line_sizes: &[LineSize], + overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(desc.id, || handle.strides.len() - 1); let shape_axis = desc.shape[axis]; @@ -229,9 +229,9 @@ fn vectorization_input( return Vect::Aligned(1); } - let inner = |s: u8| { + let inner = |s: LineSize| { // The last dimension should be a multiple of the vector size or broadcated. - if shape_axis.is_multiple_of(s as usize) { + if shape_axis.is_multiple_of(s) { return Some(Vect::Aligned(s)); } None @@ -260,15 +260,15 @@ fn vectorization_input( fn vectorization_output( desc: &TensorIr, axis: &VectorizationAxis, - line_sizes: &[u8], - max: u8, - overrides: Option<&Vec>, + line_sizes: &[LineSize], + max: LineSize, + overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(desc.id, || desc.shape.rank() - 1); - let inner = |s: u8| { + let inner = |s: LineSize| { // The dimension should be a multiple of the vector size. - if desc.shape[axis].is_multiple_of(s as usize) && s <= max { + if desc.shape[axis].is_multiple_of(s) && s <= max { return Some(Vect::Aligned(s)); } @@ -299,9 +299,9 @@ fn vectorization_reshape( original: &TensorIr, multi_reads: bool, axis: &VectorizationAxis, - line_sizes: &[u8], - max: u8, - overrides: Option<&Vec>, + line_sizes: &[LineSize], + max: LineSize, + overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(reshaped.id, || reshaped.shape.rank() - 1); let reshape_shape_axis = reshaped.shape[axis]; @@ -321,10 +321,10 @@ fn vectorization_reshape( return Vect::Aligned(1); } - let inner = |s: u8| { + let inner = |s: LineSize| { if !multi_reads { // The last dimension should be a multiple of the vector size or broadcated. - if reshape_shape_axis.is_multiple_of(s as usize) && s <= max { + if reshape_shape_axis.is_multiple_of(s) && s <= max { Some(Vect::Aligned(s)) } else { None @@ -333,8 +333,8 @@ fn vectorization_reshape( // Since the original tensor must share the same vectorization factor as the // reshaped tensor, they must have compatible shapes when both are access // independently. - if reshape_shape_axis.is_multiple_of(s as usize) - && original_shape_axis.is_multiple_of(s as usize) + if reshape_shape_axis.is_multiple_of(s) + && original_shape_axis.is_multiple_of(s) && s <= max { Some(Vect::Aligned(s)) @@ -370,11 +370,11 @@ fn vectorization_swapped( swapped: &TensorIr, original: &TensorIr, multi_reads: bool, - dims: &(u32, u32), - max: u8, + dims: &(usize, usize), + max: LineSize, axis: &VectorizationAxis, - line_sizes: &[u8], - overrides: Option<&Vec>, + line_sizes: &[LineSize], + overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(swapped.id, || swapped.shape.rank() - 1); @@ -382,10 +382,10 @@ fn vectorization_swapped( let shape_axis = original.shape[axis]; let axis_index = axis; - let dim_index = if dims.0 as usize == axis_index { - dims.1 as usize - } else if dims.1 as usize == axis_index { - dims.0 as usize + let dim_index = if dims.0 == axis_index { + dims.1 + } else if dims.1 == axis_index { + dims.0 } else { axis_index }; @@ -406,16 +406,13 @@ fn vectorization_swapped( return Vect::Broadcasted; } - let inner = |s: u8| { + let inner = |s: LineSize| { // The last dimension should be a multiple of the vector size or broadcated. if multi_reads { - if swapped_axis.is_multiple_of(s as usize) && s <= max { + if swapped_axis.is_multiple_of(s) && s <= max { return Some(Vect::Aligned(s)); } - } else if swapped_axis.is_multiple_of(s as usize) - && shape_axis.is_multiple_of(s as usize) - && s <= max - { + } else if swapped_axis.is_multiple_of(s) && shape_axis.is_multiple_of(s) && s <= max { return Some(Vect::Aligned(s)); } None diff --git a/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs b/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs index 6e9ebe120c..8d4476cf75 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs @@ -15,12 +15,15 @@ use crate::{ }; use burn_fusion::stream::Context; use burn_ir::TensorId; -use cubecl::quant::scheme::{QuantScheme, QuantStore, QuantValue}; use cubecl::{ Runtime, client::ComputeClient, ir::{ElemType, StorageType, UIntKind}, }; +use cubecl::{ + ir::LineSize, + quant::scheme::{QuantScheme, QuantStore, QuantValue}, +}; use std::marker::PhantomData; /// Select the best vectorization factor for each tensor handle. @@ -78,7 +81,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { }); let mut ref_elem = (ElemType::UInt(UIntKind::U64).into(), 8); - let mut quants_line_sizes: Option> = None; + let mut quants_line_sizes: Option> = None; for input in plan.handle_inputs.iter() { let elem: StorageType = match input { @@ -124,7 +127,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { Some(line_sizes) => line_sizes, None => client .io_optimized_line_sizes_unchecked(ref_elem.0.size()) - .collect::>(), + .collect::>(), }; let vectorization_axis = runner.axis(plan); @@ -154,7 +157,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { tensors_reshaped, tensors_swapped, &line_sizes, - u8::MAX, + u8::MAX as usize, vectorization_axis, ); @@ -228,13 +231,13 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { match plan.vectorizations.get(&input_global.id).unwrap() { Vect::Aligned(vect) => { - let handle = &mut plan.handle_inputs[pos as usize]; + let handle = &mut plan.handle_inputs[pos]; match handle { HandleInput::Normal(handle) => { - handle.vectorization = *vect; + handle.line_size = *vect; } HandleInput::QuantValues(handle) => { - handle.vectorization = *vect; + handle.line_size = *vect; } HandleInput::QuantParams(_) => {} } @@ -255,7 +258,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { &mut plan.handle_inputs, &mut plan.handle_outputs, block_plan, - u8::MAX, + u8::MAX as usize, ); } VectorizationSetting::SmallerOrEqualThanPreviousBlock => { @@ -291,7 +294,7 @@ enum VectorizationAction { #[derive(Debug)] struct BlockVectorization { action: VectorizationAction, - potential: u8, + potential: LineSize, broadcasted: bool, } @@ -300,7 +303,7 @@ fn apply_vectorization_block( inputs: &mut [HandleInput], outputs: &mut [HandleOutput], block_plan: &mut BlockPlan, - max: u8, + max: LineSize, ) { for item in block_vectorization { match item.action { @@ -313,11 +316,11 @@ fn apply_vectorization_block( match &mut inputs[pos] { HandleInput::Normal(input) => { - input.vectorization = vect; + input.line_size = vect; input.broadcated = br; } HandleInput::QuantValues(input) => { - input.vectorization = vect; + input.line_size = vect; } HandleInput::QuantParams(_) => { // Not vectorized @@ -348,7 +351,7 @@ fn apply_vectorization_block( fn line_sizes_quants( client: &ComputeClient, - quants_line_sizes: &mut Option>, + quants_line_sizes: &mut Option>, scheme: QuantScheme, ) { match scheme.store { @@ -361,7 +364,7 @@ fn line_sizes_quants( | QuantValue::E2M1 => { let line_sizes = client .io_optimized_line_sizes_unchecked(size_of::()) - .collect::>(); + .collect::>(); match &quants_line_sizes { Some(sizes) => { @@ -381,9 +384,9 @@ fn line_sizes_quants( QuantStore::U32 => { let mut line_sizes = client .io_optimized_line_sizes_unchecked(size_of::()) - .collect::>(); + .collect::>(); for val in line_sizes.iter_mut() { - *val *= scheme.num_quants() as u8; + *val *= scheme.num_quants(); } match &quants_line_sizes { diff --git a/crates/burn-cubecl-fusion/src/engine/trace/base.rs b/crates/burn-cubecl-fusion/src/engine/trace/base.rs index da0fbe11d0..573c48cbdf 100644 --- a/crates/burn-cubecl-fusion/src/engine/trace/base.rs +++ b/crates/burn-cubecl-fusion/src/engine/trace/base.rs @@ -169,13 +169,13 @@ pub enum TensorView { Reshape { reshaped: TensorId, original: TensorId, - reshape_pos: u32, + reshape_pos: usize, shape_relative: Vec, }, SwapDims { swapped: TensorId, original: TensorId, - dims: (u32, u32), + dims: (usize, usize), }, } @@ -227,7 +227,7 @@ impl RegisteredTensors { } /// Doesn't return quantized tensor. - pub fn get_index(&self, tensor_id: TensorId) -> Option { + pub fn get_index(&self, tensor_id: TensorId) -> Option { self.tensors .iter() .enumerate() @@ -236,11 +236,11 @@ impl RegisteredTensors { RegisterTensor::QuantValues(_) => false, RegisterTensor::QuantParams(_) => false, }) - .map(|(pos, _)| pos as u32) + .map(|(pos, _)| pos) } /// Get the index of a quantized tensor. - pub fn get_index_quant(&self, tensor_id: TensorId) -> Option { + pub fn get_index_quant(&self, tensor_id: TensorId) -> Option { self.tensors .iter() .enumerate() @@ -249,7 +249,7 @@ impl RegisteredTensors { RegisterTensor::QuantValues(tensor_ir) => tensor_ir.id == tensor_id, RegisterTensor::QuantParams(_) => false, }) - .map(|(pos, _)| pos as u32) + .map(|(pos, _)| pos) } /// Doesn't return quantized tensor. @@ -273,12 +273,12 @@ impl RegisteredTensors { /// Insert a quantized tensor. /// /// It will return the positions for both the value tensor and param tensor. - pub fn insert_quant(&mut self, tensor: TensorIr) -> (u32, u32) { + pub fn insert_quant(&mut self, tensor: TensorIr) -> (usize, usize) { if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val { RegisterTensor::QuantValues(tensor_ir) => tensor_ir == &tensor, _ => false, }) { - let values = old.0 as u32; + let values = old.0; let params = values + 1; return (values, params); } @@ -291,16 +291,16 @@ impl RegisteredTensors { let pos_params = self.len(); self.tensors.push(params); - (pos_values as u32, pos_params as u32) + (pos_values, pos_params) } /// Insert a normal tensor with the given [precision](FusePrecision) in the current block. - pub fn insert(&mut self, precision: FuseType, tensor: TensorIr) -> u32 { + pub fn insert(&mut self, precision: FuseType, tensor: TensorIr) -> usize { if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val { RegisterTensor::Normal(tensor_ir, _) => tensor_ir == &tensor, _ => false, }) { - return old.0 as u32; + return old.0; } let value = RegisterTensor::Normal(tensor, precision); @@ -308,7 +308,7 @@ impl RegisteredTensors { self.tensors.push(value); - pos as u32 + pos } /// Update the already registered tensor with the given [tensor ir](TensorIr). diff --git a/crates/burn-cubecl-fusion/src/engine/trace/block.rs b/crates/burn-cubecl-fusion/src/engine/trace/block.rs index 6cd295dfd1..764a50810d 100644 --- a/crates/burn-cubecl-fusion/src/engine/trace/block.rs +++ b/crates/burn-cubecl-fusion/src/engine/trace/block.rs @@ -4,7 +4,6 @@ use crate::engine::{ }; use burn_ir::{TensorId, TensorIr, TensorStatus}; use burn_std::{DType, quantization::QuantParam}; -use cubecl::prelude::Sequence; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, btree_map::Entry}; @@ -210,7 +209,7 @@ impl FuseBlockBuilder { &mut self, tensor: &TensorIr, output: &TensorIr, - dims: (u32, u32), + dims: (usize, usize), resources: &mut FuseResources, ) -> Option { if matches!(tensor.dtype, DType::QFloat(..)) { @@ -317,7 +316,7 @@ impl FuseBlockBuilder { let out = self.output(output, resources)?; let original = FuseArg::Input(input_index, precision_input, LayoutInfo::Unknown); - let mut shape = Sequence::new(); + let mut shape = Vec::new(); let index = resources.num_reshaped; resources.num_reshaped += 1; @@ -326,13 +325,13 @@ impl FuseBlockBuilder { for i in 0..output.shape.rank() { let id = index * rank + i; - shape.push(FuseArg::ScalarShape(id as u32)); + shape.push(FuseArg::ScalarShape(id)); } resources.views.push(TensorView::Reshape { reshaped: output.id, original: tensor.id, - reshape_pos: index as u32, + reshape_pos: index, shape_relative: output.shape.dims.clone(), }); @@ -381,11 +380,7 @@ impl FuseBlockBuilder { tensor.id, FuseOp::Assign(UnaryFuseArgs { input: local, - out: FuseArg::Output( - out_index + offset as u32, - *precision, - LayoutInfo::Unknown, - ), + out: FuseArg::Output(out_index + offset, *precision, LayoutInfo::Unknown), }), ); } @@ -657,7 +652,7 @@ impl FuseBlockBuilder { #[derive(Default, Clone, Debug)] struct LocalVariablePool { - values: BTreeMap>, + values: BTreeMap>, } impl LocalVariablePool { @@ -681,7 +676,7 @@ impl LocalVariablePool { None } - fn find_tensor_id(&self, precision: FuseType, position: u32) -> Option { + fn find_tensor_id(&self, precision: FuseType, position: usize) -> Option { if let Some(indexes) = self.values.get(&precision) { indexes .iter() @@ -694,7 +689,7 @@ impl LocalVariablePool { fn create(&mut self, precision: FuseType, tensor_id: TensorId) -> FuseArg { if let Some(indexes) = self.values.get_mut(&precision) { - let new_index = indexes.len() as u32; + let new_index = indexes.len(); indexes.insert(tensor_id, new_index); return FuseArg::Local(new_index, precision); } diff --git a/crates/burn-cubecl-fusion/src/engine/trace/fuser.rs b/crates/burn-cubecl-fusion/src/engine/trace/fuser.rs index a49a4dbca2..00c62c43e4 100644 --- a/crates/burn-cubecl-fusion/src/engine/trace/fuser.rs +++ b/crates/burn-cubecl-fusion/src/engine/trace/fuser.rs @@ -205,7 +205,7 @@ impl TraceFuser { &mut self, tensor: &TensorIr, output: &TensorIr, - dims: (u32, u32), + dims: (usize, usize), ) -> Option { if matches!(tensor.dtype, DType::QFloat(_)) { return None; @@ -238,7 +238,7 @@ impl TraceFuser { FuseType::Bool => self.bool_precision, _ => precision, }; - let new_index = self.resources.scalars.len() as u32; + let new_index = self.resources.scalars.len(); self.resources.scalars.push((precision, id.value)); FuseArg::Scalar(new_index, precision) diff --git a/crates/burn-cubecl-fusion/src/optim/elemwise/optimization.rs b/crates/burn-cubecl-fusion/src/optim/elemwise/optimization.rs index ab6dff4893..e6b895a914 100644 --- a/crates/burn-cubecl-fusion/src/optim/elemwise/optimization.rs +++ b/crates/burn-cubecl-fusion/src/optim/elemwise/optimization.rs @@ -86,13 +86,13 @@ impl TraceRunner for ElemwiseRunner { let config = &configs[0]; let shape = match &config.ref_layout { RefLayout::Concrete(arg) => match arg { - FuseArg::Input(..) => inputs.shape_ref(&config.ref_layout, config.rank as usize), - FuseArg::Output(..) => outputs.shape_ref(&config.ref_layout, config.rank as usize), + FuseArg::Input(..) => inputs.shape_ref(&config.ref_layout, config.rank), + FuseArg::Output(..) => outputs.shape_ref(&config.ref_layout, config.rank), _ => panic!("Invalid concreate ref layout"), }, - RefLayout::Virtual(_) => inputs.shape_ref(&config.ref_layout, config.rank as usize), + RefLayout::Virtual(_) => inputs.shape_ref(&config.ref_layout, config.rank), }; - let working_units = shape.iter().product::() / config.width as usize; + let working_units = shape.iter().product::() / config.width; let cube_dim = CubeDim::new(client, working_units); let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim); @@ -119,7 +119,7 @@ fn elemwise_fuse( ) { // We write no values for this fusion. let values = Registry::>::new(); - let args = comptime![Sequence::::new()]; + let args = comptime![Vec::::new()]; let pos = ABSOLUTE_POS; let mut locals = init_locals(inputs, outputs, config); diff --git a/crates/burn-cubecl-fusion/src/optim/matmul/args.rs b/crates/burn-cubecl-fusion/src/optim/matmul/args.rs index 538c0ac9f5..4dac36ea89 100644 --- a/crates/burn-cubecl-fusion/src/optim/matmul/args.rs +++ b/crates/burn-cubecl-fusion/src/optim/matmul/args.rs @@ -16,7 +16,7 @@ use cubecl::{ }, tensor::{ View, ViewExpand, - layout::{Coords1d, Coords2d, Coords3d, VirtualLayout}, + layout::{Coords1d, Coords2d, VirtualLayout}, }, }, }; @@ -26,7 +26,7 @@ use cubek::matmul::{ GlobalScaleLayout, GlobalScaleLayoutExpand, NoopLayout, }, definition::MatrixLayout, - launch::MatmulArgs, + launch::{BatchedCoords, MatmulArgs}, }; use serde::{Deserialize, Serialize}; use std::marker::PhantomData; @@ -70,7 +70,7 @@ impl MatmulArgs for FusedMatmulArgs { #[unroll] for i in 0..rank - 2 { - batch_shape.push(FastDivmod::new_Fallback(locals.ref_shape[i])); + batch_shape.push(FastDivmod::new_Fallback(locals.ref_shape[i] as u32)); batch_strides_out.push(locals.ref_strides[i]); } @@ -115,27 +115,27 @@ impl MatmulArgs for FusedMatmulArgs { fn view_lhs( state: &Self::State, - ) -> View, Coords3d> { + ) -> View, BatchedCoords> { global_view( &state.inputs, &state.locals, &state.batch_shape, comptime![state.a.clone()], comptime![state.config.clone()], - comptime![state.lhs_layout_config], + state.lhs_layout_config, ) } fn batch_lhs( state: &Self::State, - batch: u32, - ) -> u32 { + batch: usize, + ) -> usize { state.a_batch.to_source_pos(batch) } fn view_rhs( state: &Self::State, - ) -> View, Coords3d> { + ) -> View, BatchedCoords> { global_view( &state.inputs, &state.locals, @@ -148,14 +148,14 @@ impl MatmulArgs for FusedMatmulArgs { fn batch_rhs( state: &Self::State, - batch: u32, - ) -> u32 { + batch: usize, + ) -> usize { state.b_batch.to_source_pos(batch) } fn view_acc( state: &Self::State, - ) -> CubeOption, Coords3d>> { + ) -> CubeOption, BatchedCoords>> { match comptime![state.c.clone()] { Option::Some(c) => { let view = global_view( @@ -174,8 +174,8 @@ impl MatmulArgs for FusedMatmulArgs { fn batch_acc( state: &Self::State, - batch: u32, - ) -> u32 { + batch: usize, + ) -> usize { match state.c_batch { CubeOption::Some(c_batch) => c_batch.to_source_pos(batch), CubeOption::None => batch, @@ -184,11 +184,11 @@ impl MatmulArgs for FusedMatmulArgs { fn view_out( state: &mut Self::State, - ) -> View, Coords3d, ReadWrite> { + ) -> View, BatchedCoords, ReadWrite> { let rank = comptime![state.config.rank]; - let shape_row = state.locals.ref_shape[rank - 2]; - let shape_col = state.locals.ref_shape[rank - 1]; + let shape_row = state.locals.ref_shape[rank - 2] as u32; + let shape_col = state.locals.ref_shape[rank - 1] as u32; let stride_row = state.locals.ref_strides[rank - 2]; let stride_col = state.locals.ref_strides[rank - 1]; @@ -215,8 +215,8 @@ impl MatmulArgs for FusedMatmulArgs { fn batch_out( state: &Self::State, - batch: u32, - ) -> u32 { + batch: usize, + ) -> usize { state.out_batch.to_source_pos(batch) } } @@ -229,7 +229,7 @@ fn global_view( #[comptime] arg: MatmulArg, #[comptime] config: FuseBlockConfig, #[comptime] layout_config: GlobalLayoutConfig, -) -> View, Coords3d> { +) -> View, BatchedCoords> { let rank = comptime![config.rank]; let data = comptime![arg.data().clone()]; let data_tensor = match comptime![data.clone()] { @@ -237,13 +237,13 @@ fn global_view( _ => panic!("Input must be concrete"), }; - let mut shape_row = data_tensor.tensor.shape(rank - 2); - let mut shape_col = data_tensor.tensor.shape(rank - 1); - let mut packing = comptime![1u32]; + let mut shape_row = data_tensor.tensor.shape(rank - 2) as u32; + let mut shape_col = data_tensor.tensor.shape(rank - 1) as u32; + let mut packing = comptime![1]; - if comptime![arg.scheme().is_some()] { - let scheme = comptime![arg.scheme().unwrap()]; - let num_quants = comptime![scheme.num_quants() as u32]; + if arg.scheme().is_some() { + let scheme = arg.scheme().unwrap(); + let num_quants = scheme.num_quants() as u32; comptime![packing = num_quants]; match comptime![layout_config.matrix_layout] { MatrixLayout::RowMajor => shape_col *= num_quants, @@ -266,8 +266,8 @@ fn global_view( inputs, shape, batch_layout, - comptime![arg.data().clone()], - comptime![config.clone()], + arg.data().clone(), + config.clone(), data_tensor.tensor.line_size(), layout_config, packing, @@ -296,7 +296,7 @@ fn global_view( batch_layout, comptime![scales.clone()], comptime![config.clone()], - 1u32, + 1usize, layout_config, 1u32, ); @@ -319,7 +319,7 @@ fn input_batch_layout( batch_shape: &Sequence, #[comptime] arg: MatmulArg, #[comptime] config: FuseBlockConfig, -) -> VirtualLayout { +) -> VirtualLayout { let rank = comptime![config.rank]; match comptime![arg.clone()] { MatmulArg::Normal(arg) => { @@ -346,10 +346,10 @@ fn input_batch_layout( fn global_layout( inputs: &GlobalArgs, shape: Coords2d, - batch_layout: VirtualLayout, + batch_layout: VirtualLayout, #[comptime] arg: FuseArg, #[comptime] config: FuseBlockConfig, - #[comptime] line_size: u32, + #[comptime] line_size: LineSize, #[comptime] layout_config: GlobalLayoutConfig, #[comptime] packing: u32, ) -> GlobalLayout { @@ -387,7 +387,7 @@ struct CreateQuantView<'a, E: Numeric> { } impl<'a, E: Numeric> RunWithQuantType for CreateQuantView<'a, E> { - type Output = ViewExpand, Coords3d>; + type Output = ViewExpand, BatchedCoords>; fn execute(self) -> Self::Output { create_quant_view::expand::( @@ -409,7 +409,7 @@ fn create_quant_view_dynamic( scales_buf: GlobalInput, scales_layout: GlobalScaleLayout, #[comptime] scheme: QuantScheme, -) -> View, Coords3d> { +) -> View, BatchedCoords> { intrinsic!(|scope| { let func = CreateQuantView { scope, @@ -431,10 +431,10 @@ fn create_quant_view( scales_buf: GlobalInput, scales_layout: GlobalScaleLayout, #[comptime] scheme: QuantScheme, -) -> View, Coords3d> { - let data_view: View, Coords3d> = +) -> View, BatchedCoords> { + let data_view: View, BatchedCoords> = View::new::(&data_buf, data_layout); - let scales_view: View = + let scales_view: View = View::new::(&scales_buf, scales_layout); QuantizedView::new(data_view, scales_view, scheme).view() } @@ -474,10 +474,10 @@ impl FusedMatmulState { inputs: &FusedMatmulInput, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - a_batch: VirtualLayout, - b_batch: VirtualLayout, - c_batch: CubeOption>, - out_batch: VirtualLayout, + a_batch: VirtualLayout, + b_batch: VirtualLayout, + c_batch: CubeOption>, + out_batch: VirtualLayout, batch_shape: Sequence, #[comptime] config: &FuseBlockConfig, #[comptime] lhs_layout_config: GlobalLayoutConfig, diff --git a/crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs b/crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs index f5f4dbad16..1805b68557 100644 --- a/crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs +++ b/crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs @@ -411,7 +411,7 @@ impl FusedMatmulLaunch<'_> { ) -> Result<(), FusedMatmulError> { let lhs_shape = inputs.shape(self.matmul.lhs.data()); let rhs_shape = inputs.shape(self.matmul.rhs.data()); - let out_shape = outputs.shape_ref(&config.ref_layout, config.rank as usize); + let out_shape = outputs.shape_ref(&config.ref_layout, config.rank); let lhs_strides = inputs.strides(self.matmul.lhs.data()); let rhs_strides = inputs.strides(self.matmul.rhs.data()); @@ -447,10 +447,10 @@ impl FusedMatmulLaunch<'_> { } if let MatmulArg::Quantized { scheme, .. } = self.matmul.lhs { - line_sizes.lhs *= scheme.num_quants() as u8; + line_sizes.lhs *= scheme.num_quants(); } if let MatmulArg::Quantized { scheme, .. } = self.matmul.rhs { - line_sizes.rhs *= scheme.num_quants() as u8; + line_sizes.rhs *= scheme.num_quants(); } let out_strides = MatrixLayout::RowMajor.to_strides(&out_shape); diff --git a/crates/burn-cubecl-fusion/src/optim/reduce/args.rs b/crates/burn-cubecl-fusion/src/optim/reduce/args.rs index e56fcac60c..acf901494b 100644 --- a/crates/burn-cubecl-fusion/src/optim/reduce/args.rs +++ b/crates/burn-cubecl-fusion/src/optim/reduce/args.rs @@ -65,8 +65,8 @@ impl ReduceArgs for FusedReduceArgs { FusedReduceState::new(input, output, &mut locals_read, &mut locals_write) } - fn read_input(state: &Self::State

, index: u32) -> Line { - *fuse_on_read::( + fn read_input(state: &Self::State

, index: usize) -> Line { + fuse_on_read::( unsafe { &(*state.inputs) }, unsafe { &mut (*state.outputs) }, unsafe { &mut (*state.locals_on_read) }, @@ -77,17 +77,16 @@ impl ReduceArgs for FusedReduceArgs { sequence }, &state.config_on_read, - ) - .index(0) + )[0] } - fn read_output(_state: &Self::State

, _index: u32) -> Line { - Line::empty(1_u32) + fn read_output(_state: &Self::State

, _index: usize) -> Line { + Line::empty(1usize) } - fn write_output(state: &mut Self::State

, index: u32, value: Line) { + fn write_output(state: &mut Self::State

, index: usize, value: Line) { let mut values = Registry::>::new(); - let mut args = comptime![Sequence::::new()]; + let mut args = comptime![Vec::::new()]; values.insert(comptime![state.out.clone()], value); comptime![args.push(state.out.clone())]; @@ -103,7 +102,7 @@ impl ReduceArgs for FusedReduceArgs { ); } - fn len_input(state: &Self::State

) -> u32 { + fn len_input(state: &Self::State

) -> usize { ref_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, @@ -112,7 +111,7 @@ impl ReduceArgs for FusedReduceArgs { ) } - fn len_output(state: &Self::State

) -> u32 { + fn len_output(state: &Self::State

) -> usize { ref_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, @@ -121,7 +120,7 @@ impl ReduceArgs for FusedReduceArgs { ) } - fn buffer_len_input(state: &Self::State

) -> u32 { + fn buffer_len_input(state: &Self::State

) -> usize { ref_buffer_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, @@ -130,7 +129,7 @@ impl ReduceArgs for FusedReduceArgs { ) } - fn buffer_len_output(state: &Self::State

) -> u32 { + fn buffer_len_output(state: &Self::State

) -> usize { ref_buffer_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, @@ -139,35 +138,35 @@ impl ReduceArgs for FusedReduceArgs { ) } - fn rank_input(state: &Self::State

) -> u32 { + fn rank_input(state: &Self::State

) -> usize { state.config_on_read.rank.runtime() } - fn rank_output(state: &Self::State

) -> u32 { + fn rank_output(state: &Self::State

) -> usize { state.config_on_write.rank.runtime() } - fn shape_input(state: &Self::State

, dim: u32) -> u32 { + fn shape_input(state: &Self::State

, dim: usize) -> usize { ref_shape(unsafe { &(*state.locals_on_read) }, dim) } - fn shape_output(state: &Self::State

, dim: u32) -> u32 { + fn shape_output(state: &Self::State

, dim: usize) -> usize { ref_shape(unsafe { &(*state.locals_on_write) }, dim) } - fn stride_input(state: &Self::State

, dim: u32) -> u32 { + fn stride_input(state: &Self::State

, dim: usize) -> usize { ref_stride(unsafe { &(*state.locals_on_read) }, dim) } - fn stride_output(state: &Self::State

, dim: u32) -> u32 { + fn stride_output(state: &Self::State

, dim: usize) -> usize { ref_stride(unsafe { &(*state.locals_on_write) }, dim) } - fn line_size_input(state: &Self::State

) -> comptime_type!(u32) { + fn line_size_input(state: &Self::State

) -> comptime_type!(LineSize) { ref_line_size(unsafe { &(*state.locals_on_read) }) } - fn line_size_output(state: &Self::State

) -> comptime_type!(u32) { + fn line_size_output(state: &Self::State

) -> comptime_type!(LineSize) { ref_line_size(unsafe { &(*state.locals_on_write) }) } } diff --git a/crates/burn-cubecl-fusion/src/optim/reduce/optimization.rs b/crates/burn-cubecl-fusion/src/optim/reduce/optimization.rs index 5a86858475..8081588e42 100644 --- a/crates/burn-cubecl-fusion/src/optim/reduce/optimization.rs +++ b/crates/burn-cubecl-fusion/src/optim/reduce/optimization.rs @@ -279,17 +279,17 @@ impl TraceRunner for FusedReduceLaunch<'_> { let [config_read, config_write] = [&configs[0], &configs[1]]; let shape = match &config_read.ref_layout { RefLayout::Concrete(FuseArg::Output(..)) => { - outputs.shape_ref(&config_read.ref_layout, config_read.rank as usize) + outputs.shape_ref(&config_read.ref_layout, config_read.rank) } - _ => inputs.shape_ref(&config_read.ref_layout, config_read.rank as usize), + _ => inputs.shape_ref(&config_read.ref_layout, config_read.rank), }; - let reduce_count: u32 = shape + let reduce_count: usize = shape .iter() .enumerate() - .map(|(i, s)| if i == self.reduce.axis { 1 } else { *s as u32 }) + .map(|(i, s)| if i == self.reduce.axis { 1 } else { *s }) .product(); - let line_mode = match self.reduce.axis == config_read.rank as usize - 1 { + let line_mode = match self.reduce.axis == config_read.rank - 1 { true => LineMode::Parallel, false => LineMode::Perpendicular, }; @@ -300,9 +300,9 @@ impl TraceRunner for FusedReduceLaunch<'_> { line_size_output: config_write.width, }; let problem = ReduceProblem { - vector_size: shape[self.reduce.axis] as u32, + vector_size: shape[self.reduce.axis], vector_count: reduce_count, - axis: self.reduce.axis as u32, + axis: self.reduce.axis, dtypes: ReduceDtypes { input: self.reduce.op.input.dtype.into(), output: self.reduce.op.out.dtype.into(), @@ -329,7 +329,7 @@ impl TraceRunner for FusedReduceLaunch<'_> { client, inputs, outputs, - axis: self.reduce.axis as u32, + axis: self.reduce.axis, config_fuse_read: config_read.clone(), config_fuse_write: config_write.clone(), input: self.reduce.input.clone(), @@ -356,7 +356,7 @@ struct ReduceKwArgs<'a, 'b, Run: Runtime> { client: &'b ComputeClient, inputs: GlobalArgsLaunch<'a, Run>, outputs: GlobalArgsLaunch<'a, Run>, - axis: u32, + axis: usize, blueprint: ReduceBlueprint, settings: ReduceLaunchSettings, config_fuse_read: FuseBlockConfig, @@ -413,7 +413,7 @@ fn launch_reduce( pub fn reduce_kernel( input: &FusedReduceInput, output: &mut FusedReduceOutput, - axis_reduce: u32, + axis_reduce: usize, #[comptime] blueprint: ReduceBlueprint, #[comptime] config: ReduceOperationConfig, #[define(In)] _input_dtype: StorageType, diff --git a/crates/burn-cubecl/src/kernel/binary.rs b/crates/burn-cubecl/src/kernel/binary.rs index 97267ee379..0ea46ff4bc 100644 --- a/crates/burn-cubecl/src/kernel/binary.rs +++ b/crates/burn-cubecl/src/kernel/binary.rs @@ -6,9 +6,7 @@ use crate::{ }; use burn_backend::{bf16, f16}; use cubecl::{ - calculate_cube_count_elemwise, intrinsic, - prelude::*, - std::{scalar::InputScalar, tensor::layout::linear::LinearView}, + calculate_cube_count_elemwise, intrinsic, prelude::*, std::tensor::layout::linear::LinearView, }; pub(crate) trait BinaryOpFamily: Send + Sync + 'static { diff --git a/crates/burn-cubecl/src/kernel/binary_int.rs b/crates/burn-cubecl/src/kernel/binary_int.rs index 315b271512..07a6310bb1 100644 --- a/crates/burn-cubecl/src/kernel/binary_int.rs +++ b/crates/burn-cubecl/src/kernel/binary_int.rs @@ -4,11 +4,7 @@ use crate::{ ops::{max_line_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; -use cubecl::{ - calculate_cube_count_elemwise, - prelude::*, - std::{scalar::InputScalar, tensor::layout::linear::LinearView}, -}; +use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait BinaryOpIntFamily: Send + Sync + 'static { type BinaryOp: BinaryOpInt; diff --git a/crates/burn-cubecl/src/kernel/clamp.rs b/crates/burn-cubecl/src/kernel/clamp.rs index 589188d941..46059aaf52 100644 --- a/crates/burn-cubecl/src/kernel/clamp.rs +++ b/crates/burn-cubecl/src/kernel/clamp.rs @@ -1,4 +1,4 @@ -use cubecl::{prelude::*, std::scalar::InputScalar}; +use cubecl::prelude::*; use crate::{ CubeRuntime, diff --git a/crates/burn-cubecl/src/kernel/comparison.rs b/crates/burn-cubecl/src/kernel/comparison.rs index 557b97af99..646c68489e 100644 --- a/crates/burn-cubecl/src/kernel/comparison.rs +++ b/crates/burn-cubecl/src/kernel/comparison.rs @@ -5,11 +5,7 @@ use crate::{ tensor::CubeTensor, }; use burn_backend::DType; -use cubecl::{ - calculate_cube_count_elemwise, - prelude::*, - std::{scalar::InputScalar, tensor::layout::linear::LinearView}, -}; +use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; #[cube] pub(crate) trait ComparisonOpFamily: 'static + Send + Sync { diff --git a/crates/burn-cubecl/src/kernel/contiguous.rs b/crates/burn-cubecl/src/kernel/contiguous.rs index b58b6e22ce..95f7e3e69f 100644 --- a/crates/burn-cubecl/src/kernel/contiguous.rs +++ b/crates/burn-cubecl/src/kernel/contiguous.rs @@ -14,7 +14,7 @@ pub fn into_contiguous(tensor: CubeTensor) -> CubeTensor { return into_contiguous_quantized(tensor, AllocationKind::Contiguous); } - let output = cubecl::std::tensor::into_contiguous( + let output = cubecl::std::tensor::into_contiguous_ref( &tensor.client, &tensor.as_handle_ref(), tensor.dtype.into(), @@ -46,7 +46,7 @@ pub fn into_contiguous_aligned(tensor: CubeTensor) -> CubeTen return into_contiguous_quantized(tensor, AllocationKind::Optimized); } - let output = cubecl::std::tensor::into_contiguous_pitched( + let output = cubecl::std::tensor::into_contiguous_pitched_ref( &tensor.client, &tensor.as_handle_ref(), tensor.dtype.into(), @@ -83,7 +83,7 @@ fn into_contiguous_quantized( &values.as_handle_ref(), &out_values.as_handle_ref(), &tensor.shape, - scheme.num_quants() as u32, + scheme.num_quants(), DType::U32.into(), ) .expect("Kernel to never fail"); @@ -102,7 +102,7 @@ fn into_contiguous_quantized( .expect("Kernel to never fail"); } QuantStore::Native => { - cubecl::std::tensor::into_contiguous_ref( + cubecl::std::tensor::copy_into( &values.client, &values.as_handle_ref(), &out_values.as_handle_ref(), @@ -112,7 +112,7 @@ fn into_contiguous_quantized( } } - cubecl::std::tensor::into_contiguous_ref( + cubecl::std::tensor::copy_into( &scales.client, &scales.as_handle_ref(), &out_scales.as_handle_ref(), diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs index 0818532706..c0bf4db27f 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs @@ -209,17 +209,17 @@ fn col2im( bias.as_tensor_arg(vectorization), out.as_tensor_arg(vectorization), Col2ImArgsLaunch::new( - ScalarArg::new(out_h as u32), - ScalarArg::new(out_w as u32), - ScalarArg::new(kernel_h as u32), - ScalarArg::new(kernel_w as u32), - ScalarArg::new(options.padding[0] as u32), - ScalarArg::new(options.padding[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(col_size_1 as u32), + ScalarArg::new(out_h), + ScalarArg::new(out_w), + ScalarArg::new(kernel_h), + ScalarArg::new(kernel_w), + ScalarArg::new(options.padding[0]), + ScalarArg::new(options.padding[1]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(col_size_1), ), has_bias, dtype.into(), @@ -229,20 +229,20 @@ fn col2im( #[derive(CubeLaunch, CubeType)] struct Col2ImArgs { - out_h: u32, - out_w: u32, + out_h: usize, + out_w: usize, - kernel_h: u32, - kernel_w: u32, + kernel_h: usize, + kernel_w: usize, - pad_h: u32, - pad_w: u32, - dilation_h: u32, - dilation_w: u32, - stride_h: u32, - stride_w: u32, + pad_h: usize, + pad_w: usize, + dilation_h: usize, + dilation_w: usize, + stride_h: usize, + stride_w: usize, - col_size_1: u32, + col_size_1: usize, } #[cube(launch_unchecked)] @@ -271,13 +271,13 @@ fn col2im_kernel( let x_col_start = if im_x >= kernel_extent_w { (im_x - kernel_extent_w) / args.stride_w + 1 } else { - 0u32.runtime() + 0usize.runtime() }; let x_col_end = Min::min(im_x / args.stride_w + 1, args.out_w); let y_col_start = if im_y >= kernel_extent_h { (im_y - kernel_extent_h) / args.stride_h + 1 } else { - 0u32.runtime() + 0usize.runtime() }; let y_col_end = Min::min(im_y / args.stride_h + 1, args.out_h); diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs index d0abd6e967..3d61210108 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs @@ -13,13 +13,13 @@ use cubek::convolution::components::ConvSetupError; #[derive(CubeLaunch, CubeType)] struct ConvArgs { - conv_stride_0: u32, - conv_stride_1: u32, - dilation_0: u32, - dilation_1: u32, - padding_0: u32, - padding_1: u32, - groups: u32, + conv_stride_0: usize, + conv_stride_1: usize, + dilation_0: usize, + dilation_1: usize, + padding_0: usize, + padding_1: usize, + groups: usize, } #[cube(launch)] @@ -61,10 +61,10 @@ fn conv_transpose2d_direct_kernel( let y_start = ((out_y + args.padding_0) as i32 - kms_h) / stride_0_i; let x_start = ((out_x + args.padding_1) as i32 - kms_w) / stride_1_i; - let y_end = Min::min(Max::max(kms_h + y_start + 1, 0) as u32, input.shape(2)); - let x_end = Min::min(Max::max(kms_w + x_start + 1, 0) as u32, input.shape(3)); - let y_start = Max::max(y_start, 0) as u32; - let x_start = Max::max(x_start, 0) as u32; + let y_end = Min::min(Max::max(kms_h + y_start + 1, 0) as usize, input.shape(2)); + let x_end = Min::min(Max::max(kms_w + x_start + 1, 0) as usize, input.shape(3)); + let y_start = Max::max(y_start, 0) as usize; + let x_start = Max::max(x_start, 0) as usize; let idx_input_batch = batch * input.stride(0); let idx_weight_oc = out_c * weight.stride(1); @@ -183,13 +183,13 @@ pub fn conv_transpose2d_direct( bias.as_tensor_arg(1), output.as_tensor_arg(1), ConvArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), - ScalarArg::new(options.padding[0] as u32), - ScalarArg::new(options.padding[1] as u32), - ScalarArg::new(options.groups as u32), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), + ScalarArg::new(options.padding[0]), + ScalarArg::new(options.padding[1]), + ScalarArg::new(options.groups), ), input.dtype.into(), )?; diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs index 329ede2383..96fdfac9fa 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs @@ -13,16 +13,16 @@ use burn_backend::{Shape, ops::ConvTransposeOptions}; #[derive(CubeLaunch, CubeType)] struct ConvArgs { - conv_stride_0: u32, - conv_stride_1: u32, - conv_stride_2: u32, - dilation_0: u32, - dilation_1: u32, - dilation_2: u32, - padding_0: u32, - padding_1: u32, - padding_2: u32, - groups: u32, + conv_stride_0: usize, + conv_stride_1: usize, + conv_stride_2: usize, + dilation_0: usize, + dilation_1: usize, + dilation_2: usize, + padding_0: usize, + padding_1: usize, + padding_2: usize, + groups: usize, } #[cube(launch)] @@ -68,13 +68,13 @@ fn conv_transpose3d_kernel( let y_start = ((out_y + args.padding_1) as i32 - kernel_h) / stride_1_i; let x_start = ((out_x + args.padding_2) as i32 - kernel_w) / stride_2_i; - let z_end = Min::min(Max::max(kernel_d + z_start + 1, 0) as u32, input.shape(2)); - let y_end = Min::min(Max::max(kernel_h + y_start + 1, 0) as u32, input.shape(3)); - let x_end = Min::min(Max::max(kernel_w + x_start + 1, 0) as u32, input.shape(4)); + let z_end = Min::min(Max::max(kernel_d + z_start + 1, 0) as usize, input.shape(2)); + let y_end = Min::min(Max::max(kernel_h + y_start + 1, 0) as usize, input.shape(3)); + let x_end = Min::min(Max::max(kernel_w + x_start + 1, 0) as usize, input.shape(4)); - let z_start = Max::max(z_start, 0) as u32; - let y_start = Max::max(y_start, 0) as u32; - let x_start = Max::max(x_start, 0) as u32; + let z_start = Max::max(z_start, 0) as usize; + let y_start = Max::max(y_start, 0) as usize; + let x_start = Max::max(x_start, 0) as usize; let index_input_batch = batch * input.stride(0); let index_weight_out_c = out_channel * weight.stride(1); @@ -218,16 +218,16 @@ pub(crate) fn conv_transpose3d( bias.as_tensor_arg(1), output.as_tensor_arg(1), ConvArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.stride[2] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), - ScalarArg::new(options.dilation[2] as u32), - ScalarArg::new(options.padding[0] as u32), - ScalarArg::new(options.padding[1] as u32), - ScalarArg::new(options.padding[2] as u32), - ScalarArg::new(options.groups as u32), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(options.stride[2]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), + ScalarArg::new(options.dilation[2]), + ScalarArg::new(options.padding[0]), + ScalarArg::new(options.padding[1]), + ScalarArg::new(options.padding[2]), + ScalarArg::new(options.groups), ), input.dtype.into(), )?; diff --git a/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs b/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs index 8a2b1059b8..28ca0baf42 100644 --- a/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs @@ -1,4 +1,4 @@ -use cubecl::{calculate_cube_count_elemwise, prelude::*, std::scalar::InputScalar}; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; use cubek::convolution::components::ConvSetupError; use burn_backend::{ @@ -21,20 +21,20 @@ use crate::{ #[derive(CubeLaunch, CubeType)] struct DeformConv2dArgs { - conv_stride_h: u32, - conv_stride_w: u32, - dilation_h: u32, - dilation_w: u32, + conv_stride_h: usize, + conv_stride_w: usize, + dilation_h: usize, + dilation_w: usize, padding_h: InputScalar, padding_w: InputScalar, - offset_groups: u32, + offset_groups: usize, - kernel_height: u32, - kernel_width: u32, - out_h: u32, - out_w: u32, + kernel_height: usize, + kernel_width: usize, + out_h: usize, + out_w: usize, - col_stride_0: u32, + col_stride_0: usize, } #[cube(launch)] @@ -44,8 +44,8 @@ fn deform_im2col_kernel( mask: &Tensor, columns: &mut Tensor, args: &DeformConv2dArgs, - #[comptime] kernel_h_unroll: Option, - #[comptime] kernel_w_unroll: Option, + #[comptime] kernel_h_unroll: Option, + #[comptime] kernel_w_unroll: Option, #[comptime] use_mask: bool, #[define(F)] _dtype: StorageType, ) { @@ -131,11 +131,11 @@ fn deform_im2col_kernel( #[cube] pub(crate) fn bilinear_interpolate( input: &Tensor, - height: u32, - width: u32, + height: usize, + width: usize, y: F, x: F, - offset: u32, + offset: usize, ) -> F { // To simplify code let y = f32::cast_from(y); @@ -143,26 +143,26 @@ pub(crate) fn bilinear_interpolate( let mut result = F::new(0.0); if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x { - let in_w = u32::cast_from(width); + let in_w = width; let y_low = f32::floor(y); let x_low = f32::floor(x); - let y_high = (y_low + 1.) as u32; - let x_high = (x_low + 1.) as u32; + let y_high = (y_low + 1.) as usize; + let x_high = (x_low + 1.) as usize; let zero = F::new(0.0); let v1: F = if y_low >= 0. && x_low >= 0. { - input[offset + y_low as u32 * in_w + x_low as u32] + input[offset + y_low as usize * in_w + x_low as usize] } else { zero }; let v2: F = if y_low >= 0. && x_high < width { - input[offset + y_low as u32 * in_w + x_high] + input[offset + y_low as usize * in_w + x_high] } else { zero }; let v3: F = if y_high < height && x_low >= 0. { - input[offset + y_high * in_w + x_low as u32] + input[offset + y_high * in_w + x_low as usize] } else { zero }; @@ -237,10 +237,10 @@ pub(crate) fn deform_im2col( mask.as_handle_ref().as_tensor_arg(1), output.as_handle_ref().as_tensor_arg(1), DeformConv2dArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), { let val = options.padding[0] as f32; InputScalar::new(val, dtype) @@ -249,15 +249,15 @@ pub(crate) fn deform_im2col( let val = options.padding[1] as f32; InputScalar::new(val, dtype) }, - ScalarArg::new(options.offset_groups as u32), - ScalarArg::new(kernel_height as u32), - ScalarArg::new(kernel_width as u32), - ScalarArg::new(out_height as u32), - ScalarArg::new(out_width as u32), - ScalarArg::new(output.strides[0] as u32), + ScalarArg::new(options.offset_groups), + ScalarArg::new(kernel_height), + ScalarArg::new(kernel_width), + ScalarArg::new(out_height), + ScalarArg::new(out_width), + ScalarArg::new(output.strides[0]), ), - Some(kernel_height as u32), - Some(kernel_width as u32), + Some(kernel_height), + Some(kernel_width), use_mask, dtype.into(), )?; diff --git a/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs index 94f040a33e..b199d80d71 100644 --- a/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs @@ -16,7 +16,6 @@ use crate::{ use burn_backend::{DType, Shape, ops::DeformConvOptions}; use cubecl::{ CubeDim, CubeLaunch, calculate_cube_count_elemwise, cube, features::TypeUsage, prelude::*, - std::scalar::InputScalar, }; use cubek::{ convolution::components::ConvSetupError, @@ -273,15 +272,15 @@ fn compute_offset_and_mask_gradient( grad_offset.as_handle_ref().as_tensor_arg(1), grad_mask.as_handle_ref().as_tensor_arg(1), DeformConv2dCol2ImgCoordArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), InputScalar::new(options.padding[0] as f32, dtype.elem_type()), InputScalar::new(options.padding[1] as f32, dtype.elem_type()), - ScalarArg::new(options.offset_groups as u32), - ScalarArg::new(kernel_height as u32), - ScalarArg::new(kernel_width as u32), + ScalarArg::new(options.offset_groups), + ScalarArg::new(kernel_height), + ScalarArg::new(kernel_width), ), use_mask, dtype, @@ -294,15 +293,15 @@ fn compute_offset_and_mask_gradient( #[derive(CubeLaunch, CubeType)] struct DeformConv2dCol2ImgCoordArgs { - stride_h: u32, - stride_w: u32, - dilation_h: u32, - dilation_w: u32, + stride_h: usize, + stride_w: usize, + dilation_h: usize, + dilation_w: usize, pad_h: InputScalar, pad_w: InputScalar, - offset_groups: u32, - kernel_height: u32, - kernel_width: u32, + offset_groups: usize, + kernel_height: usize, + kernel_width: usize, } #[expect(clippy::collapsible_if)] @@ -365,7 +364,7 @@ fn deform_col2img_coord_kernel( let c_bound = channels_per_offset_group * kernel_h * kernel_w; - for col_c in range_stepped(offset_c / 2, c_bound, col_step) { + for col_c in range_stepped(offset_c / 2, c_bound, col_step as u32) { let col_pos = (((col_c * batch_size + b) * out_h) + h) * out_w + w; let out_x = col_pos % out_w; @@ -431,8 +430,8 @@ fn deform_col2img_coord_kernel( #[cube] fn get_coordinate_weight( input: &Slice, - height: u32, - width: u32, + height: usize, + width: usize, y: F, x: F, is_y_direction: bool, @@ -453,22 +452,22 @@ fn get_coordinate_weight( let valid_x_high = x_high >= 0. && x_high < width as f32; let bottom_left = if valid_y_low && valid_x_low { - input[y_low as u32 * stride_y + x_low as u32] + input[y_low as usize * stride_y + x_low as usize] } else { F::new(0.0) }; let bottom_right = if valid_y_low && valid_x_high { - input[y_low as u32 * stride_y + x_high as u32] + input[y_low as usize * stride_y + x_high as usize] } else { F::new(0.0) }; let top_left = if valid_y_high && valid_x_low { - input[y_high as u32 * stride_y + x_low as u32] + input[y_high as usize * stride_y + x_low as usize] } else { F::new(0.0) }; let top_right = if valid_y_high && valid_x_high { - input[y_high as u32 * stride_y + x_high as u32] + input[y_high as usize * stride_y + x_high as usize] } else { F::new(0.0) }; @@ -549,19 +548,19 @@ fn compute_input_grad( columns.as_tensor_arg(1), grad_arg, DeformConv2dCol2ImgArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), InputScalar::new(options.padding[0] as f32, dtypes[0].elem_type()), InputScalar::new(options.padding[1] as f32, dtypes[0].elem_type()), - ScalarArg::new(options.offset_groups as u32), - ScalarArg::new(batch_size as u32), - ScalarArg::new(in_channels as u32), - ScalarArg::new(height as u32), - ScalarArg::new(width as u32), - ScalarArg::new(kernel_height as u32), - ScalarArg::new(kernel_width as u32), + ScalarArg::new(options.offset_groups), + ScalarArg::new(batch_size), + ScalarArg::new(in_channels), + ScalarArg::new(height), + ScalarArg::new(width), + ScalarArg::new(kernel_height), + ScalarArg::new(kernel_width), ), use_mask, dtypes, @@ -577,19 +576,19 @@ fn compute_input_grad( #[derive(CubeLaunch, CubeType)] struct DeformConv2dCol2ImgArgs { - stride_h: u32, - stride_w: u32, - dilation_h: u32, - dilation_w: u32, + stride_h: usize, + stride_w: usize, + dilation_h: usize, + dilation_w: usize, pad_h: InputScalar, pad_w: InputScalar, - offset_groups: u32, - batch_size: u32, - in_channels: u32, - height: u32, - width: u32, - kernel_height: u32, - kernel_width: u32, + offset_groups: usize, + batch_size: usize, + in_channels: usize, + height: usize, + width: usize, + kernel_height: usize, + kernel_width: usize, } #[cube(launch_unchecked)] @@ -669,8 +668,8 @@ fn deform_col2img_kernel( && F::abs(x - xp) < F::new(1.0) { let gradient_pos = - ((batch * n_in_channels + in_channel) * height + u32::cast_from(yp)) * width - + u32::cast_from(xp); + ((batch * n_in_channels + in_channel) * height + usize::cast_from(yp)) * width + + usize::cast_from(xp); let weight = (F::new(1.0) - F::abs(y - yp)) * (F::new(1.0) - F::abs(x - xp)); diff --git a/crates/burn-cubecl/src/kernel/conv/direct.rs b/crates/burn-cubecl/src/kernel/conv/direct.rs index 71148eb13d..2cb42a7395 100644 --- a/crates/burn-cubecl/src/kernel/conv/direct.rs +++ b/crates/burn-cubecl/src/kernel/conv/direct.rs @@ -1,15 +1,12 @@ use crate::ops::numeric::empty_device_optimized_dtype; use crate::{ CubeRuntime, - kernel::{ - into_contiguous_aligned, - utils::{linear_view, shape_divmod}, - }, + kernel::{into_contiguous_aligned, utils::linear_view}, ops::max_line_size, tensor::CubeTensor, }; use burn_backend::ops::{ConvOptions, conv::calculate_conv_output_sizes}; -use cubecl::std::{CubeOption, CubeOptionExpand, FastDivmod}; +use cubecl::std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs}; use cubecl::{ calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView, tensor_line_size_parallel, @@ -50,20 +47,20 @@ fn direct_conv2d_kernel( let line_size_out = output.line_size(); let pos = ABSOLUTE_POS * line_size_out; - let in_c_per_group = weight.shape(weight.rank() - 1); + let in_c_per_group = weight.shape(weight.rank() - 1) as u32; - let (rem, out_c) = shape_out_c.div_mod(pos); + let (rem, out_c) = shape_out_c.div_mod(pos as u32); let (b, spatial_pos) = div_mod_seq(rem, &shape_out); let g = out_c / args.channels_per_group; let ic_start = in_c_per_group * g; let mut sum = match bias { - CubeOption::Some(bias) => bias[out_c / line_size_out], + CubeOption::Some(bias) => bias[out_c as usize / line_size_out], CubeOption::None => Line::empty(line_size_out).fill(E::from_int(0)), }; - let in_offs = b * input.stride(0) + ic_start; + let in_offs = b as usize * input.stride(0) + ic_start as usize; let stride_oc = weight.stride(0); @@ -74,13 +71,13 @@ fn direct_conv2d_kernel( #[unroll] for i in 0..n_spatial { - in_shape.push(input.shape(i + 1)); + in_shape.push(input.shape(i + 1) as u32); in_strides.push(input.stride(i + 1)); - kernel_shape.push(weight.shape(i + 1)); + kernel_shape.push(weight.shape(i + 1) as u32); kernel_strides.push(weight.stride(i + 1)); } - let weight_offs = out_c * stride_oc; + let weight_offs = out_c as usize * stride_oc; let loop_params = LoopParams { out_pos: spatial_pos, @@ -101,7 +98,7 @@ fn direct_conv2d_kernel( true, weight_offs, &loop_params, - 0u32, + 0usize, has_padding, ); @@ -112,13 +109,13 @@ fn direct_conv2d_kernel( struct LoopParams { out_pos: Sequence, in_shape: Sequence, - in_strides: Sequence, + in_strides: Sequence, kernel_shape: Sequence, - kernel_strides: Sequence, + kernel_strides: Sequence, conv_params: Sequence, in_c_per_group: u32, - stride_oc: u32, + stride_oc: usize, } #[cube] @@ -126,11 +123,11 @@ fn kernel_loop( input: &Tensor>, weight: &Tensor>, sum: &mut Line, - in_offs: u32, + in_offs: usize, in_bounds: bool, - weight_offs: u32, + weight_offs: usize, params: &LoopParams, - #[comptime] kernel_dim: u32, + #[comptime] kernel_dim: usize, #[comptime] has_padding: bool, ) { if comptime![kernel_dim < params.kernel_shape.len()] { @@ -142,8 +139,8 @@ fn kernel_loop( for pos in 0..*params.kernel_shape.index(kernel_dim) { let in_pos = (out_idx * conv.stride + pos * conv.dilation) as i32 - conv.padding; - let in_offs = in_offs + in_pos as u32 * stride; - let weight_offs = weight_offs + pos * k_stride; + let in_offs = in_offs + in_pos as usize * stride; + let weight_offs = weight_offs + pos as usize * k_stride; let mut in_bounds = in_bounds; if has_padding { @@ -181,19 +178,19 @@ fn kernel_loop_inner( input: &Tensor>, weight: &Tensor>, sum: &mut Line, - in_offs: u32, + in_offs: usize, in_bounds: bool, - weight_offs: u32, + weight_offs: usize, in_c_per_group: u32, - stride_oc: u32, + stride_oc: usize, ) { let line_size_in = input.line_size(); let line_size_out = sum.size(); if in_bounds { - for in_c in range_stepped(0, in_c_per_group, line_size_in) { - let in_pos = in_offs + in_c; - let mut weight_pos = weight_offs + in_c; + for in_c in range_stepped(0, in_c_per_group, line_size_in as u32) { + let in_pos = in_offs + in_c as usize; + let mut weight_pos = weight_offs + in_c as usize; let val = input[in_pos / line_size_in]; @@ -225,6 +222,7 @@ pub fn conv_direct( bias: Option>, options: ConvOptions, ) -> Result, ConvSetupError> { + let client = input.client.clone(); let out_dtype = input.dtype; let rank = input.shape.num_dims(); let dim_c = rank - 1; @@ -276,9 +274,11 @@ pub fn conv_direct( // Use channels_per_group instead of in_channels to avoid issues here let line_size_in = max_line_size(&weight); - let mut shape_out = shape_divmod(&output); - shape_out.values.remove(0); - let shape_out_c = shape_out.values.pop().unwrap(); + let shape_out = output.shape[1..dim_c] + .iter() + .map(|s| FastDivmodArgs::::new(&client, *s as u32)) + .collect(); + let shape_out_c = FastDivmodArgs::::new(&client, out_channels as u32); let mut conv_params = SequenceArg::new(); @@ -292,7 +292,7 @@ pub fn conv_direct( let bias = bias.as_ref().map(|b| b.as_tensor_arg(line_size_out)); - let working_units = output.shape.num_elements() / line_size_out as usize; + let working_units = output.shape.num_elements() / line_size_out; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/conv/im2col.rs b/crates/burn-cubecl/src/kernel/conv/im2col.rs index 028eb36a4e..08be204461 100644 --- a/crates/burn-cubecl/src/kernel/conv/im2col.rs +++ b/crates/burn-cubecl/src/kernel/conv/im2col.rs @@ -5,7 +5,7 @@ use burn_backend::{ use core::iter; use cubecl::{ prelude::*, - std::tensor::{TensorHandle, into_contiguous_pitched}, + std::tensor::{TensorHandle, into_contiguous_pitched_ref}, }; use cubek::convolution::components::ConvSetupError; @@ -30,7 +30,7 @@ pub(crate) fn batches_per_run( let cube_count_per_batch = out_shape.div_ceil(plane_size); let max_cube_count = u16::MAX as usize; - let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size); + let max_simultaneous = Ord::min(max_cube_count / cube_count_per_batch, batch_size); if max_simultaneous == 0 { return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static( cube_count_per_batch as u32, @@ -141,7 +141,7 @@ fn reshape_input(mut input: CubeTensor) -> CubeTensor { if !is_spatial_contiguous(&input.shape, &input.strides) { let contiguous = - into_contiguous_pitched(&input.client, &input.as_handle_ref(), dtype.into()) + into_contiguous_pitched_ref(&input.client, &input.as_handle_ref(), dtype.into()) .expect("Kernel to never fail"); input = from_handle(&input.client, &input.device, contiguous, dtype); } diff --git a/crates/burn-cubecl/src/kernel/index/flip.rs b/crates/burn-cubecl/src/kernel/index/flip.rs index 3a9daa8a87..85648a7a9a 100644 --- a/crates/burn-cubecl/src/kernel/index/flip.rs +++ b/crates/burn-cubecl/src/kernel/index/flip.rs @@ -2,7 +2,6 @@ use crate::{ CubeRuntime, kernel::into_contiguous, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use burn_backend::DType; -use cubecl::std::scalar::InputScalar; use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked)] @@ -10,7 +9,7 @@ fn flip_kernel( input: &Tensor, output: &mut Tensor, indices: Sequence, - #[comptime] rank: u32, + #[comptime] rank: usize, #[define(E, Bool)] _dtypes: [StorageType; 2], ) { if ABSOLUTE_POS >= output.len() { @@ -80,7 +79,7 @@ pub(crate) fn flip_on_output( tensor.as_tensor_arg(1), output.as_tensor_arg(1), indices_sequence, - ndims as u32, + ndims, [dtype_input.into(), dtype_bool.into()], ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/gather.rs b/crates/burn-cubecl/src/kernel/index/gather.rs index e83307223d..e78f38918e 100644 --- a/crates/burn-cubecl/src/kernel/index/gather.rs +++ b/crates/burn-cubecl/src/kernel/index/gather.rs @@ -21,7 +21,7 @@ fn gather_kernel( indices: &LinearView>, output: &mut Tensor>, out_layout: LinearLayout, - dim: &u32, + dim: usize, #[define(T, I)] _dtypes: [StorageType; 2], ) { if !indices.is_in_bounds(ABSOLUTE_POS) { @@ -31,17 +31,17 @@ fn gather_kernel( let index = indices[ABSOLUTE_POS]; let out_pos = out_layout.to_source_pos(ABSOLUTE_POS); - let stride = input.stride(*dim); - let mut offset = u32::cast_from(index); + let stride = input.stride(dim); + let mut offset = usize::cast_from(index); offset *= stride; - if *dim > 0 { - let offset_before = index_offset_with_layout(input, output, out_pos, 0, *dim, false); + if dim > 0 { + let offset_before = index_offset_with_layout(input, output, out_pos, 0, dim, false); offset += offset_before; } let offset_after = - index_offset_with_layout(input, output, out_pos, *dim + 1, input.rank(), false); + index_offset_with_layout(input, output, out_pos, dim + 1, input.rank(), false); offset += offset_after; output[out_pos] = input[offset]; } @@ -72,7 +72,7 @@ pub(crate) fn gather( linear_view(&indices, 1), output.as_tensor_arg(1), linear_layout(&output, 1), - ScalarArg::new(dim as u32), + ScalarArg::new(dim), [tensor.dtype.into(), indices.dtype.into()], ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/repeat_dim.rs b/crates/burn-cubecl/src/kernel/index/repeat_dim.rs index 768973c4fc..a75cbd7f54 100644 --- a/crates/burn-cubecl/src/kernel/index/repeat_dim.rs +++ b/crates/burn-cubecl/src/kernel/index/repeat_dim.rs @@ -5,7 +5,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; fn repeat_dim_kernel( input: &Tensor, output: &mut Tensor, - dim: u32, + dim: usize, #[define(E)] _dtype: StorageType, ) { if ABSOLUTE_POS >= output.len() { @@ -56,7 +56,7 @@ pub(crate) fn repeat_dim( cube_dim, input.as_tensor_arg(1), output.as_tensor_arg(1), - ScalarArg::new(dim as u32), + ScalarArg::new(dim), output.dtype.into(), ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/scatter.rs b/crates/burn-cubecl/src/kernel/index/scatter.rs index 8f6cc62d9c..c8dea8cbe6 100644 --- a/crates/burn-cubecl/src/kernel/index/scatter.rs +++ b/crates/burn-cubecl/src/kernel/index/scatter.rs @@ -11,18 +11,18 @@ fn scatter_kernel( input: &mut Tensor, indices: &Tensor, value: &Tensor, - dim: &u32, + dim: usize, #[define(T, I)] _dtypes: [StorageType; 2], ) { - let stride_input = input.stride(*dim); - let shape_value = value.shape(*dim); + let stride_input = input.stride(dim); + let shape_value = value.shape(dim); let mut offset_input = 0; let mut offset_value = 0; let mut num_elems = 1; for i in 0..value.rank() { - let shouldnt_skip = i != *dim; + let shouldnt_skip = i != dim; if shouldnt_skip { let shape_input_loop = input.shape(i); let shape_value_loop = value.shape(i); @@ -54,7 +54,7 @@ fn scatter_kernel( idx += offset_value; let result_value = value[idx]; - let result_indices = u32::cast_from(indices[idx]); + let result_indices = usize::cast_from(indices[idx]); let mut index_input = stride_input * result_indices; index_input += offset_input; @@ -121,7 +121,7 @@ pub(crate) fn scatter( tensor.as_tensor_arg(1), indices.as_tensor_arg(1), value.as_tensor_arg(1), - ScalarArg::new(dim as u32), + ScalarArg::new(dim), [tensor.dtype.into(), indices.dtype.into()], ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/select.rs b/crates/burn-cubecl/src/kernel/index/select.rs index 64347c7787..42483ac4d1 100644 --- a/crates/burn-cubecl/src/kernel/index/select.rs +++ b/crates/burn-cubecl/src/kernel/index/select.rs @@ -9,7 +9,7 @@ fn select_kernel( input: &Tensor, indices: &Tensor, output: &mut Tensor, - dim: u32, + dim: usize, #[define(T, I)] _dtypes: [StorageType; 2], ) { if ABSOLUTE_POS >= output.len() { @@ -22,7 +22,7 @@ fn select_kernel( let mut offset_local = ABSOLUTE_POS / output.stride(i) % output.shape(i); if i == dim { - offset_local = u32::cast_from(indices[offset_local]); + offset_local = usize::cast_from(indices[offset_local]); } offset_input += offset_local * input.stride(i); @@ -69,7 +69,7 @@ pub(crate) fn select( indices.dtype.size(), ), output.as_tensor_arg(1), - ScalarArg::new(dim as u32), + ScalarArg::new(dim), [tensor.dtype.into(), indices.dtype.into()], ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/select_assign.rs b/crates/burn-cubecl/src/kernel/index/select_assign.rs index 3ad68fb054..31e756927e 100644 --- a/crates/burn-cubecl/src/kernel/index/select_assign.rs +++ b/crates/burn-cubecl/src/kernel/index/select_assign.rs @@ -8,13 +8,12 @@ fn select_assign_kernel( tensor: &mut Tensor, indices: &Tensor, value: &Tensor, - dim: &u32, + dim: usize, #[define(F, I)] _dtypes: [StorageType; 2], ) { - let dim = *dim; - let mut offset_tensor = 0u32; - let mut offset_value = 0u32; - let mut num_elems = 1u32; + let mut offset_tensor = 0; + let mut offset_value = 0; + let mut num_elems = 1; // Calculate offsets and num_elems for i in 0..tensor.rank() { @@ -39,7 +38,7 @@ fn select_assign_kernel( // Main operation for i in 0..value.shape(dim) { - let index_tensor = u32::cast_from(indices[i]) * strides_tensor_dim + offset_tensor; + let index_tensor = usize::cast_from(indices[i]) * strides_tensor_dim + offset_tensor; let index_value = i * strides_value_dim + offset_value; let value = Op::BinaryOp::::execute( @@ -104,7 +103,7 @@ pub(crate) fn select_assign( indices.dtype.size(), ), value.as_tensor_arg(1), - ScalarArg::new(dim as u32), + ScalarArg::new(dim), [tensor.dtype.into(), indices.dtype.into()], ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/slice.rs b/crates/burn-cubecl/src/kernel/index/slice.rs index 9041ad1642..a872478713 100644 --- a/crates/burn-cubecl/src/kernel/index/slice.rs +++ b/crates/burn-cubecl/src/kernel/index/slice.rs @@ -58,8 +58,8 @@ pub fn slice(tensor: CubeTensor, indices: &[Range]) -> fn slice_kernel( input: &Tensor, output: &mut LinearView, - out_shape: Sequence, - indices: Sequence, + out_shape: Sequence>, + indices: Sequence, #[define(E)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { @@ -73,11 +73,10 @@ fn slice_kernel( #[unroll] for i in 0..rank { // Iterate in reverse to use divmod - let i = unwrap(i); - let dim = comptime![rank - i - 1]; + let dim = rank - i - 1; - let range_start = *indices.index(dim); - let (rem, offset_local) = out_shape.index(dim).div_mod(offset_output); + let range_start = indices[dim]; + let (rem, offset_local) = out_shape[dim].div_mod(offset_output); offset_output = rem; let offset_local = offset_local + range_start; @@ -94,11 +93,11 @@ pub(crate) fn slice_on_output( indices: &[Range], ) -> CubeTensor { let ndims = tensor.shape.num_dims(); - let mut indices_sequence = SequenceArg::::new(); + let mut indices_sequence = SequenceArg::::new(); for i in 0..ndims { let start = indices.get(i).map(|index| index.start).unwrap_or(0); - indices_sequence.push(ScalarArg::new(start as u32)); + indices_sequence.push(ScalarArg::new(start)); } let working_units = output.shape.num_elements(); @@ -127,9 +126,9 @@ pub(crate) fn slice_on_output( fn slice_with_steps_kernel( input: &Tensor, output: &mut LinearView, - out_shape: Sequence, - starts: Sequence, - ends: Sequence, + out_shape: Sequence>, + starts: Sequence, + ends: Sequence, steps: Sequence, #[define(E)] _dtype: StorageType, ) { @@ -145,21 +144,20 @@ fn slice_with_steps_kernel( #[unroll] for i in 0..rank { // Iterate in reverse to use divmod - let i = unwrap(i); - let dim = comptime![rank - i - 1]; - let start = *starts.index(dim); - let end = *ends.index(dim); - let step = *steps.index(dim); + let dim = rank - i - 1; + let start = starts[dim]; + let end = ends[dim]; + let step = steps[dim]; - let (rem, output_idx) = out_shape.index(dim).div_mod(output_offset); + let (rem, output_idx) = out_shape[dim].div_mod(output_offset); output_offset = rem; let input_idx = if step > 0 { // Forward stepping - start + output_idx * (step as u32) + start + output_idx * (step as usize) } else { // Backward stepping - start from end-1 - let abs_step = (-step) as u32; + let abs_step = (-step) as usize; let end_minus_1 = end - 1; end_minus_1 - output_idx * abs_step }; @@ -197,21 +195,21 @@ pub fn slice_with_steps(tensor: CubeTensor, slices: &[Slice]) ); // Prepare three separate sequences for kernel - let mut starts = SequenceArg::::new(); - let mut ends = SequenceArg::::new(); + let mut starts = SequenceArg::::new(); + let mut ends = SequenceArg::::new(); let mut steps = SequenceArg::::new(); for (dim, slice) in slices.iter().enumerate() { let range = slice.to_range(tensor.shape[dim]); - starts.push(ScalarArg::new(range.start as u32)); - ends.push(ScalarArg::new(range.end as u32)); + starts.push(ScalarArg::new(range.start)); + ends.push(ScalarArg::new(range.end)); steps.push(ScalarArg::new(slice.step as i32)); } // Pad with default values if needed to match tensor dimensions for dim in slices.len()..tensor.shape.num_dims() { starts.push(ScalarArg::new(0)); - ends.push(ScalarArg::new(tensor.shape[dim] as u32)); + ends.push(ScalarArg::new(tensor.shape[dim])); steps.push(ScalarArg::new(1)); } diff --git a/crates/burn-cubecl/src/kernel/index/slice_assign.rs b/crates/burn-cubecl/src/kernel/index/slice_assign.rs index 26b11172c2..527011c41c 100644 --- a/crates/burn-cubecl/src/kernel/index/slice_assign.rs +++ b/crates/burn-cubecl/src/kernel/index/slice_assign.rs @@ -13,8 +13,8 @@ use cubecl::{ fn slice_assign_kernel( input: &mut Tensor>, value: &LinearView>, - slice_shape: Sequence, - slice_offsets: Sequence, + slice_shape: Sequence>, + slice_offsets: Sequence, #[define(E)] _dtype: StorageType, ) { if !value.is_in_bounds(ABSOLUTE_POS) { @@ -27,21 +27,17 @@ fn slice_assign_kernel( let mut offset_remainder = ABSOLUTE_POS * line_size; let mut offset_input = 0; - let mut i = comptime![0]; - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..rank { - let dim = comptime![rank - i - 1]; - let (rem, offset_local) = slice_shape.index(dim).div_mod(offset_remainder); + for i in 0..rank { + let dim = rank - i - 1; + let (rem, offset_local) = slice_shape[dim].div_mod(offset_remainder); - let range_start = *slice_offsets.index(dim); + let range_start = slice_offsets[dim]; let offset_local_input = offset_local + range_start; offset_input += offset_local_input * input.stride(dim); offset_remainder = rem; - - comptime![i += 1;] } // Value tensor is accessed linearly since it's a LinearView @@ -53,9 +49,9 @@ fn slice_assign_kernel( fn slice_assign_with_steps_kernel( input: &mut Tensor, value: &LinearView, - value_shape: Sequence, - starts: Sequence, - ends: Sequence, + value_shape: Sequence>, + starts: Sequence, + ends: Sequence, steps: Sequence, #[define(E)] _dtype: StorageType, ) { @@ -71,22 +67,21 @@ fn slice_assign_with_steps_kernel( #[unroll] for i in 0..rank { // Iterate in reverse to use divmod - let i = unwrap(i); - let dim = comptime![rank - i - 1]; - let start = *starts.index(dim); - let end = *ends.index(dim); - let step = *steps.index(dim); + let dim = rank - i - 1; + let start = starts[dim]; + let end = ends[dim]; + let step = steps[dim]; - let (rem, value_idx) = value_shape.index(dim).div_mod(value_offset); + let (rem, value_idx) = value_shape[dim].div_mod(value_offset); value_offset = rem; let input_idx = if step > 0 { // Forward stepping - start + value_idx * (step as u32) + start + value_idx * (step as usize) } else if step < 0 { // Backward stepping - start from end-1 // For negative steps, we iterate backwards through the selected indices - let abs_step = (-step) as u32; + let abs_step = (-step) as usize; let end_minus_1 = end - 1; end_minus_1 - value_idx * abs_step } else { @@ -135,7 +130,7 @@ pub(crate) fn slice_assign( *R::supported_line_sizes() .iter() .filter(|it| { - let it = **it as usize; + let it = **it; shape.is_multiple_of(it) && strides_compatible(&tensor.strides, it) && strides_compatible(&value.strides, it) @@ -147,8 +142,8 @@ pub(crate) fn slice_assign( 1 }; - let mut shape = SequenceArg::::new(); - let mut offsets = SequenceArg::::new(); + let mut shape = SequenceArg::>::new(); + let mut offsets = SequenceArg::::new(); for i in 0..ndims { let slice = indices.get(i).cloned().unwrap_or(burn_backend::Slice { @@ -160,11 +155,11 @@ pub(crate) fn slice_assign( let end = slice.end.unwrap_or(tensor.shape[i] as isize); let length = (end - slice.start) as usize; - shape.push(FastDivmodArgs::new(&client, length as u32)); - offsets.push(ScalarArg::new(start as u32)); + shape.push(FastDivmodArgs::::new(&client, length)); + offsets.push(ScalarArg::new(start)); } - let working_units = value.shape.num_elements() / line_size as usize; + let working_units = value.shape.num_elements() / line_size; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); @@ -205,21 +200,21 @@ pub(crate) fn slice_assign_with_steps( }; // Prepare sequences for kernel - let mut starts = SequenceArg::::new(); - let mut ends = SequenceArg::::new(); + let mut starts = SequenceArg::::new(); + let mut ends = SequenceArg::::new(); let mut steps = SequenceArg::::new(); for (dim, slice) in slices.iter().enumerate() { let range = slice.to_range(tensor.shape[dim]); - starts.push(ScalarArg::new(range.start as u32)); - ends.push(ScalarArg::new(range.end as u32)); + starts.push(ScalarArg::new(range.start)); + ends.push(ScalarArg::new(range.end)); steps.push(ScalarArg::new(slice.step as i32)); } // Pad with default values if needed to match tensor dimensions for dim in slices.len()..tensor.shape.num_dims() { starts.push(ScalarArg::new(0)); - ends.push(ScalarArg::new(tensor.shape[dim] as u32)); + ends.push(ScalarArg::new(tensor.shape[dim])); steps.push(ScalarArg::new(1)); } diff --git a/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs b/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs index 98e5624698..84e71d8f9c 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs @@ -15,7 +15,7 @@ use crate::{ fn interpolate_bicubic_kernel( input: &Tensor>, output: &mut Tensor>, - shape_out: Sequence, + shape_out: Sequence>, out_layout: LinearLayout, #[define(F)] _dtype: StorageType, ) { @@ -26,9 +26,9 @@ fn interpolate_bicubic_kernel( let line_size = input.line_size(); let out_idx = out_layout.to_source_pos(ABSOLUTE_POS); - let (rem, c) = shape_out.index(3).div_mod(ABSOLUTE_POS * line_size); - let (rem, x) = shape_out.index(2).div_mod(rem); - let (b, y) = shape_out.index(1).div_mod(rem); + let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size); + let (rem, x) = shape_out[2].div_mod(rem); + let (b, y) = shape_out[1].div_mod(rem); let input_height = input.shape(1) - 1; let output_height = f32::cast_from(Max::max(output.shape(1) - 1, 1)); @@ -36,7 +36,7 @@ fn interpolate_bicubic_kernel( let frac = f32::cast_from(numerator / output_height); let y_in_f = Floor::floor(frac); - let y_in = u32::cast_from(y_in_f); + let y_in = usize::cast_from(y_in_f); let yw = Line::empty(line_size).fill(F::cast_from(frac - y_in_f)); let y0 = select(y_in != 0, y_in - 1, 0); @@ -49,7 +49,7 @@ fn interpolate_bicubic_kernel( let numerator = f32::cast_from(x * input_width); let frac = numerator / output_width; let x_in_f = Floor::floor(frac); - let x_in = u32::cast_from(x_in_f); + let x_in = usize::cast_from(x_in_f); let xw = Line::empty(line_size).fill(F::cast_from(frac - x_in_f)); let x0 = select(x_in != 0, x_in - 1, 0); diff --git a/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs b/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs index 50cf80ac29..216817a5dc 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs @@ -15,7 +15,7 @@ use crate::{ fn interpolate_bilinear_kernel( input: &Tensor>, output: &mut Tensor>, - shape_out: Sequence, + shape_out: Sequence>, out_layout: LinearLayout, #[define(F)] _dtype: StorageType, ) { @@ -26,9 +26,9 @@ fn interpolate_bilinear_kernel( let line_size = input.line_size(); let out_idx = out_layout.to_source_pos(ABSOLUTE_POS); - let (rem, c) = shape_out.index(3).div_mod(ABSOLUTE_POS * line_size); - let (rem, x) = shape_out.index(2).div_mod(rem); - let (b, y) = shape_out.index(1).div_mod(rem); + let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size); + let (rem, x) = shape_out[2].div_mod(rem); + let (b, y) = shape_out[1].div_mod(rem); let numerator = (input.shape(1) - 1) as f32; let denominator = Max::max(output.shape(1) - 1, 1) as f32; @@ -42,8 +42,8 @@ fn interpolate_bilinear_kernel( let yw_ = Line::empty(line_size).fill(F::new(1.0) - yw); let yw = Line::empty(line_size).fill(yw); let y0_ok = v0 >= 0.0; - let y0 = v0 as u32; - let y1 = v1 as u32; + let y0 = v0 as usize; + let y1 = v1 as usize; let numerator = f32::cast_from(input.shape(2) - 1); let denominator = f32::cast_from(Max::max(output.shape(2) - 1, 1)); @@ -55,8 +55,8 @@ fn interpolate_bilinear_kernel( let xw_ = Line::empty(line_size).fill(F::new(1.0) - xw); let xw = Line::empty(line_size).fill(xw); let x0_ok = v0 >= 0.0; - let x0 = v0 as u32; - let x1 = v1 as u32; + let x0 = v0 as usize; + let x1 = v1 as usize; let index_base = b * input.stride(0) + c * input.stride(3); diff --git a/crates/burn-cubecl/src/kernel/interpolate/nearest.rs b/crates/burn-cubecl/src/kernel/interpolate/nearest.rs index 8ebc2a3f6a..a60f61e743 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/nearest.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/nearest.rs @@ -15,7 +15,7 @@ use crate::{ fn interpolate_nearest_kernel( input: &Tensor>, output: &mut Tensor>, - shape_out: Sequence, + shape_out: Sequence>, out_layout: LinearLayout, #[define(F)] _dtype: StorageType, ) { @@ -31,16 +31,16 @@ fn interpolate_nearest_kernel( let (h_in, w_in) = (input.shape(1) as f32, input.shape(2) as f32); let (h_out, w_out) = (output.shape(1) as f32, output.shape(2) as f32); - let (rem, c) = shape_out.index(3).div_mod(out_pos); - let (rem, x) = shape_out.index(2).div_mod(rem); - let (b, y) = shape_out.index(1).div_mod(rem); + let (rem, c) = shape_out[3].div_mod(out_pos); + let (rem, x) = shape_out[2].div_mod(rem); + let (b, y) = shape_out[1].div_mod(rem); let y = y as f32 * (h_in / h_out); let x = x as f32 * (w_in / w_out); let in_idx = b * input.stride(0) - + y as u32 * input.stride(1) - + x as u32 * input.stride(2) + + y as usize * input.stride(1) + + x as usize * input.stride(2) + c * input.stride(3); output[out_idx] = input[in_idx / line_size]; diff --git a/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs b/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs index 82ab2c6072..b0dab1d6aa 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs @@ -15,7 +15,7 @@ use crate::{ fn interpolate_nearest_backward_kernel( grad: &Tensor>, output: &mut Tensor>, - shape_out: Sequence, + shape_out: Sequence>, out_layout: LinearLayout, #[define(F)] _dtype: StorageType, ) { @@ -31,9 +31,9 @@ fn interpolate_nearest_backward_kernel( let grad_h = grad.shape(1); let grad_w = grad.shape(2); - let (rem, c) = shape_out.index(3).div_mod(ABSOLUTE_POS * line_size); - let (rem, out_x) = shape_out.index(2).div_mod(rem); - let (b, out_y) = shape_out.index(1).div_mod(rem); + let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size); + let (rem, out_x) = shape_out[2].div_mod(rem); + let (b, out_y) = shape_out[1].div_mod(rem); let grad_y_start = start_index::(out_y, grad_h, out_h); let grad_y_end = end_index::(out_y, grad_h, out_h); @@ -56,18 +56,18 @@ fn interpolate_nearest_backward_kernel( } #[cube] -fn start_index(input_index: u32, output_size: u32, input_size: u32) -> u32 { +fn start_index(input_index: usize, output_size: usize, input_size: usize) -> usize { let numerator = F::cast_from(input_index * output_size); let div: F = Ceil::ceil(numerator / F::cast_from(input_size)); - u32::cast_from(div) + usize::cast_from(div) } #[cube] -fn end_index(input_index: u32, output_size: u32, input_size: u32) -> u32 { +fn end_index(input_index: usize, output_size: usize, input_size: usize) -> usize { let numerator = F::cast_from((input_index + 1) * output_size); let div: F = Ceil::ceil(numerator / F::cast_from(input_size)); - let index = u32::cast_from(div); + let index = usize::cast_from(div); Min::min(output_size, index) } diff --git a/crates/burn-cubecl/src/kernel/mask/base.rs b/crates/burn-cubecl/src/kernel/mask/base.rs index df08bed25f..fbedacd78b 100644 --- a/crates/burn-cubecl/src/kernel/mask/base.rs +++ b/crates/burn-cubecl/src/kernel/mask/base.rs @@ -1,5 +1,5 @@ use burn_backend::DType; -use cubecl::std::scalar::InputScalar; +use cubecl::prelude::InputScalar; use super::{MaskFillStrategy, mask_where::MaskWhereStrategy}; use crate::{CubeRuntime, tensor::CubeTensor}; diff --git a/crates/burn-cubecl/src/kernel/mask/mask_fill.rs b/crates/burn-cubecl/src/kernel/mask/mask_fill.rs index 5c1c61e372..ba5e80c00f 100644 --- a/crates/burn-cubecl/src/kernel/mask/mask_fill.rs +++ b/crates/burn-cubecl/src/kernel/mask/mask_fill.rs @@ -1,9 +1,5 @@ use burn_backend::DType; -use cubecl::{ - calculate_cube_count_elemwise, - prelude::*, - std::{scalar::InputScalar, tensor::layout::linear::LinearView}, -}; +use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; use crate::{ CubeRuntime, diff --git a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs index a3fa067cc3..1a7f9a6f3b 100644 --- a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs @@ -63,12 +63,12 @@ fn adaptive_avg_pool2d_direct( } #[cube] -fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { +fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { (output_size_index * input_size) / output_size } #[cube] -fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { +fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { let index = (output_size_index + 1) * input_size; let index = index.div_ceil(output_size); diff --git a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs index bbf9fa4e5a..4c991f142d 100644 --- a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs @@ -68,12 +68,12 @@ fn adaptive_avg_pool2d_backward_direct( } #[cube] -fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { +fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { (output_size_index * input_size) / output_size } #[cube] -fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { +fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { let index = (output_size_index + 1) * input_size; let index = index.div_ceil(output_size); diff --git a/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs index 546cf10328..4f9a0c19b6 100644 --- a/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs @@ -36,7 +36,7 @@ impl Pool2dDirectStrategy for AvgPoolStrategy { fn initialize( #[comptime] _config: &Self::Config, - #[comptime] line_size: u32, + #[comptime] line_size: LineSize, ) -> Self::Accumulator { let sum = Line::empty(line_size).fill(N::from_int(0)); // Count will be set dynamically: either by accumulate (count_include_pad=false) @@ -49,7 +49,7 @@ impl Pool2dDirectStrategy for AvgPoolStrategy { fn accumulate( #[comptime] config: &Self::Config, accumulator: &mut Self::Accumulator, - _index: u32, + _index: usize, result: Line, ) { let (sum, count) = accumulator; @@ -78,7 +78,7 @@ impl Pool2dDirectStrategy for AvgPoolStrategy { fn store( #[comptime] _config: &Self::Config, - position: u32, + position: usize, output: &mut Tensor>, _output_indices: &mut (), accumulator: Self::Accumulator, diff --git a/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs index 5f7a429ba5..7494ba0ebf 100644 --- a/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs @@ -35,9 +35,9 @@ fn avg_pool2d_backward_kernel( let channel_lines = output.shape(3) / line_size; let channel = (ABSOLUTE_POS % channel_lines) * output.line_size(); let pos = ABSOLUTE_POS / channel_lines; - let iw = pos % output.shape(2); + let iw = pos as u32 % output.shape(2) as u32; let pos = pos / output.shape(2); - let ih = pos % output.shape(1); + let ih = pos as u32 % output.shape(1) as u32; let batch = pos / output.shape(1); let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0)); @@ -45,8 +45,8 @@ fn avg_pool2d_backward_kernel( let (oh_start, oh_end, ow_start, ow_end) = loop_ranges( ih as i32, iw as i32, - grad.shape(1), - grad.shape(2), + grad.shape(1) as u32, + grad.shape(2) as u32, args, kernel_size_0, kernel_size_1, @@ -60,8 +60,8 @@ fn avg_pool2d_backward_kernel( let kernel_size_1 = comptime![kernel_size_1 as u32]; let index_base = batch * grad.stride(0) + channel * grad.stride(3); - let border_bottom = output.shape(1) + padding_0; - let border_right = output.shape(2) + padding_1; + let border_bottom = output.shape(1) as u32 + padding_0; + let border_right = output.shape(2) as u32 + padding_1; let begin_h = ih + padding_0; let begin_w = iw + padding_1; @@ -72,7 +72,8 @@ fn avg_pool2d_backward_kernel( if begin_h >= ih_start && ih < ih_end { for ow in ow_start..ow_end { - let index = index_base + oh * grad.stride(1) + ow * grad.stride(2); + let index = + index_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2); let iw_start = ow * stride_1; let iw_end = Min::min(iw_start + kernel_size_1, border_right); diff --git a/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs index 3ceaf66869..6f54b906d3 100644 --- a/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs @@ -33,7 +33,7 @@ impl Pool2dDirectStrategy for MaxPoolStrategy { fn initialize( #[comptime] _config: &Self::Config, - #[comptime] line_size: u32, + #[comptime] line_size: LineSize, ) -> Self::Accumulator { Line::empty(line_size).fill(N::min_value()) } @@ -41,7 +41,7 @@ impl Pool2dDirectStrategy for MaxPoolStrategy { fn accumulate( #[comptime] _config: &Self::Config, accumulator: &mut Self::Accumulator, - _index: u32, + _index: LineSize, result: Line, ) { *accumulator = Max::max(*accumulator, result); @@ -57,7 +57,7 @@ impl Pool2dDirectStrategy for MaxPoolStrategy { fn store( #[comptime] _config: &Self::Config, - position: u32, + position: usize, output: &mut Tensor>, _output_indices: &mut (), accumulator: Self::Accumulator, @@ -74,7 +74,7 @@ impl Pool2dDirectStrategy for MaxPoolWithIndicesStrategy { fn initialize( #[comptime] _config: &Self::Config, - #[comptime] line_size: u32, + #[comptime] line_size: LineSize, ) -> Self::Accumulator { let val = Line::empty(line_size).fill(N::min_value()); let idx = Line::empty(line_size).fill(0i32); @@ -84,7 +84,7 @@ impl Pool2dDirectStrategy for MaxPoolWithIndicesStrategy { fn accumulate( #[comptime] _config: &Self::Config, accumulator: &mut Self::Accumulator, - index: u32, + index: usize, result: Line, ) { let indices = Line::cast_from(index); @@ -102,7 +102,7 @@ impl Pool2dDirectStrategy for MaxPoolWithIndicesStrategy { fn store( #[comptime] _config: &Self::Config, - position: u32, + position: usize, output: &mut Tensor>, output_indices: &mut Tensor>, accumulator: Self::Accumulator, diff --git a/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs index a639894bb8..a6f272d59d 100644 --- a/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs @@ -38,8 +38,8 @@ fn max_pool2d_with_indices_backward_kernel( let (oh_start, oh_end, ow_start, ow_end) = loop_ranges( ih as i32, iw as i32, - grad.shape(1), - grad.shape(2), + grad.shape(1) as u32, + grad.shape(2) as u32, args, kernel_size_0, kernel_size_1, @@ -51,7 +51,7 @@ fn max_pool2d_with_indices_backward_kernel( for oh in oh_start..oh_end { for ow in ow_start..ow_end { - let index = index_base + oh * grad.stride(1) + ow * grad.stride(2); + let index = index_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2); let index_max = Line::::cast_from(indices[index / line_size]); grad_acc += select_many( diff --git a/crates/burn-cubecl/src/kernel/pool/pool2d.rs b/crates/burn-cubecl/src/kernel/pool/pool2d.rs index f342b15ba1..a25b3f7cd1 100644 --- a/crates/burn-cubecl/src/kernel/pool/pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/pool2d.rs @@ -16,13 +16,13 @@ pub(crate) trait Pool2dDirectStrategy: Send + Sync + 'static { fn initialize( #[comptime] config: &Self::Config, - #[comptime] line_size: u32, + #[comptime] line_size: LineSize, ) -> Self::Accumulator; fn accumulate( #[comptime] config: &Self::Config, accumulator: &mut Self::Accumulator, - index: u32, + index: usize, result: Line, ); @@ -38,7 +38,7 @@ pub(crate) trait Pool2dDirectStrategy: Send + Sync + 'static { fn store( #[comptime] config: &Self::Config, - position: u32, + position: usize, output: &mut Tensor>, output_indices: &mut Self::Indices, accumulator: Self::Accumulator, @@ -77,13 +77,13 @@ pub fn pool2d_direct( input.stride(2), input.stride(3), ); - let (in_h, in_w) = (input.shape(1), input.shape(2)); + let (in_h, in_w) = (input.shape(1) as u32, input.shape(2) as u32); let c = (ABSOLUTE_POS % channel_lines) * input.line_size(); let pos = ABSOLUTE_POS / channel_lines; - let ow = pos % out_w; + let ow = pos as u32 % out_w as u32; let pos = pos / out_w; - let oh = pos % out_h; + let oh = pos as u32 % out_h as u32; let b = pos / out_h; let mut accumulator = S::Pool2d::::initialize(config, input.line_size()); @@ -110,15 +110,15 @@ pub fn pool2d_direct( let ih_pad = ih - args.padding_0; let iw_pad = iw - args.padding_1; - let in_h_off = ih_pad * in_stride_h; - let in_w_off = iw_pad * in_stride_w; + let in_h_off = ih_pad as usize * in_stride_h; + let in_w_off = iw_pad as usize * in_stride_w; let index_input = in_b_off + in_c_off + in_h_off + in_w_off; S::Pool2d::::accumulate( config, &mut accumulator, - ih_pad * in_w + iw_pad, + ih_pad as usize * in_w as usize + iw_pad as usize, input[index_input / input.line_size()], ); } diff --git a/crates/burn-cubecl/src/kernel/utils.rs b/crates/burn-cubecl/src/kernel/utils.rs index 20a5770bf7..67fa3c85f1 100644 --- a/crates/burn-cubecl/src/kernel/utils.rs +++ b/crates/burn-cubecl/src/kernel/utils.rs @@ -1,5 +1,6 @@ use burn_backend::Shape; use cubecl::{ + ir::LineSize, prelude::ArrayArg, std::{ FastDivmod, FastDivmodArgs, @@ -10,17 +11,19 @@ use cubecl::{prelude::SequenceArg, std::tensor::layout::linear::LinearLayout}; use crate::{CubeRuntime, tensor::CubeTensor}; -pub fn shape_divmod<'a, R: CubeRuntime>(tensor: &CubeTensor) -> SequenceArg<'a, R, FastDivmod> { +pub fn shape_divmod<'a, R: CubeRuntime>( + tensor: &CubeTensor, +) -> SequenceArg<'a, R, FastDivmod> { let mut arg = SequenceArg::new(); for dim in tensor.shape.iter() { - arg.push(FastDivmodArgs::new(&tensor.client, *dim as u32)); + arg.push(FastDivmodArgs::::new(&tensor.client, *dim)); } arg } pub fn linear_layout<'a, R: CubeRuntime>( tensor: &'a CubeTensor, - line_size: u8, + line_size: LineSize, ) -> LinearLayoutArgs<'a, R> { LinearLayoutArgs::from_shape_strides(&tensor.client, &tensor.shape, &tensor.strides, line_size) } @@ -28,7 +31,7 @@ pub fn linear_layout<'a, R: CubeRuntime>( pub fn linear_layout_ref<'a, R: CubeRuntime>( tensor: &'a CubeTensor, reference: &'a CubeTensor, - line_size: u8, + line_size: LineSize, ) -> LinearLayoutArgs<'a, R> { LinearLayoutArgs::from_shape_strides_with_reference( &tensor.client, @@ -41,7 +44,7 @@ pub fn linear_layout_ref<'a, R: CubeRuntime>( pub fn linear_view<'a, R: CubeRuntime>( tensor: &'a CubeTensor, - line_size: u8, + line_size: LineSize, ) -> LinearViewLaunch<'a, R> { let len = tensor.shape.iter().product::(); let layout = linear_layout(tensor, line_size); @@ -54,7 +57,7 @@ pub fn linear_view<'a, R: CubeRuntime>( pub fn linear_view_ref<'a, R: CubeRuntime>( tensor: &'a CubeTensor, reference: &'a CubeTensor, - line_size: u8, + line_size: LineSize, ) -> LinearViewLaunch<'a, R> { let len = tensor.shape.iter().product::(); let layout = linear_layout_ref(tensor, reference, line_size); @@ -66,7 +69,7 @@ pub fn linear_view_ref<'a, R: CubeRuntime>( pub fn linear_view_alias<'a, R: CubeRuntime>( tensor: &'a CubeTensor, - line_size: u8, + line_size: LineSize, pos: usize, ) -> LinearViewLaunch<'a, R> { let layout = linear_layout(tensor, line_size); diff --git a/crates/burn-cubecl/src/ops/base.rs b/crates/burn-cubecl/src/ops/base.rs index e5da9bdabb..ebddccba86 100644 --- a/crates/burn-cubecl/src/ops/base.rs +++ b/crates/burn-cubecl/src/ops/base.rs @@ -5,7 +5,7 @@ use burn_backend::{ }; use burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape}; use burn_std::tensor::{ReshapeAction, contiguous_strides, reshape_action}; -use cubecl::server::CopyDescriptor; +use cubecl::{ir::LineSize, server::CopyDescriptor}; use cubecl::{quant::scheme::BlockSize, tensor_line_size_parallel}; pub(crate) fn from_data(data: TensorData, device: &R::Device) -> CubeTensor { @@ -333,7 +333,7 @@ pub fn q_reshape(mut tensor: CubeTensor, shape: Shape) -> Cub tensor } -pub(crate) fn max_line_size(tensor: &CubeTensor) -> u8 { +pub(crate) fn max_line_size(tensor: &CubeTensor) -> LineSize { tensor_line_size_parallel( tensor .client @@ -344,7 +344,10 @@ pub(crate) fn max_line_size(tensor: &CubeTensor) -> u8 { ) } -pub(crate) fn max_line_size_many(tensors: &[&CubeTensor], axis: usize) -> u8 { +pub(crate) fn max_line_size_many( + tensors: &[&CubeTensor], + axis: usize, +) -> LineSize { let vec = tensors .iter() .map(|tensor| { diff --git a/crates/burn-cubecl/src/ops/bool_ops.rs b/crates/burn-cubecl/src/ops/bool_ops.rs index 8132d9f429..22d39a2b0d 100644 --- a/crates/burn-cubecl/src/ops/bool_ops.rs +++ b/crates/burn-cubecl/src/ops/bool_ops.rs @@ -9,7 +9,7 @@ use burn_backend::{ tensor::{BoolTensor, Device, FloatTensor, IntTensor}, }; use burn_backend::{Shape, TensorData, tensor::BoolElem}; -use cubecl::std::scalar::InputScalar; +use cubecl::prelude::InputScalar; use std::ops::Range; use super::{expand, numeric, permute, unfold}; diff --git a/crates/burn-cubecl/src/ops/float_ops.rs b/crates/burn-cubecl/src/ops/float_ops.rs index 6b8af09a2f..7de1530860 100644 --- a/crates/burn-cubecl/src/ops/float_ops.rs +++ b/crates/burn-cubecl/src/ops/float_ops.rs @@ -15,7 +15,6 @@ use burn_backend::{Backend, ExecutionError}; use burn_backend::{DType, ElementConversion, FloatDType, Slice}; use burn_backend::{Distribution, Shape, TensorData, ops::FloatTensorOps}; use cubecl::prelude::*; -use cubecl::std::scalar::InputScalar; use cubek::reduce::components::instructions::ReduceOperationConfig; use std::ops::Range; diff --git a/crates/burn-cubecl/src/ops/int_ops.rs b/crates/burn-cubecl/src/ops/int_ops.rs index 5ad13d86f7..d73790d754 100644 --- a/crates/burn-cubecl/src/ops/int_ops.rs +++ b/crates/burn-cubecl/src/ops/int_ops.rs @@ -20,8 +20,8 @@ use burn_backend::ExecutionError; use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use burn_backend::{DType, IntDType, Slice, ops::IntTensorOps}; use burn_backend::{Distribution, ElementConversion, Shape, TensorData}; +use cubecl::frontend::Numeric; use cubecl::prelude::*; -use cubecl::{frontend::Numeric, std::scalar::InputScalar}; use cubek::reduce::components::instructions::ReduceOperationConfig; use std::ops::Range; diff --git a/crates/burn-cubecl/src/ops/numeric.rs b/crates/burn-cubecl/src/ops/numeric.rs index 69889ae8cc..da4acc7641 100644 --- a/crates/burn-cubecl/src/ops/numeric.rs +++ b/crates/burn-cubecl/src/ops/numeric.rs @@ -11,7 +11,7 @@ use crate::{ ops::max_line_size, }; use burn_backend::{DType, Shape}; -use cubecl::std::{FastDivmod, scalar::InputScalar, tensor::layout::linear::LinearView}; +use cubecl::std::{FastDivmod, tensor::layout::linear::LinearView}; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use cubecl::{client::ComputeClient, server::Allocation}; @@ -343,8 +343,8 @@ impl CumulativeOp for MinOp { fn cumulative_kernel( input: &Tensor, output: &mut LinearView, - shape: Sequence, - #[comptime] dim: u32, + shape: Sequence>, + #[comptime] dim: usize, #[define(C)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { @@ -428,7 +428,7 @@ fn cumulative_op( input.as_tensor_arg(1), linear_view(&output, 1), shape_divmod(&input), - dim as u32, + dim, output.dtype.into(), ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/template/base.rs b/crates/burn-cubecl/src/template/base.rs index 68e61c9f8e..bedb44c643 100644 --- a/crates/burn-cubecl/src/template/base.rs +++ b/crates/burn-cubecl/src/template/base.rs @@ -23,6 +23,7 @@ impl CubeTask for SourceKernel { _compiler: &mut C, _options: &C::CompilationOptions, _mode: ExecutionMode, + _address_type: StorageType, ) -> Result, CompilationError> { let source_template = self.kernel_source.source(); let source = source_template.complete(); @@ -42,6 +43,10 @@ impl KernelMetadata for SourceKernel { fn id(&self) -> KernelId { self.kernel_source.id() } + + fn address_type(&self) -> StorageType { + u32::as_type_native_unchecked() + } } /// Generates kernel source code by replacing some information using templating. diff --git a/crates/burn-cubecl/src/tensor/base.rs b/crates/burn-cubecl/src/tensor/base.rs index 8562e2eff3..dec4d239d8 100644 --- a/crates/burn-cubecl/src/tensor/base.rs +++ b/crates/burn-cubecl/src/tensor/base.rs @@ -206,7 +206,7 @@ where } /// Return the reference to a tensor argument. - pub fn as_tensor_arg<'a>(&'a self, line_size: u8) -> TensorArg<'a, R> { + pub fn as_tensor_arg<'a>(&'a self, line_size: LineSize) -> TensorArg<'a, R> { let size = self.dtype.size(); let handle: TensorHandleRef<'a, R> = self.as_handle_ref(); @@ -222,12 +222,12 @@ where } /// Return the reference to an array argument. - pub fn as_array_arg(&self, vectorisation: u8) -> ArrayArg<'_, R> { + pub fn as_array_arg(&self, line_size: LineSize) -> ArrayArg<'_, R> { unsafe { ArrayArg::from_raw_parts::( &self.handle, self.handle.size() as usize / core::mem::size_of::(), - vectorisation, + line_size, ) } } diff --git a/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs index fe9876c23d..9aec65c85b 100644 --- a/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs @@ -17,18 +17,22 @@ use cubecl::{features::Plane, prelude::*}; use super::prefix_sum::prefix_sum; -const BLOCK_H: u32 = 4; +const BLOCK_H: usize = 4; #[cube] fn merge(labels: &Tensor>, label_1: u32, label_2: u32) { let mut label_1 = label_1; let mut label_2 = label_2; - while label_1 != label_2 && (label_1 != u32::cast_from(Atomic::load(&labels[label_1])) - 1) { - label_1 = u32::cast_from(Atomic::load(&labels[label_1])) - 1; + while label_1 != label_2 + && (label_1 != u32::cast_from(Atomic::load(&labels[label_1 as usize])) - 1) + { + label_1 = u32::cast_from(Atomic::load(&labels[label_1 as usize])) - 1; } - while label_1 != label_2 && (label_2 != u32::cast_from(Atomic::load(&labels[label_2])) - 1) { - label_2 = u32::cast_from(Atomic::load(&labels[label_2])) - 1; + while label_1 != label_2 + && (label_2 != u32::cast_from(Atomic::load(&labels[label_2 as usize])) - 1) + { + label_2 = u32::cast_from(Atomic::load(&labels[label_2 as usize])) - 1; } while label_1 != label_2 { #[allow(clippy::manual_swap)] @@ -37,7 +41,10 @@ fn merge(labels: &Tensor>, label_1: u32, label_2: u32) { label_1 = label_2; label_2 = tmp; } - let label_3 = u32::cast_from(Atomic::min(&labels[label_1], I::cast_from(label_2 + 1))) - 1; + let label_3 = u32::cast_from(Atomic::min( + &labels[label_1 as usize], + I::cast_from(label_2 + 1), + )) - 1; if label_1 == label_3 { label_1 = label_2; } else { @@ -60,7 +67,7 @@ fn end_distance(pixels: u32, tx: u32) -> u32 { #[allow(unconditional_panic, reason = "clippy thinks PLANE_DIM is always 2")] fn ballot_dyn(y: u32, pred: bool) -> u32 { let index = y % (PLANE_DIM / 32); - plane_ballot(pred)[index] + plane_ballot(pred)[index as usize] } #[cube(launch_unchecked)] @@ -72,23 +79,23 @@ fn strip_labeling( let mut shared_pixels = SharedMemory::::new(BLOCK_H); let y = ABSOLUTE_POS_Y; - let rows = labels.shape(0); - let cols = labels.shape(1); + let rows = labels.shape(0) as u32; + let cols = labels.shape(1) as u32; if y >= rows { terminate!(); } - let img_stride = img.stride(0); - let labels_stride = labels.stride(0); + let img_stride = img.stride(0) as u32; + let labels_stride = labels.stride(0) as u32; let img_line_base = y * img_stride + UNIT_POS_X; - let labels_line_base = y * labels.stride(0) + UNIT_POS_X; + let labels_line_base = y * labels_stride + UNIT_POS_X; - let mut distance_y = 0; + let mut distance_y = 0u32; let mut distance_y_1 = 0; - for i in range_stepped(0, img.shape(1), PLANE_DIM) { + for i in range_stepped(0, img.shape(1) as u32, PLANE_DIM) { let x = UNIT_POS_X + i; if x < cols { @@ -101,14 +108,14 @@ fn strip_labeling( let img_index = img_line_base + i; let labels_index = labels_line_base + i; - let p_y = bool::cast_from(img[img_index]); + let p_y = bool::cast_from(img[img_index as usize]); let pixels_y = ballot_dyn(UNIT_POS_Y, p_y) & mask; let mut s_dist_y = start_distance(pixels_y, UNIT_POS_X); if p_y && s_dist_y == 0 { Atomic::store( - &labels[labels_index], + &labels[labels_index as usize], I::cast_from(labels_index - select(UNIT_POS_X == 0, distance_y, 0) + 1), ); } @@ -117,7 +124,7 @@ fn strip_labeling( sync_cube(); if UNIT_POS_X == 0 { - shared_pixels[UNIT_POS_Y] = pixels_y; + shared_pixels[UNIT_POS_Y as usize] = pixels_y; } sync_cube(); @@ -125,7 +132,7 @@ fn strip_labeling( // Requires if and not select, because `select` may execute the then branch even if the // condition is false (on non-CUDA backends), which can lead to OOB reads. let pixels_y_1 = if UNIT_POS_Y > 0 { - shared_pixels[UNIT_POS_Y - 1] + shared_pixels[(UNIT_POS_Y - 1) as usize] } else { 0u32.runtime() }; @@ -198,14 +205,14 @@ fn strip_merge( #[comptime] connectivity: Connectivity, ) { let plane_start_x = CUBE_POS_X * (CUBE_DIM_X * CUBE_DIM_Z - PLANE_DIM) + UNIT_POS_Z * PLANE_DIM; - let y = (CUBE_POS_Y + 1) * BLOCK_H; + let y = (CUBE_POS_Y + 1) * BLOCK_H as u32; let x = plane_start_x + UNIT_POS_X; - let img_step = img.stride(0); - let labels_step = labels.stride(0); - let cols = img.shape(1); + let img_step = img.stride(0) as u32; + let labels_step = labels.stride(0) as u32; + let cols = img.shape(1) as u32; - if y < labels.shape(0) && x < labels.shape(1) { + if y < labels.shape(0) as u32 && x < labels.shape(1) as u32 { let mut mask = 0xffffffffu32; if cols - plane_start_x < 32 { mask >>= 32 - (cols - plane_start_x); @@ -217,8 +224,8 @@ fn strip_merge( let img_index_up = img_index - img_step; let labels_index_up = labels_index - labels_step; - let p = bool::cast_from(img[img_index]); - let p_up = bool::cast_from(img[img_index_up]); + let p = bool::cast_from(img[img_index as usize]); + let p_up = bool::cast_from(img[img_index_up as usize]); let pixels = ballot_dyn(UNIT_POS_Z, p) & mask; let pixels_up = ballot_dyn(UNIT_POS_Z, p_up) & mask; @@ -234,27 +241,27 @@ fn strip_merge( } } Connectivity::Eight => { - let mut last_dist_vec = SharedMemory::::new(32); - let mut last_dist_up_vec = SharedMemory::::new(32); + let mut last_dist_vec = SharedMemory::::new(32usize); + let mut last_dist_up_vec = SharedMemory::::new(32usize); let s_dist = start_distance(pixels, UNIT_POS_X); let s_dist_up = start_distance(pixels_up, UNIT_POS_X); if UNIT_POS_PLANE == PLANE_DIM - 1 { - last_dist_vec[UNIT_POS_Z] = start_distance(pixels, 32); - last_dist_up_vec[UNIT_POS_Z] = start_distance(pixels_up, 32); + last_dist_vec[UNIT_POS_Z as usize] = start_distance(pixels, 32); + last_dist_up_vec[UNIT_POS_Z as usize] = start_distance(pixels_up, 32); } sync_cube(); if CUBE_POS_X == 0 || UNIT_POS_Z > 0 { let last_dist = if UNIT_POS_Z > 0 { - last_dist_vec[UNIT_POS_Z - 1] + last_dist_vec[(UNIT_POS_Z - 1) as usize] } else { 0u32.runtime() }; let last_dist_up = if UNIT_POS_Z > 0 { - last_dist_up_vec[UNIT_POS_Z - 1] + last_dist_up_vec[(UNIT_POS_Z - 1) as usize] } else { 0u32.runtime() }; @@ -300,10 +307,10 @@ fn relabeling(img: &Tensor, labels: &mut Tensor(img: &Tensor, labels: &mut Tensor( let y = ABSOLUTE_POS_Y; let x = ABSOLUTE_POS_X; - let cols = labels.shape(1); - let rows = labels.shape(0); - let img_step = img.stride(0); - let labels_step = labels.stride(0); + let cols = labels.shape(1) as u32; + let rows = labels.shape(0) as u32; + let img_step = img.stride(0) as u32; + let labels_step = labels.stride(0) as u32; if x < cols && y < rows { let mut mask = 0xffffffffu32; @@ -363,7 +370,7 @@ fn analysis( let img_index = y * img_step + x; let labels_index = y * labels_step + x; - let p = bool::cast_from(img[img_index]); + let p = bool::cast_from(img[img_index as usize]); let pixels = ballot_dyn(UNIT_POS_Y, p) & mask; let s_dist = start_distance(pixels, UNIT_POS_X); let count = end_distance(pixels, UNIT_POS_X); @@ -372,29 +379,29 @@ fn analysis( let mut label = 0u32; if p && s_dist == 0 { - label = u32::cast_from(labels[labels_index]) - 1; - while label != u32::cast_from(labels[label]) - 1 { - label = u32::cast_from(labels[label]) - 1; + label = u32::cast_from(labels[labels_index as usize]) - 1; + while label != u32::cast_from(labels[label as usize]) - 1 { + label = u32::cast_from(labels[label as usize]) - 1; } label += 1; - Atomic::add(&area[label], I::cast_from(count)); + Atomic::add(&area[label as usize], I::cast_from(count)); if opts.bounds_enabled { - Atomic::min(&left[label], I::cast_from(x)); - Atomic::min(&top[label], I::cast_from(y)); - Atomic::max(&right[label], I::cast_from(max_x)); - Atomic::max(&bottom[label], I::cast_from(y)); + Atomic::min(&left[label as usize], I::cast_from(x)); + Atomic::min(&top[label as usize], I::cast_from(y)); + Atomic::max(&right[label as usize], I::cast_from(max_x)); + Atomic::max(&bottom[label as usize], I::cast_from(y)); } if comptime!(opts.max_label_enabled || opts.compact_labels) { Atomic::max(&max_label[0], I::cast_from(label)); } } - label = plane_broadcast(label, UNIT_POS_X - s_dist); + label = plane_shuffle(label, UNIT_POS_X - s_dist); if p { - labels[labels_index] = I::cast_from(label); + labels[labels_index as usize] = I::cast_from(label); } } } @@ -408,16 +415,16 @@ fn compact_labels( let x = ABSOLUTE_POS_X; let y = ABSOLUTE_POS_Y; - let labels_pos = y * labels.stride(0) + x; + let labels_pos = y * labels.stride(0) as u32 + x; - if labels_pos >= labels.len() { + if labels_pos as usize >= labels.len() { terminate!(); } - let label = u32::cast_from(labels[labels_pos]); + let label = u32::cast_from(labels[labels_pos as usize]); if label != 0 { - let new_label = remap[label]; - labels[labels_pos] = new_label; + let new_label = remap[label as usize]; + labels[labels_pos as usize] = new_label; Atomic::max(&max_label[0], new_label); } } @@ -437,23 +444,23 @@ fn compact_stats( remap: &Tensor, ) { let label = ABSOLUTE_POS_X; - if label >= remap.len() { + if label as usize >= remap.len() { terminate!(); } - let area = area[label]; + let area = area[label as usize]; if area == I::new(0) { terminate!(); } - let new_label = u32::cast_from(remap[label]); + let new_label = u32::cast_from(remap[label as usize]); - area_new[new_label] = area; + area_new[new_label as usize] = area; // This should be gated but there's a problem with the Eq bound only being implemented for tuples // up to 12 elems, so I can't pass the opts. It's not unsafe, but potentially unnecessary work. - top_new[new_label] = top[label]; - left_new[new_label] = left[label]; - right_new[new_label] = right[label]; - bottom_new[new_label] = bottom[label]; + top_new[new_label as usize] = top[label as usize]; + left_new[new_label as usize] = left[label as usize]; + right_new[new_label as usize] = right[label as usize]; + bottom_new[new_label as usize] = bottom[label as usize]; } #[allow(clippy::type_complexity)] @@ -502,7 +509,7 @@ pub fn hardware_accelerated( @@ -22,11 +22,11 @@ fn prefix_sum_kernel( scan_out: &mut Tensor>, scan_bump: &Tensor>, reduction: &Tensor>, - cube_count_x: u32, + cube_count_x: usize, ) { - let mut broadcast = SharedMemory::::new(1); + let mut broadcast = SharedMemory::::new(1usize); let mut reduce = SharedMemory::::new(MAX_REDUCE_SIZE); - let batch = CUBE_POS_Z; + let batch = CUBE_POS_Z as usize; let line_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size()); let nums_per_cube = CUBE_SIZE * line_spt; let v_last = comptime!(scan_in.line_size() - 1); @@ -36,11 +36,11 @@ fn prefix_sum_kernel( broadcast[0] = Atomic::add(&scan_bump[batch], I::new(1)); } sync_cube(); - let part_id = u32::cast_from(broadcast[0]); + let part_id = usize::cast_from(broadcast[0]); let plane_id = UNIT_POS_X / PLANE_DIM; let dev_offs = part_id * nums_per_cube; - let plane_offs = plane_id * PLANE_DIM * line_spt; + let plane_offs = UNIT_POS_X as usize * line_spt; // Exit if full plane is out of bounds if dev_offs + plane_offs >= scan_in.shape(1) { @@ -56,9 +56,9 @@ fn prefix_sum_kernel( let red_offs = batch * reduction.stride(0); let scan_offs = batch * scan_in.stride(0); - let mut t_scan = Array::>::vectorized(line_spt, scan_in.line_size()); + let mut t_scan = Array::>::lined(line_spt, scan_in.line_size()); { - let mut i = dev_offs + plane_offs + UNIT_POS_PLANE; + let mut i = dev_offs + plane_offs + UNIT_POS_PLANE as usize; if part_id < cube_count_x - 1 { for k in 0..line_spt { @@ -70,7 +70,7 @@ fn prefix_sum_kernel( scan[v] += prev; } t_scan[k] = scan; - i += PLANE_DIM; + i += PLANE_DIM as usize; } } @@ -87,7 +87,7 @@ fn prefix_sum_kernel( } t_scan[k] = scan; } - i += PLANE_DIM; + i += PLANE_DIM as usize; } } @@ -95,13 +95,13 @@ fn prefix_sum_kernel( let plane_mask = PLANE_DIM - 1; let circular_shift = (UNIT_POS_PLANE + plane_mask) & plane_mask; for k in 0..line_spt { - let t = plane_broadcast(plane_inclusive_sum(t_scan[k][v_last]), circular_shift); + let t = plane_shuffle(plane_inclusive_sum(t_scan[k][v_last]), circular_shift); t_scan[k] += Line::cast_from(select(UNIT_POS_PLANE != 0, t, zero) + prev); - prev += plane_broadcast(t, 0); + prev += plane_broadcast(t, 0u32); } if UNIT_POS_PLANE == 0 { - reduce[plane_id] = prev; + reduce[plane_id as usize] = prev; } } sync_cube(); @@ -118,9 +118,9 @@ fn prefix_sum_kernel( while j <= aligned_size { let i_0 = ((UNIT_POS_X + offset_0) << offset_1) - offset_0; let pred_0 = i_0 < spine_size; - let t_0 = plane_inclusive_sum(select(pred_0, reduce[i_0], zero)); + let t_0 = plane_inclusive_sum(select(pred_0, reduce[i_0 as usize], zero)); if pred_0 { - reduce[i_0] = t_0; + reduce[i_0 as usize] = t_0; } sync_cube(); @@ -129,9 +129,13 @@ fn prefix_sum_kernel( let i_1 = UNIT_POS_X + rshift; if (i_1 & (j - 1)) >= rshift { let pred_1 = i_1 < spine_size; - let t_1 = select(pred_1, reduce[((i_1 >> offset_1) << offset_1) - 1], zero); + let t_1 = select( + pred_1, + reduce[(((i_1 >> offset_1) << offset_1) - 1) as usize], + zero, + ); if pred_1 && ((i_1 + 1) & (rshift - 1)) != 0 { - reduce[i_1] += t_1; + reduce[i_1 as usize] += t_1; } } } else { @@ -148,7 +152,7 @@ fn prefix_sum_kernel( if UNIT_POS_X == 0 { Atomic::store( &reduction[part_id + red_offs], - (reduce[spine_size - 1] << I::new(2)) + (reduce[(spine_size - 1) as usize] << I::new(2)) | select(part_id != 0, flag_reduction, flag_inclusive), ) } @@ -164,7 +168,8 @@ fn prefix_sum_kernel( prev_reduction += flag_payload >> I::new(2); Atomic::store( &reduction[part_id + red_offs], - ((prev_reduction + reduce[spine_size - 1]) << I::new(2)) | flag_inclusive, + ((prev_reduction + reduce[(spine_size - 1) as usize]) << I::new(2)) + | flag_inclusive, ); broadcast[0] = prev_reduction; break; @@ -181,19 +186,19 @@ fn prefix_sum_kernel( { let prev = if plane_id != 0 { - reduce[plane_id - 1] + reduce[(plane_id - 1) as usize] } else { zero }; let prev = Line::cast_from(broadcast[0] + prev); - let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt; + let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt as u32; let dev_offset = part_id * nums_per_cube; - let mut i = s_offset + dev_offset; + let mut i = s_offset as usize + dev_offset; if part_id < cube_count_x - 1 { for k in 0..line_spt { scan_out[i + scan_offs] = t_scan[k] + prev; - i += PLANE_DIM; + i += PLANE_DIM as usize; } } @@ -202,7 +207,7 @@ fn prefix_sum_kernel( if i < scan_out.shape(1) { scan_out[i + scan_offs] = t_scan[k] + prev; } - i += PLANE_DIM; + i += PLANE_DIM as usize; } } } @@ -217,27 +222,27 @@ fn count_trailing_zeros(num: u32) -> u32 { pub fn prefix_sum(input: CubeTensor) -> CubeTensor { let client = input.client.clone(); let device = input.device.clone(); - let num_elems = input.shape.num_elements() as u32; - let numbers = *input.shape.last().unwrap() as u32; + let num_elems = input.shape.num_elements(); + let numbers = *input.shape.last().unwrap(); let batches = num_elems / numbers; - let input = reshape(input, Shape::new([batches as usize, numbers as usize])); + let input = reshape(input, Shape::new([batches, numbers])); let out = empty_device::(client.clone(), device.clone(), input.shape.clone()); let cubes = numbers.div_ceil(PART_SIZE); - let cube_dim = CubeDim::new_1d(CUBE_SIZE); - let cube_count = CubeCount::new_3d(cubes, 1, batches); + let cube_dim = CubeDim::new_1d(CUBE_SIZE as u32); + let cube_count = CubeCount::new_3d(cubes as u32, 1, batches as u32); let bump = zeros_client::( client.clone(), device.clone(), - Shape::new([batches as usize]), + Shape::new([batches]), I::dtype(), ); let reduction = zeros_client::( client.clone(), device.clone(), - Shape::new([batches as usize, cubes as usize]), + Shape::new([batches, cubes]), I::dtype(), ); diff --git a/examples/custom-cubecl-kernel/src/kernel.rs b/examples/custom-cubecl-kernel/src/kernel.rs index 08d4ded4d7..872feb5df7 100644 --- a/examples/custom-cubecl-kernel/src/kernel.rs +++ b/examples/custom-cubecl-kernel/src/kernel.rs @@ -8,9 +8,9 @@ pub fn fused_matmul_add_relu_kernel( bias: &Tensor, output: &mut Tensor, ) { - let row = ABSOLUTE_POS_X; - let col = ABSOLUTE_POS_Y; - let batch = ABSOLUTE_POS_Z; + let row = ABSOLUTE_POS_X as usize; + let col = ABSOLUTE_POS_Y as usize; + let batch = ABSOLUTE_POS_Z as usize; let n_rows = output.shape(output.rank() - 2); let n_cols = output.shape(output.rank() - 1); From e60a2d365eea087d3d5d79abccd5d97dc93ad1fc Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 5 Jan 2026 01:24:16 +0100 Subject: [PATCH 2/4] Fix merge --- .../src/optim/matmul/args.rs | 8 ++--- .../kernel/conv/deform_conv_transpose2d.rs | 2 +- crates/burn-cubecl/src/kernel/conv/direct.rs | 6 ++-- .../src/kernel/grid_sample/base.rs | 36 +++++++++---------- .../src/kernel/grid_sample/bilinear.rs | 31 ++++++++-------- 5 files changed, 43 insertions(+), 40 deletions(-) diff --git a/crates/burn-cubecl-fusion/src/optim/matmul/args.rs b/crates/burn-cubecl-fusion/src/optim/matmul/args.rs index 4dac36ea89..ccdfa4e3c8 100644 --- a/crates/burn-cubecl-fusion/src/optim/matmul/args.rs +++ b/crates/burn-cubecl-fusion/src/optim/matmul/args.rs @@ -225,7 +225,7 @@ impl MatmulArgs for FusedMatmulArgs { fn global_view( inputs: &GlobalArgs, locals: &LocalArgs, - batch_shape: &Sequence, + batch_shape: &Sequence>, #[comptime] arg: MatmulArg, #[comptime] config: FuseBlockConfig, #[comptime] layout_config: GlobalLayoutConfig, @@ -316,7 +316,7 @@ fn global_view( #[cube] fn input_batch_layout( inputs: &GlobalArgs, - batch_shape: &Sequence, + batch_shape: &Sequence>, #[comptime] arg: MatmulArg, #[comptime] config: FuseBlockConfig, ) -> VirtualLayout { @@ -464,7 +464,7 @@ pub struct FusedMatmulState { rhs_layout_config: GlobalLayoutConfig, #[cube(comptime)] out_layout_config: GlobalLayoutConfig, - batch_shape: Sequence, + batch_shape: Sequence>, } #[cube] @@ -478,7 +478,7 @@ impl FusedMatmulState { b_batch: VirtualLayout, c_batch: CubeOption>, out_batch: VirtualLayout, - batch_shape: Sequence, + batch_shape: Sequence>, #[comptime] config: &FuseBlockConfig, #[comptime] lhs_layout_config: GlobalLayoutConfig, #[comptime] rhs_layout_config: GlobalLayoutConfig, diff --git a/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs index b199d80d71..9787dfe53a 100644 --- a/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs @@ -364,7 +364,7 @@ fn deform_col2img_coord_kernel( let c_bound = channels_per_offset_group * kernel_h * kernel_w; - for col_c in range_stepped(offset_c / 2, c_bound, col_step as u32) { + for col_c in range_stepped(offset_c / 2, c_bound, col_step) { let col_pos = (((col_c * batch_size + b) * out_h) + h) * out_w + w; let out_x = col_pos % out_w; diff --git a/crates/burn-cubecl/src/kernel/conv/direct.rs b/crates/burn-cubecl/src/kernel/conv/direct.rs index 36bf1d3a0e..e4a268a80d 100644 --- a/crates/burn-cubecl/src/kernel/conv/direct.rs +++ b/crates/burn-cubecl/src/kernel/conv/direct.rs @@ -33,8 +33,8 @@ fn direct_conv2d_kernel( bias: CubeOption>>, output: &mut LinearView, ReadWrite>, args: Conv2dArgs, - shape_out: Sequence, - shape_out_c: FastDivmod, + shape_out: Sequence>, + shape_out_c: FastDivmod, #[comptime] has_padding: bool, #[define(E)] _dtype: StorageType, ) { @@ -317,7 +317,7 @@ pub fn conv_direct( } #[cube] -pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence) -> (u32, Sequence) { +pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence>) -> (u32, Sequence) { let rank = comptime![shape.len()]; let mut offs = pos; let mut out = Sequence::new(); diff --git a/crates/burn-cubecl/src/kernel/grid_sample/base.rs b/crates/burn-cubecl/src/kernel/grid_sample/base.rs index 4f7d0e15ba..7bfdabdbb9 100644 --- a/crates/burn-cubecl/src/kernel/grid_sample/base.rs +++ b/crates/burn-cubecl/src/kernel/grid_sample/base.rs @@ -45,9 +45,9 @@ impl From for PaddingMode { #[cube] pub(crate) fn fetch_value( input: &Tensor, - base: u32, - stride_h: u32, - stride_w: u32, + base: usize, + stride_h: usize, + stride_w: usize, y: i32, x: i32, h: i32, @@ -67,17 +67,17 @@ pub(crate) fn fetch_value( #[cube] pub(crate) fn fetch_with_zeros( input: &Tensor, - base: u32, - stride_h: u32, - stride_w: u32, + base: usize, + stride_h: usize, + stride_w: usize, y: i32, x: i32, h: i32, w: i32, ) -> F { let in_bounds = x >= 0 && x < w && y >= 0 && y < h; - let x_clamped = Min::min(Max::max(x, 0), w - 1) as u32; - let y_clamped = Min::min(Max::max(y, 0), h - 1) as u32; + let x_clamped = Min::min(Max::max(x, 0), w - 1) as usize; + let y_clamped = Min::min(Max::max(y, 0), h - 1) as usize; let idx = base + y_clamped * stride_h + x_clamped * stride_w; select(in_bounds, input[idx], F::new(0.0)) } @@ -86,16 +86,16 @@ pub(crate) fn fetch_with_zeros( #[cube] pub(crate) fn fetch_with_border( input: &Tensor, - base: u32, - stride_h: u32, - stride_w: u32, + base: usize, + stride_h: usize, + stride_w: usize, y: i32, x: i32, h: i32, w: i32, ) -> F { - let x_clamped = Min::min(Max::max(x, 0), w - 1) as u32; - let y_clamped = Min::min(Max::max(y, 0), h - 1) as u32; + let x_clamped = Min::min(Max::max(x, 0), w - 1) as usize; + let y_clamped = Min::min(Max::max(y, 0), h - 1) as usize; let idx = base + y_clamped * stride_h + x_clamped * stride_w; input[idx] } @@ -105,9 +105,9 @@ pub(crate) fn fetch_with_border( #[cube] pub(crate) fn fetch_with_reflection( input: &Tensor, - base: u32, - stride_h: u32, - stride_w: u32, + base: usize, + stride_h: usize, + stride_w: usize, y: i32, x: i32, h: i32, @@ -122,7 +122,7 @@ pub(crate) fn fetch_with_reflection( /// Reflect an integer index that may be out of bounds. /// After float reflection, indices can be up to 2 steps out for bicubic (1 step for bilinear). #[cube] -fn reflect_coord_bounded(idx: i32, size: i32) -> u32 { +fn reflect_coord_bounded(idx: i32, size: i32) -> usize { let max_idx = size - 1; let neg_reflected = -idx - 1; let pos_reflected = 2 * max_idx + 1 - idx; @@ -131,7 +131,7 @@ fn reflect_coord_bounded(idx: i32, size: i32) -> u32 { neg_reflected, select(idx > max_idx, pos_reflected, idx), ); - Min::min(Max::max(result, 0), max_idx) as u32 + Min::min(Max::max(result, 0), max_idx) as usize } /// Reflect a float coordinate into the valid sampling range. diff --git a/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs b/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs index 7e92301fc5..5ec1e2a0e4 100644 --- a/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs +++ b/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs @@ -14,10 +14,10 @@ use super::base::{PaddingMode, fetch_value, reflect_coord}; /// 3. For each channel: fetch 4 corner values, interpolate, and write output #[cube(launch)] fn grid_sample_bilinear_kernel( - input: &Tensor, // [N, C, H_in, W_in] - grid: &Tensor, // [N, H_out, W_out, 2] - output: &mut Tensor, // [N, C, H_out, W_out] - shape_spatial: Sequence, // [N, H_out, W_out] for thread decomposition + input: &Tensor, // [N, C, H_in, W_in] + grid: &Tensor, // [N, H_out, W_out, 2] + output: &mut Tensor, // [N, C, H_out, W_out] + shape_spatial: Sequence>, // [N, H_out, W_out] for thread decomposition #[comptime] align_corners: bool, #[comptime] pad_mode: PaddingMode, #[define(F)] _dtype: StorageType, @@ -30,15 +30,17 @@ fn grid_sample_bilinear_kernel( } // Decompose spatial index into (n, h_out, w_out) - let (rem, w_out) = shape_spatial.index(2).div_mod(spatial_idx); - let (n, h_out) = shape_spatial.index(1).div_mod(rem); + let (rem, w_out) = shape_spatial[2].div_mod(spatial_idx as u32); + let (n, h_out) = shape_spatial[1].div_mod(rem); - let channels = input.shape(1); - let h_in = input.shape(2); - let w_in = input.shape(3); + let channels = input.shape(1) as u32; + let h_in = input.shape(2) as u32; + let w_in = input.shape(3) as u32; // Read grid coordinates once per spatial position - let grid_offset = n * grid.stride(0) + h_out * grid.stride(1) + w_out * grid.stride(2); + let grid_offset = n as usize * grid.stride(0) + + h_out as usize * grid.stride(1) + + w_out as usize * grid.stride(2); let gx = grid[grid_offset]; // x coordinate in [-1, 1] let gy = grid[grid_offset + 1]; // y coordinate in [-1, 1] @@ -95,12 +97,13 @@ fn grid_sample_bilinear_kernel( let out_stride_w = output.stride(3); // Base offsets for this spatial position - let in_base_n = n * stride_n; - let out_base_spatial = n * out_stride_n + h_out * out_stride_h + w_out * out_stride_w; + let in_base_n = n as usize * stride_n; + let out_base_spatial = + n as usize * out_stride_n + h_out as usize * out_stride_h + w_out as usize * out_stride_w; // Loop over all channels - grid coords and weights are reused for c in 0..channels { - let in_base = in_base_n + c * stride_c; + let in_base = in_base_n + c as usize * stride_c; let v00 = fetch_value( input, in_base, stride_h, stride_w, y0, x0, h_in, w_in, pad_mode, @@ -118,7 +121,7 @@ fn grid_sample_bilinear_kernel( // Bilinear interpolation let result = wx_ * wy_ * v00 + wx_ * wy * v01 + wx * wy_ * v10 + wx * wy * v11; - let out_idx = out_base_spatial + c * out_stride_c; + let out_idx = out_base_spatial + c as usize * out_stride_c; output[out_idx] = result; } } From 6f2c8d4f01c67943da3d43f247cb1c51294029e6 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 7 Jan 2026 23:42:09 +0100 Subject: [PATCH 3/4] Update cubecl --- Cargo.lock | 17 +++++++++++++++++ Cargo.toml | 8 ++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6fdf721cf3..31463292f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1993,11 +1993,13 @@ dependencies = [ [[package]] name = "cubecl" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "cubecl-core", "cubecl-cpu", "cubecl-cuda", "cubecl-hip", + "cubecl-ir", "cubecl-runtime", "cubecl-std", "cubecl-wgpu", @@ -2007,6 +2009,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "backtrace", "bytemuck", @@ -2042,6 +2045,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bitflags 2.10.0", "bytemuck", @@ -2066,6 +2070,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bytemuck", "cubecl-common", @@ -2081,6 +2086,7 @@ dependencies = [ [[package]] name = "cubecl-cpu" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bytemuck", "cubecl-common", @@ -2101,6 +2107,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bytemuck", "cubecl-common", @@ -2118,6 +2125,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bytemuck", "cubecl-common", @@ -2146,6 +2154,7 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -2165,6 +2174,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "cubecl-common", "darling 0.21.3", @@ -2179,6 +2189,7 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "darling 0.21.3", "proc-macro2", @@ -2189,6 +2200,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "cubecl-common", "cubecl-core", @@ -2205,6 +2217,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "async-channel", "bytemuck", @@ -2233,6 +2246,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bitflags 2.10.0", "cubecl-common", @@ -2247,6 +2261,7 @@ dependencies = [ [[package]] name = "cubecl-std" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "cubecl-common", "cubecl-core", @@ -2263,6 +2278,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.9.0-pre.6" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "ash", "async-channel", @@ -2272,6 +2288,7 @@ dependencies = [ "cubecl-common", "cubecl-core", "cubecl-cpp", + "cubecl-ir", "cubecl-runtime", "cubecl-spirv", "derive-new", diff --git a/Cargo.toml b/Cargo.toml index 4154c0aca5..5d53c719c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -180,12 +180,12 @@ portable-atomic = { version = "1.11.1" } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "48ff83f19952d053b80ab5762baf387f451e5c63" } -# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "48ff83f19952d053b80ab5762baf387f451e5c63" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7cd8e9c869d363edcbf989fd905eb4b28df938ab" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7cd8e9c869d363edcbf989fd905eb4b28df938ab" } # cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "0d9a635229d3cabfa8297ddc967ff4e783be348c" } ### For local development. ### -cubecl = { path = "../cubecl/crates/cubecl", default-features = false } -cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } cubek = { path = "../cubek/crates/cubek", default-features = false } ### For the release. ### # cubecl = { version = "=0.9.0-pre.6", default-features = false } From e0d9252bdde4b55721746e470b58359fe38c4b40 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 8 Jan 2026 17:43:12 +0100 Subject: [PATCH 4/4] Update cubek --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d5b2fa258e..a1f9bd725c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -182,11 +182,11 @@ portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7cd8e9c869d363edcbf989fd905eb4b28df938ab" } cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7cd8e9c869d363edcbf989fd905eb4b28df938ab" } -# cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "0d9a635229d3cabfa8297ddc967ff4e783be348c" } +cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "8097a621dfd3a6a89f8d0433994e8c1adba377c2" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } -cubek = { path = "../cubek/crates/cubek", default-features = false } +# cubek = { path = "../cubek/crates/cubek", default-features = false } ### For the release. ### # cubecl = { version = "=0.9.0-pre.6", default-features = false } # cubecl-common = { version = "=0.9.0-pre.6", default-features = false }