diff --git a/Cargo.lock b/Cargo.lock index ea43fd7a0f..16d43547e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -172,7 +172,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -251,7 +251,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -262,7 +262,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -373,9 +373,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" dependencies = [ "bytes", "futures-core", @@ -419,9 +419,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64ct" -version = "1.8.1" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a" +checksum = "7d809780667f4410e7c41b07f52439b94d2bdf8528eeedc287fa38d3b7f95d82" [[package]] name = "bincode" @@ -460,7 +460,7 @@ dependencies = [ "regex", "rustc-hash 2.1.1", "shlex", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -589,7 +589,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -910,7 +910,7 @@ dependencies = [ "derive-new", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -948,7 +948,7 @@ dependencies = [ "rust-format", "serde", "serde_json", - "syn 2.0.111", + "syn 2.0.114", "thiserror 2.0.17", "tracing-core", "tracing-subscriber", @@ -1169,7 +1169,7 @@ version = "0.20.0-pre.6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1249,7 +1249,7 @@ checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1519,9 +1519,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.53" +version = "4.5.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" +checksum = "c6e6ff9dcd79cff5cd969a17a545d79e84ab086e444102a591e288a8aa3ce394" dependencies = [ "clap_builder", "clap_derive", @@ -1529,9 +1529,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.53" +version = "4.5.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" +checksum = "fa42cf4d2b7a41bc8f663a7cab4031ebafa1bf3875705bfaf8466dc60ab52c00" dependencies = [ "anstream", "anstyle", @@ -1549,7 +1549,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -1914,7 +1914,7 @@ dependencies = [ "mio", "parking_lot", "rustix 0.38.44", - "signal-hook", + "signal-hook 0.3.18", "signal-hook-mio", "winapi", ] @@ -1993,12 +1993,13 @@ dependencies = [ [[package]] name = "cubecl" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "cubecl-core", "cubecl-cpu", "cubecl-cuda", "cubecl-hip", + "cubecl-ir", "cubecl-runtime", "cubecl-std", "cubecl-wgpu", @@ -2008,7 +2009,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "backtrace", "bytemuck", @@ -2044,7 +2045,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bitflags 2.10.0", "bytemuck", @@ -2069,7 +2070,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bytemuck", "cubecl-common", @@ -2085,7 +2086,7 @@ dependencies = [ [[package]] name = "cubecl-cpu" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bytemuck", "cubecl-common", @@ -2106,7 +2107,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bytemuck", "cubecl-common", @@ -2124,7 +2125,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bytemuck", "cubecl-common", @@ -2153,7 +2154,7 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -2173,7 +2174,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "cubecl-common", "darling 0.21.3", @@ -2182,24 +2183,24 @@ dependencies = [ "prettyplease 0.2.37", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] name = "cubecl-macros-internal" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "darling 0.21.3", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] name = "cubecl-opt" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "cubecl-common", "cubecl-core", @@ -2216,7 +2217,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "async-channel", "bytemuck", @@ -2245,7 +2246,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "bitflags 2.10.0", "cubecl-common", @@ -2260,7 +2261,7 @@ dependencies = [ [[package]] name = "cubecl-std" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "cubecl-common", "cubecl-core", @@ -2277,7 +2278,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.9.0-pre.6" -source = "git+https://github.com/tracel-ai/cubecl?rev=88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35#88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" +source = "git+https://github.com/tracel-ai/cubecl?rev=7cd8e9c869d363edcbf989fd905eb4b28df938ab#7cd8e9c869d363edcbf989fd905eb4b28df938ab" dependencies = [ "ash", "async-channel", @@ -2287,6 +2288,7 @@ dependencies = [ "cubecl-common", "cubecl-core", "cubecl-cpp", + "cubecl-ir", "cubecl-runtime", "cubecl-spirv", "derive-new", @@ -2303,7 +2305,6 @@ dependencies = [ [[package]] name = "cubek" version = "0.1.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=ec439f0c648708c6d436d31eb1a3da9cb634d77c#ec439f0c648708c6d436d31eb1a3da9cb634d77c" dependencies = [ "cubecl", "cubek-attention", @@ -2317,7 +2318,6 @@ dependencies = [ [[package]] name = "cubek-attention" version = "0.1.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=ec439f0c648708c6d436d31eb1a3da9cb634d77c#ec439f0c648708c6d436d31eb1a3da9cb634d77c" dependencies = [ "bytemuck", "cubecl", @@ -2331,7 +2331,6 @@ dependencies = [ [[package]] name = "cubek-convolution" version = "0.1.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=ec439f0c648708c6d436d31eb1a3da9cb634d77c#ec439f0c648708c6d436d31eb1a3da9cb634d77c" dependencies = [ "bytemuck", "cubecl", @@ -2346,7 +2345,6 @@ dependencies = [ [[package]] name = "cubek-matmul" version = "0.1.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=ec439f0c648708c6d436d31eb1a3da9cb634d77c#ec439f0c648708c6d436d31eb1a3da9cb634d77c" dependencies = [ "bytemuck", "cubecl", @@ -2358,7 +2356,6 @@ dependencies = [ [[package]] name = "cubek-quant" version = "0.1.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=ec439f0c648708c6d436d31eb1a3da9cb634d77c#ec439f0c648708c6d436d31eb1a3da9cb634d77c" dependencies = [ "cubecl", "cubecl-common", @@ -2369,7 +2366,6 @@ dependencies = [ [[package]] name = "cubek-random" version = "0.1.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=ec439f0c648708c6d436d31eb1a3da9cb634d77c#ec439f0c648708c6d436d31eb1a3da9cb634d77c" dependencies = [ "cubecl", "cubecl-common", @@ -2382,7 +2378,6 @@ dependencies = [ [[package]] name = "cubek-reduce" version = "0.1.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=ec439f0c648708c6d436d31eb1a3da9cb634d77c#ec439f0c648708c6d436d31eb1a3da9cb634d77c" dependencies = [ "cubecl", "half", @@ -2508,6 +2503,16 @@ dependencies = [ "darling_macro 0.21.3", ] +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core 0.23.0", + "darling_macro 0.23.0", +] + [[package]] name = "darling_core" version = "0.20.11" @@ -2519,7 +2524,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2533,7 +2538,20 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.111", + "syn 2.0.114", +] + +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.114", ] [[package]] @@ -2544,7 +2562,7 @@ checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core 0.20.11", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2555,7 +2573,18 @@ checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" dependencies = [ "darling_core 0.21.3", "quote", - "syn 2.0.111", + "syn 2.0.114", +] + +[[package]] +name = "darling_macro" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +dependencies = [ + "darling_core 0.23.0", + "quote", + "syn 2.0.114", ] [[package]] @@ -2626,7 +2655,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2637,7 +2666,7 @@ checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2658,7 +2687,7 @@ dependencies = [ "darling 0.20.11", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2668,29 +2697,29 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] name = "derive_more" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618" +checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" dependencies = [ "derive_more-impl", ] [[package]] name = "derive_more-impl" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b" +checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" dependencies = [ "convert_case 0.10.0", "proc-macro2", "quote", "rustc_version", - "syn 2.0.111", + "syn 2.0.114", "unicode-xid", ] @@ -2758,7 +2787,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2783,7 +2812,7 @@ checksum = "9556bc800956545d6420a640173e5ba7dfa82f38d3ea5a167eb555bc69ac3323" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2877,7 +2906,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2972,7 +3001,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -2992,7 +3021,7 @@ checksum = "8ca9601fb2d62598ee17836250842873a413586e5d7ed88b356e38ddbb0ec631" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -3014,7 +3043,7 @@ dependencies = [ "darling 0.21.3", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -3057,7 +3086,7 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -3176,7 +3205,7 @@ checksum = "a0aca10fb742cb43f9e7bb8467c91aa9bcb8e3ffbc6a6f7389bb93ffc920577d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -3298,7 +3327,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -3401,7 +3430,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -3732,9 +3761,9 @@ checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" [[package]] name = "gix-features" -version = "0.45.0" +version = "0.45.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba0ba40b1ca17f2cb3987c8d54e596aba924201cd8e5947098b441067e6686a0" +checksum = "d56aad357ae016449434705033df644ac6253dfcf1281aad3af3af9e907560d1" dependencies = [ "gix-trace", "gix-utils", @@ -3743,9 +3772,9 @@ dependencies = [ [[package]] name = "gix-fs" -version = "0.18.0" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95b160a13547a64d67a02d894e4f5502a2a5f98635c89931f6bb9c7a4c80c7db" +checksum = "785b9c499e46bc78d7b81c148c21b3fca18655379ee729a856ed19ce50d359ec" dependencies = [ "bstr", "fastrand", @@ -3769,24 +3798,24 @@ dependencies = [ [[package]] name = "gix-tempfile" -version = "20.0.0" +version = "20.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "816bbb99bbf8cd329e38342594528506f224c4937a6341dbd1d16ee4082f621c" +checksum = "ad89218e74850f42d364ed3877c7291f0474c8533502df91bb877ecc5cb0dd40" dependencies = [ "dashmap", "gix-fs", "libc", "parking_lot", - "signal-hook", + "signal-hook 0.4.1", "signal-hook-registry", "tempfile", ] [[package]] name = "gix-trace" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edd971cd6961fb1ebb29a0052a4ab04d8498dbf363c122e137b04753a3bbb5c3" +checksum = "6e42a4c2583357721ba2d887916e78df504980f22f1182df06997ce197b89504" [[package]] name = "gix-utils" @@ -3932,9 +3961,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" dependencies = [ "atomic-waker", "bytes", @@ -4171,7 +4200,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots 1.0.4", + "webpki-roots 1.0.5", ] [[package]] @@ -4380,7 +4409,7 @@ dependencies = [ "rgb", "tiff", "zune-core 0.5.0", - "zune-jpeg 0.5.7", + "zune-jpeg 0.5.8", ] [[package]] @@ -4430,9 +4459,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", "hashbrown 0.16.1", @@ -4486,9 +4515,9 @@ dependencies = [ [[package]] name = "insta" -version = "1.45.0" +version = "1.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b76866be74d68b1595eb8060cb9191dca9c021db2316558e52ddc5d55d41b66c" +checksum = "1b66886d14d18d420ab5052cbff544fc5d34d0b2cdd35eb5976aaa10a4a472e5" dependencies = [ "console 0.15.11", "once_cell", @@ -4498,15 +4527,15 @@ dependencies = [ [[package]] name = "instability" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6778b0196eefee7df739db78758e5cf9b37412268bfa5650bfeed028aed20d9c" +checksum = "357b7205c6cd18dd2c86ed312d1e70add149aea98e7ef72b9fdf0270e555c11d" dependencies = [ - "darling 0.20.11", + "darling 0.23.0", "indoc", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -4517,7 +4546,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -4541,9 +4570,9 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "iri-string" -version = "0.7.9" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f867b9d1d896b67beb18518eda36fdb77a32ea590de864f1325b294a6d14397" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" dependencies = [ "memchr", "serde", @@ -4575,15 +4604,15 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" [[package]] name = "jiff" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49cce2b81f2098e7e3efc35bc2e0a6b7abec9d34128283d7a26fa8f32a6dbb35" +checksum = "e67e8da4c49d6d9909fe03361f9b620f58898859f5c7aded68351e85e71ecf50" dependencies = [ "jiff-static", "log", @@ -4594,13 +4623,13 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "980af8b43c3ad5d8d349ace167ec8170839f753a42d233ba19e08afe1850fa69" +checksum = "e0c84ee7f197eca9a86c6fd6cb771e55eb991632f15f2bc3ca6ec838929e6e78" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -4719,13 +4748,13 @@ checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libredox" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df15f6eac291ed1cf25865b1ee60399f57e7c227e7f51bdbd4c5270396a9ed50" +checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" dependencies = [ "bitflags 2.10.0", "libc", - "redox_syscall 0.6.0", + "redox_syscall 0.7.0", ] [[package]] @@ -4741,9 +4770,9 @@ dependencies = [ [[package]] name = "libz-rs-sys" -version = "0.5.4" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15413ef615ad868d4d65dce091cb233b229419c7c0c4bcaa746c0901c49ff39c" +checksum = "c10501e7805cee23da17c7790e59df2870c0d4043ec6d03f67d31e2b53e77415" dependencies = [ "zlib-rs", ] @@ -4832,9 +4861,9 @@ dependencies = [ [[package]] name = "lzma-rust2" -version = "0.15.4" +version = "0.15.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48172246aa7c3ea28e423295dd1ca2589a24617cc4e588bb8cfe177cb2c54d95" +checksum = "17f7337d278fec032975dc884152491580dd23750ee957047856735fe0e61ede" dependencies = [ "crc", "sha2", @@ -4865,7 +4894,7 @@ dependencies = [ "darling 0.20.11", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -5092,7 +5121,7 @@ checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -5156,7 +5185,7 @@ dependencies = [ "libc", "log", "openssl", - "openssl-probe", + "openssl-probe 0.1.6", "openssl-sys", "schannel", "security-framework 2.11.1", @@ -5290,9 +5319,9 @@ dependencies = [ [[package]] name = "ntapi" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a3895c6391c39d7fe7ebc444a87eb2991b2a0bc718fdabd071eec617fc68e4" +checksum = "c70f219e21142367c70c0b30c6a9e3a14d55b4d12a204d897fbec83a0363f081" dependencies = [ "winapi", ] @@ -5354,7 +5383,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -5427,7 +5456,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -5647,7 +5676,7 @@ version = "0.20.0-pre.6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -5710,7 +5739,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -5719,6 +5748,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" +[[package]] +name = "openssl-probe" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f50d9b3dabb09ecd771ad0aa242ca6894994c130308ca3d7684634df8037391" + [[package]] name = "openssl-sys" version = "0.9.111" @@ -5976,7 +6011,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -6170,7 +6205,7 @@ dependencies = [ "parking_lot", "polars-arrow-format", "regex", - "signal-hook", + "signal-hook 0.3.18", "simdutf8", ] @@ -6606,7 +6641,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -6620,9 +6655,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.104" +version = "1.0.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9695f8df41bb4f3d222c95a67532365f569318332d03d5f3f67f37b20e6ebdf0" +checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7" dependencies = [ "unicode-ident", ] @@ -6643,7 +6678,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b" dependencies = [ "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -6666,7 +6701,7 @@ dependencies = [ "itertools 0.14.0", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -6860,9 +6895,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.42" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a" dependencies = [ "proc-macro2", ] @@ -7126,7 +7161,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -7146,9 +7181,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec96166dafa0886eb81fe1c0a388bece180fbef2135f97c1e2cf8302e74b43b5" +checksum = "49f3fe0889e69e2ae9e41f4d6c4c0181701d00e4697b356fb1f74173a5e0ee27" dependencies = [ "bitflags 2.10.0", ] @@ -7213,9 +7248,9 @@ checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" [[package]] name = "reqwest" -version = "0.12.26" +version = "0.12.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b4c14b2d9afca6a60277086b0cc6a6ae0b568f6f7916c943a8cdc79f8be240f" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64 0.22.1", "bytes", @@ -7252,7 +7287,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots 1.0.4", + "webpki-roots 1.0.5", ] [[package]] @@ -7280,13 +7315,11 @@ dependencies = [ [[package]] name = "rmp" -version = "0.8.14" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" dependencies = [ - "byteorder", "num-traits", - "paste", ] [[package]] @@ -7324,7 +7357,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.111", + "syn 2.0.114", "unicode-ident", ] @@ -7408,9 +7441,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.35" +version = "0.23.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" +checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" dependencies = [ "log", "once_cell", @@ -7423,11 +7456,11 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9980d917ebb0c0536119ba501e90834767bffc3d60641457fd84a1f3fd337923" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" dependencies = [ - "openssl-probe", + "openssl-probe 0.2.0", "rustls-pki-types", "schannel", "security-framework 3.5.1", @@ -7462,9 +7495,9 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" [[package]] name = "safetensors" @@ -7662,14 +7695,14 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] name = "serde_json" -version = "1.0.148" +version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3084b546a1dd6289475996f182a22aba973866ea8e8b02c51d9f46b1336a22da" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ "itoa", "memchr", @@ -7733,11 +7766,12 @@ dependencies = [ [[package]] name = "serial_test" -version = "3.2.0" +version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b258109f244e1d6891bf1053a55d63a5cd4f8f4c30cf9a1280989f80e7a1fa9" +checksum = "0d0b343e184fc3b7bb44dff0705fffcf4b3756ba6aff420dddd8b24ca145e555" dependencies = [ - "futures", + "futures-executor", + "futures-util", "log", "once_cell", "parking_lot", @@ -7747,13 +7781,13 @@ dependencies = [ [[package]] name = "serial_test_derive" -version = "3.2.0" +version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" +checksum = "6f50427f258fb77356e4cd4aa0e87e2bd2c66dbcee41dc405282cae2bfc26c83" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -7812,6 +7846,16 @@ dependencies = [ "signal-hook-registry", ] +[[package]] +name = "signal-hook" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a37d01603c37b5466f808de79f845c7116049b0579adb70a6b7d47c1fa3a952" +dependencies = [ + "libc", + "signal-hook-registry", +] + [[package]] name = "signal-hook-mio" version = "0.2.5" @@ -7820,15 +7864,16 @@ checksum = "b75a19a7a740b25bc7944bdee6172368f988763b744e3d4dfe753f6b4ece40cc" dependencies = [ "libc", "mio", - "signal-hook", + "signal-hook 0.3.18", ] [[package]] name = "signal-hook-registry" -version = "1.4.7" +version = "1.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7664a098b8e616bdfcc2dc0e9ac44eb231eedf41db4e9fe95d8d32ec728dedad" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" dependencies = [ + "errno", "libc", ] @@ -8067,7 +8112,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -8079,7 +8124,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -8101,9 +8146,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.111" +version = "2.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" dependencies = [ "proc-macro2", "quote", @@ -8127,7 +8172,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -8335,7 +8380,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -8346,7 +8391,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -8508,7 +8553,7 @@ checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -8538,9 +8583,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.17" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" dependencies = [ "futures-core", "pin-project-lite", @@ -8561,9 +8606,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.17" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", @@ -8785,7 +8830,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.111", + "syn 2.0.114", "tracel-llvm-bundler", "tracel-tblgen-rs", "unindent", @@ -8853,7 +8898,7 @@ checksum = "ce80fb81ad70dc91a536a3a47d0dd003de54fba384fd1ea62c3af07d41e764dc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -8888,7 +8933,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -9202,9 +9247,9 @@ dependencies = [ [[package]] name = "url" -version = "2.5.7" +version = "2.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" dependencies = [ "form_urlencoded", "idna", @@ -9268,7 +9313,7 @@ checksum = "41b6d82be61465f97d42bd1d15bf20f3b0a3a0905018f38f9d6f6962055b0b5c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -9374,7 +9419,7 @@ dependencies = [ "bumpalo", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", "wasm-bindgen-shared", ] @@ -9433,9 +9478,9 @@ dependencies = [ [[package]] name = "webpki-root-certs" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee3e3b5f5e80bc89f30ce8d0343bf4e5f12341c51f3e26cbeecbc7c85443e85b" +checksum = "36a29fc0408b113f68cf32637857ab740edfafdf460c326cd2afaa2d84cc05dc" dependencies = [ "rustls-pki-types", ] @@ -9446,14 +9491,14 @@ version = "0.26.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" dependencies = [ - "webpki-roots 1.0.4", + "webpki-roots 1.0.5", ] [[package]] name = "webpki-roots" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2878ef029c47c6e8cf779119f20fcf52bde7ad42a731b2a304bc221df17571e" +checksum = "12bed680863276c63889429bfd6cab3b99943659923822de1c8a39c49e4d722c" dependencies = [ "rustls-pki-types", ] @@ -9760,7 +9805,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -9771,7 +9816,7 @@ checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -9782,7 +9827,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -9793,7 +9838,7 @@ checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -10071,7 +10116,7 @@ dependencies = [ "darling 0.20.11", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -10155,7 +10200,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", "synstructure", ] @@ -10167,28 +10212,28 @@ checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", "synstructure", ] [[package]] name = "zerocopy" -version = "0.8.31" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" +checksum = "668f5168d10b9ee831de31933dc111a459c97ec93225beb307aed970d1372dfd" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.31" +version = "0.8.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" +checksum = "2c7962b26b0a8685668b671ee4b54d007a67d4eaf05fda79ac0ecf41e32270f1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -10208,7 +10253,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", "synstructure", ] @@ -10223,13 +10268,13 @@ dependencies = [ [[package]] name = "zeroize_derive" -version = "1.4.2" +version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -10262,7 +10307,7 @@ checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -10330,15 +10375,15 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.5.4" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51f936044d677be1a1168fae1d03b583a285a5dd9d8cbf7b24c23aa1fc775235" +checksum = "40990edd51aae2c2b6907af74ffb635029d5788228222c4bb811e9351c0caad3" [[package]] name = "zmij" -version = "1.0.2" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f4a4e8e9dc5c62d159f04fcdbe07f4c3fb710415aab4754bf11505501e3251d" +checksum = "2fc5a66a20078bf1251bde995aa2fdcc4b800c70b5d92dd2c62abc5c60f679f8" [[package]] name = "zopfli" @@ -10431,9 +10476,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d915729b0e7d5fe35c2f294c5dc10b30207cc637920e5b59077bfa3da63f28" +checksum = "e35aee689668bf9bd6f6f3a6c60bb29ba1244b3b43adfd50edd554a371da37d5" dependencies = [ "zune-core 0.5.0", ] diff --git a/Cargo.toml b/Cargo.toml index af85d9226e..a1f9bd725c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -180,9 +180,9 @@ portable-atomic = { version = "1.13.0" } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "88c0c6f781f70ad2f6e9981fd0cbe2e87e153a35" } -cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "ec439f0c648708c6d436d31eb1a3da9cb634d77c" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7cd8e9c869d363edcbf989fd905eb4b28df938ab" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "7cd8e9c869d363edcbf989fd905eb4b28df938ab" } +cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "8097a621dfd3a6a89f8d0433994e8c1adba377c2" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-backend-tests/tests/cubecl/mask_fill.rs b/crates/burn-backend-tests/tests/cubecl/mask_fill.rs index 730bc10777..d36e0c23ea 100644 --- a/crates/burn-backend-tests/tests/cubecl/mask_fill.rs +++ b/crates/burn-backend-tests/tests/cubecl/mask_fill.rs @@ -2,7 +2,7 @@ use super::*; use burn_cubecl::kernel::{MaskFillStrategy, mask_fill}; use burn_tensor::Tolerance; use burn_tensor::{Bool, Distribution, Element, Tensor, TensorPrimitive, backend::Backend}; -use cubecl::std::scalar::InputScalar; +use cubecl::prelude::InputScalar; #[test] fn mask_fill_should_match_reference_backend() { diff --git a/crates/burn-backend-tests/tests/tensor/float/ops/trig.rs b/crates/burn-backend-tests/tests/tensor/float/ops/trig.rs index 7c198bc346..6e437f2506 100644 --- a/crates/burn-backend-tests/tests/tensor/float/ops/trig.rs +++ b/crates/burn-backend-tests/tests/tensor/float/ops/trig.rs @@ -1,3 +1,5 @@ +#![allow(clippy::approx_constant)] + use super::*; use burn_tensor::TensorData; use burn_tensor::Tolerance; diff --git a/crates/burn-cubecl-fusion/src/base.rs b/crates/burn-cubecl-fusion/src/base.rs index 92ada55fdf..376d622e71 100644 --- a/crates/burn-cubecl-fusion/src/base.rs +++ b/crates/burn-cubecl-fusion/src/base.rs @@ -1,12 +1,15 @@ use burn_fusion::stream::Context; use burn_std::{DType, quantization::QParamTensor}; -use cubecl::quant::scheme::{QuantParam, QuantScheme}; use cubecl::{ CubeElement, Runtime, client::ComputeClient, ir::ElemType, prelude::{TensorArg, TensorHandleRef}, }; +use cubecl::{ + ir::LineSize, + quant::scheme::{QuantParam, QuantScheme}, +}; use std::marker::PhantomData; /// Defines a fallback operation when fusion isn't possible. @@ -73,7 +76,11 @@ impl CubeFusionHandle { } } /// Return the reference to a tensor argument. - pub fn as_tensor_arg<'a>(&'a self, shape: &'a [usize], vectorisation: u8) -> TensorArg<'a, R> { + pub fn as_tensor_arg<'a>( + &'a self, + shape: &'a [usize], + line_size: LineSize, + ) -> TensorArg<'a, R> { let handle: TensorHandleRef<'a, R> = self.as_handle_ref(shape); unsafe { @@ -81,7 +88,7 @@ impl CubeFusionHandle { handle.handle, handle.strides, handle.shape, - vectorisation, + line_size, self.dtype.size(), ) } diff --git a/crates/burn-cubecl-fusion/src/engine/codegen/io.rs b/crates/burn-cubecl-fusion/src/engine/codegen/io.rs index 05dc19e679..66901430dd 100644 --- a/crates/burn-cubecl-fusion/src/engine/codegen/io.rs +++ b/crates/burn-cubecl-fusion/src/engine/codegen/io.rs @@ -18,11 +18,11 @@ pub enum Transform { /// /// This enum entry contains a sequence of [arguments](FuseArg) that points to global scalars representing the /// new shape for the current tensor. - Reshape(Sequence), + Reshape(Vec), /// Two axes have been swapped on a tensor. /// /// The enum entry contains those two axes. - SwapDims(u32, u32), + SwapDims(usize, usize), } /// Reads the value from the [arg](FuseArg) and cast it to the generic cube primitive. @@ -39,7 +39,7 @@ pub fn read( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, - ref_pos: u32, + ref_pos: usize, #[comptime] arg: FuseArg, #[comptime] config: &FuseBlockConfig, ) -> Line { @@ -48,7 +48,7 @@ pub fn read( let global = inputs.tensors.index(pos); let line_size = global.tensor.line_size(); - if comptime![!global.broadcasted && line_size != config.width as u32] { + if comptime![!global.broadcasted && line_size != config.width] { read_input_aligned(inputs, locals, pos, ref_pos, layout, config, None) } else { read_input(inputs, locals, pos, ref_pos, layout, config, None) @@ -90,7 +90,7 @@ pub fn read( let global = inputs.tensors.index(pos); let line_size = global.tensor.line_size(); - if comptime![!broadcasted && line_size != config.width as u32] { + if comptime![!broadcasted && line_size != config.width] { read_input_aligned( inputs, locals, @@ -123,7 +123,7 @@ pub fn read( let global = inputs.tensors.index(pos); let line_size = global.tensor.line_size(); - if comptime![!broadcasted && line_size != config.width as u32] { + if comptime![!broadcasted && line_size != config.width] { read_input_aligned( inputs, locals, @@ -158,15 +158,15 @@ pub fn read( fn index_offset_with_quant_layout( tensor: &GlobalTensor, locals: &LocalArgs, - index: u32, - #[comptime] rank: u32, + index: usize, + #[comptime] rank: usize, #[comptime] scheme: QuantScheme, -) -> u32 { - let (start, end) = comptime![(0u32, rank - 1)]; - let num_quants = comptime!(scheme.num_quants() as u32); +) -> usize { + let (start, end) = (0, rank - 1); + let num_quants = scheme.num_quants(); let offset_ref = index * locals.ref_line_size; - let mut offset = 0u32; + let mut offset = 0; #[unroll] for i in start..end { @@ -192,7 +192,7 @@ fn index_offset_with_quant_layout( pub fn read_quantized( inputs: &GlobalArgs, locals: &LocalArgs, - ref_pos: u32, + ref_pos: usize, #[comptime] arg: FuseArg, #[comptime] config: &FuseBlockConfig, #[comptime] scheme: QuantScheme, @@ -224,9 +224,9 @@ pub fn read_scalar(inputs: &GlobalArgs, #[comptime] arg: FuseA /// Reads a global scalar that is used as a reshape position. #[cube] -pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> u32 { +pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> usize { match arg { - FuseArg::ScalarShape(pos) => *inputs.reshapes.index(pos), + FuseArg::ScalarShape(pos) => inputs.reshapes[pos], _ => comptime![panic!("Not a scalar shape")], } } @@ -236,8 +236,8 @@ pub fn read_scalar_shape(inputs: &GlobalArgs, #[comptime] arg: FuseArg) -> u32 { pub fn read_input( inputs: &GlobalArgs, locals: &LocalArgs, - #[comptime] pos: u32, - ref_pos: u32, + #[comptime] pos: usize, + ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, @@ -255,9 +255,9 @@ pub fn read_input( #[cube] pub fn read_input_window( inputs: &GlobalArgs, - #[comptime] pos: u32, - start: u32, - end: u32, + #[comptime] pos: usize, + start: usize, + end: usize, ) -> Slice { let tensor = inputs.tensors.index(pos); let slice = tensor.tensor.slice(start, end); @@ -266,7 +266,7 @@ pub fn read_input_window( /// Returns the input as a slice. #[cube] -pub fn input_as_slice(inputs: &GlobalArgs, #[comptime] pos: u32) -> Slice { +pub fn input_as_slice(inputs: &GlobalArgs, #[comptime] pos: usize) -> Slice { let tensor = inputs.tensors.index(pos); let slice = tensor.tensor.to_slice(); slice.try_cast_unchecked() @@ -276,11 +276,11 @@ pub fn input_as_slice(inputs: &GlobalArgs, #[comptime] pos: u3 #[cube] pub fn input_as_scales_view( inputs: &GlobalArgs, - #[comptime] pos: u32, - #[comptime] tensor_pos: u32, + #[comptime] pos: usize, + #[comptime] tensor_pos: usize, #[comptime] level: QuantLevel, #[comptime] config: &FuseBlockConfig, -) -> View { +) -> View { set_polyfill_typed::>(); let tensor = inputs.tensors.index(tensor_pos); let scales = inputs.tensors.index(pos); @@ -289,7 +289,7 @@ pub fn input_as_scales_view( let layout = match level { QuantLevel::Tensor => ScalesLayout::new_PerTensor(PerTensorLayout::new(tensor_len)), QuantLevel::Block(block_size) => { - let block_size = comptime![block_size.to_dim_vec(rank as usize)]; + let block_size = comptime![block_size.to_dim_vec(rank)]; let mut tensor_shape = Sequence::new(); let mut scales_strides = Sequence::new(); #[unroll] @@ -308,7 +308,7 @@ pub fn input_as_scales_view( ScalesLayout::new_BlockScaled(layout) } }; - View::new::, u32>(&scales.tensor.to_slice().try_cast_unchecked(), layout) + View::new::, usize>(&scales.tensor.to_slice().try_cast_unchecked(), layout) } /// Reads the input tensor aligned. @@ -316,22 +316,22 @@ pub fn input_as_scales_view( pub fn read_input_aligned( inputs: &GlobalArgs, locals: &LocalArgs, - #[comptime] pos: u32, - ref_pos: u32, + #[comptime] pos: usize, + ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, ) -> Line { - let mut result: Line = Line::::empty(comptime![config.width as u32]); + let mut result: Line = Line::::empty(config.width); let tensor = inputs.tensors.index(pos); - match comptime![transform.clone()] { + match transform.clone() { Some(Transform::Reshape(shape)) => { // Very brute force, not really efficient, but not easy to optimize and not a very // frequent workflow. - let ref_pos = ref_pos * comptime![config.width as u32]; + let ref_pos = ref_pos * config.width; #[unroll] - for i in 0u32..comptime!(config.width as u32) { + for i in 0..config.width { let index = reshaped_index( inputs, locals, @@ -346,11 +346,11 @@ pub fn read_input_aligned( Some(Transform::SwapDims(dim1, dim2)) => { let offset = get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform); - let i = comptime![swap_dims_transform(&(config.rank - 1), (dim1, dim2))]; - let stride = tensor.tensor.stride(comptime![i]); + let i = comptime![swap_dims_transform(config.rank - 1, (dim1, dim2))]; + let stride = tensor.tensor.stride(i); #[unroll] - for i in 0u32..comptime!(config.width as u32) { + for i in 0..config.width { let index = offset + i * stride; result[i] = C::cast_from(tensor.tensor[index][0]) } @@ -358,9 +358,9 @@ pub fn read_input_aligned( None => { let offset = get_offset_aligned(inputs, locals, tensor, ref_pos, layout, config, transform); - let stride = tensor.tensor.stride(comptime![config.rank - 1]); + let stride = tensor.tensor.stride(config.rank - 1); #[unroll] - for i in 0u32..comptime!(config.width as u32) { + for i in 0..config.width { let index = offset + i * stride; result[i] = C::cast_from(tensor.tensor[index][0]) } @@ -377,11 +377,11 @@ pub fn get_offset_aligned( inputs: &GlobalArgs, locals: &LocalArgs, tensor: &GlobalTensor, - ref_pos: u32, + ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, -) -> u32 { +) -> usize { match layout { LayoutInfo::SameAsRef | LayoutInfo::IsRef => { (ref_pos * locals.ref_line_size) / tensor.tensor.line_size() @@ -404,8 +404,8 @@ pub fn read_output( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, - pos: u32, - ref_pos: u32, + #[comptime] pos: usize, + ref_pos: usize, #[comptime] layout: LayoutInfo, #[comptime] config: &FuseBlockConfig, ) -> Line { @@ -424,7 +424,7 @@ pub fn write( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - ref_pos: u32, + ref_pos: usize, value: Line, #[comptime] arg: FuseArg, #[comptime] config: &FuseBlockConfig, @@ -468,11 +468,11 @@ pub(crate) fn global_offset( inputs: &GlobalArgs, outputs: &GlobalArgs, locals: &LocalArgs, - index: u32, + index: usize, #[comptime] arg: FuseArg, - #[comptime] range: Option<(u32, u32)>, + #[comptime] range: Option<(usize, usize)>, #[comptime] config: &FuseBlockConfig, -) -> u32 { +) -> usize { match arg { FuseArg::Input(pos, _precision, _layout) => { let tensor = inputs.tensors.index(pos); @@ -491,11 +491,11 @@ fn get_offset( inputs: &GlobalArgs, locals: &LocalArgs, tensor: &GlobalTensor, - ref_pos: u32, - #[comptime] range: Option<(u32, u32)>, + ref_pos: usize, + #[comptime] range: Option<(usize, usize)>, #[comptime] config: &FuseBlockConfig, #[comptime] transform: Option, -) -> u32 { +) -> usize { index_offset_with_layout( inputs, tensor, @@ -509,28 +509,28 @@ fn get_offset( #[cube] /// Gets the line size for a global tensor. -pub fn global_line_size(global: &GlobalArgs, #[comptime] pos: u32) -> comptime_type!(u32) { +pub fn global_line_size(global: &GlobalArgs, #[comptime] pos: usize) -> comptime_type!(LineSize) { let tensor = global.tensors.index(pos); tensor.tensor.line_size() } #[cube] /// Gets the rank for a global tensor. -pub fn global_rank(global: &GlobalArgs, #[comptime] pos: u32) -> u32 { +pub fn global_rank(global: &GlobalArgs, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.rank() } #[cube] /// Gets the length for a global tensor. -pub fn global_len(global: &GlobalArgs, #[comptime] pos: u32) -> u32 { +pub fn global_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.len() } #[cube] /// Gets the buffer length for a global tensor. -pub fn global_buffer_len(global: &GlobalArgs, #[comptime] pos: u32) -> u32 { +pub fn global_buffer_len(global: &GlobalArgs, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.buffer_len() } @@ -542,8 +542,8 @@ pub fn ref_len( outputs: &GlobalArgs, locals: &LocalArgs, #[comptime] config: &FuseBlockConfig, -) -> u32 { - match comptime![config.ref_layout.clone()] { +) -> usize { + match config.ref_layout.clone() { RefLayout::Concrete(arg) => match comptime![arg] { FuseArg::Input(index, _, _) => global_len(inputs, index), FuseArg::Output(index, _, _) => global_len(outputs, index), @@ -560,8 +560,8 @@ pub fn ref_buffer_len( outputs: &GlobalArgs, locals: &LocalArgs, #[comptime] config: &FuseBlockConfig, -) -> u32 { - match comptime![config.ref_layout.clone()] { +) -> usize { + match config.ref_layout.clone() { RefLayout::Concrete(arg) => match comptime![arg] { FuseArg::Input(index, _, _) => global_buffer_len(inputs, index), FuseArg::Output(index, _, _) => global_buffer_len(outputs, index), @@ -579,8 +579,8 @@ pub fn ref_buffer_len( #[cube] /// Gets the reference number of elements. -pub fn num_elements(locals: &LocalArgs, #[comptime] config: &FuseBlockConfig) -> u32 { - let mut length = 1u32; +pub fn num_elements(locals: &LocalArgs, #[comptime] config: &FuseBlockConfig) -> usize { + let mut length = 1; for i in 0..config.rank { length *= locals.ref_shape[i]; @@ -591,32 +591,32 @@ pub fn num_elements(locals: &LocalArgs, #[comptime] config: &FuseBlockConfig) -> #[cube] /// Gets the reference axis shape. -pub fn ref_shape(locals: &LocalArgs, axis: u32) -> u32 { +pub fn ref_shape(locals: &LocalArgs, axis: usize) -> usize { locals.ref_shape[axis] } #[cube] /// Gets the reference axis stride. -pub fn ref_stride(locals: &LocalArgs, axis: u32) -> u32 { +pub fn ref_stride(locals: &LocalArgs, axis: usize) -> usize { locals.ref_strides[axis] } #[cube] /// Gets the reference line size. -pub fn ref_line_size(locals: &LocalArgs) -> comptime_type!(u32) { +pub fn ref_line_size(locals: &LocalArgs) -> comptime_type!(LineSize) { comptime![locals.ref_line_size] } #[cube] /// Gets the given tensor axis shape. -pub fn global_shape(global: &GlobalArgs, axis: u32, #[comptime] pos: u32) -> u32 { +pub fn global_shape(global: &GlobalArgs, axis: usize, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.shape(axis) } #[cube] /// Gets the given tensor axis stride. -pub fn global_stride(global: &GlobalArgs, dim: u32, #[comptime] pos: u32) -> u32 { +pub fn global_stride(global: &GlobalArgs, dim: usize, #[comptime] pos: usize) -> usize { let tensor = global.tensors.index(pos); tensor.tensor.stride(dim) } @@ -626,11 +626,11 @@ fn index_offset_with_layout( inputs: &GlobalArgs, tensor: &GlobalTensor, locals: &LocalArgs, - index: u32, - #[comptime] range: Option<(u32, u32)>, - #[comptime] rank: u32, + index: usize, + #[comptime] range: Option<(usize, usize)>, + #[comptime] rank: usize, #[comptime] transform: Option, -) -> u32 { +) -> usize { match comptime![transform.clone()] { Some(Transform::Reshape(shape)) => { comptime![assert!( @@ -645,15 +645,15 @@ fn index_offset_with_layout( Some(Transform::SwapDims(dim1, dim2)) => { let (start, end) = comptime! {match range { Some(range) => range, - None => (0u32, rank), + None => (0, rank), }}; let offset_ref = index * locals.ref_line_size; - let mut offset = 0u32; + let mut offset = 0; #[unroll] for i in start..end { - let index = comptime![swap_dims_transform(&i, (dim1, dim2))]; + let index = comptime![swap_dims_transform(i, (dim1, dim2))]; let ogwl = offset_ref / locals.ref_strides[i]; offset += ogwl % tensor.tensor.shape(index) * tensor.tensor.stride(index); } @@ -663,11 +663,11 @@ fn index_offset_with_layout( None => { let (start, end) = comptime! {match range { Some(range) => range, - None => (0u32, rank), + None => (0, rank), }}; let offset_ref = index * locals.ref_line_size; - let mut offset = 0u32; + let mut offset = 0; #[unroll] for i in start..end { @@ -680,10 +680,7 @@ fn index_offset_with_layout( } } -pub(crate) fn swap_dims_transform(i: &I, dims: (u32, u32)) -> u32 { - let i_cloned: I = i.clone(); - let i = i_cloned.value().as_const().unwrap().as_u32(); - +pub(crate) fn swap_dims_transform(i: usize, dims: (usize, usize)) -> usize { if i == dims.0 { dims.1 } else if i == dims.1 { @@ -699,18 +696,18 @@ pub(crate) fn swap_dims_transform(i: &I, dims: (u32, u32)) -> fn reshaped_index( inputs: &GlobalArgs, locals: &LocalArgs, - index: u32, - #[comptime] rank: u32, - #[comptime] shape: Sequence, -) -> u32 { - let mut offset = 0u32; - let mut stride_curr = 1u32; + index: usize, + #[comptime] rank: usize, + #[comptime] shape: Vec, +) -> usize { + let mut offset = 0; + let mut stride_curr = 1; #[unroll] for r in 0..rank { - let i = reverse_index(rank, r); - let arg = comptime![shape.index(i)]; - let shape_i = read_scalar_shape(inputs, comptime![arg.clone()]); + let i = reverse_index(rank, r).comptime(); + let arg = shape[i].clone(); + let shape_i = read_scalar_shape(inputs, arg); let ogwl = index / locals.ref_strides[i]; offset += ogwl % shape_i * stride_curr; @@ -726,9 +723,9 @@ fn reshaped_index( #[allow(clippy::clone_on_copy)] fn reshaped_index_to_original_index( original: &Tensor>, - index_reshaped: u32, - #[comptime] rank: u32, -) -> u32 { + index_reshaped: usize, + #[comptime] rank: usize, +) -> usize { let mut remaining = index_reshaped; let mut offset = 0; @@ -749,20 +746,18 @@ fn reshaped_index_to_original_index( #[cube] #[allow(unused_variables)] -pub(crate) fn reverse_index(#[comptime] rank: u32, iter: u32) -> comptime_type!(u32) { - intrinsic!(|_| { - let elem = iter.constant().map(|cons| cons.as_u32()).unwrap(); - rank - elem - 1 - }) +pub(crate) fn reverse_index( + #[comptime] rank: usize, + #[comptime] iter: usize, +) -> comptime_type!(usize) { + rank - iter - 1 } /// Generic way to construct any [`CubePrimitive`] from an int. Used for fusion. #[allow(unused_variables)] #[cube] -fn from_const_int(#[comptime] value: u32) -> C { +fn from_const_int(#[comptime] value: usize) -> C { intrinsic!(|scope| { - let constant: ExpandElement = value.into(); - let constant_c = constant.as_const().unwrap().cast_to(C::as_type(scope)); ExpandElement::Plain(Variable::constant(value.into(), C::as_type(scope))).into() }) } diff --git a/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs b/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs index 318d01f93c..d2be1e33da 100644 --- a/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs +++ b/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs @@ -3,7 +3,6 @@ use burn_std::quantization::{QuantScheme, QuantStore, QuantValue}; use burn_std::{bf16, f16}; use cubecl::ir::{ElemType, FloatKind, IntKind, StorageType, UIntKind}; use cubecl::prelude::*; -use cubecl::std::scalar::InputScalar; use serde::{Deserialize, Serialize}; use super::tensor::GlobalTensor; @@ -12,30 +11,30 @@ use super::tensor::GlobalTensor; /// Argument to a [fuse operation](FuseOp). pub enum FuseArg { /// A readonly input tensor. - Input(u32, FuseType, LayoutInfo), + Input(usize, FuseType, LayoutInfo), /// A temporary local variable. - Local(u32, FuseType), + Local(usize, FuseType), /// A readwrite output tensor. - Output(u32, FuseType, LayoutInfo), + Output(usize, FuseType, LayoutInfo), /// A global scalar. - Scalar(u32, FuseType), + Scalar(usize, FuseType), /// A global scalar used in a reshape operation. /// /// This is not a scalar defined by a user for computation, but a scalar defined as part of /// a reshape operation. - ScalarShape(u32), + ScalarShape(usize), /// Only constant that can be encoded into an u32 can be used as literal. - Literal(u32, FuseType), + Literal(usize, FuseType), /// A readonly input tensor that is reshaped. InputReshaped { original: Box, - shape: Sequence, + shape: Vec, broadcasted: bool, }, /// A readonly input tensor with swapped dimensions. InputSwapDims { original: Box, - dims: (u32, u32), + dims: (usize, usize), broadcasted: bool, }, } @@ -128,13 +127,13 @@ pub enum FuseOp { input: FuseArg, indices: FuseArg, output: FuseArg, - dim: u32, + dim: usize, }, Select { input: FuseArg, indices: FuseArg, output: FuseArg, - dim: u32, + dim: usize, }, Dequantize { values: FuseArg, @@ -197,7 +196,7 @@ impl FuseOp { pub struct GlobalArgs { pub tensors: Sequence, pub scalars: Sequence, - pub reshapes: Sequence, + pub reshapes: Sequence, } impl Default for GlobalArgsLaunch<'_, R> { @@ -238,15 +237,15 @@ impl GlobalArgsLaunch<'_, R> { RefLayout::Virtual(layout) => match layout { VirtualLayout::SwapDims(original, dims) => { let mut shape = self.shape(original); - shape.swap(dims.0 as usize, dims.1 as usize); + shape.swap(dims.0, dims.1); shape } VirtualLayout::Reshaped { reshape_pos, .. } => { - let start = *reshape_pos as usize * rank; + let start = *reshape_pos * rank; let end = start + rank; self.reshapes.values[start..end] .iter() - .map(|s| s.elem as usize) + .map(|s| s.elem) .collect() } VirtualLayout::Shape(original, _) => self.shape(original), @@ -290,7 +289,7 @@ impl GlobalArgsLaunch<'_, R> { /// # Panics /// /// If the argument doesn't have an handle. - pub fn line_size(&self, arg: &FuseArg) -> u8 { + pub fn line_size(&self, arg: &FuseArg) -> LineSize { match self.resolve_arg(arg) { TensorArg::Handle { line_size, .. } => *line_size, TensorArg::Alias { .. } => panic!("Unsupported yet"), @@ -304,8 +303,8 @@ impl GlobalArgsLaunch<'_, R> { /// If the argument isn't a global input or output tensor. pub fn resolve_arg(&self, arg: &FuseArg) -> &TensorArg<'_, R> { match arg { - FuseArg::Input(pos, _, _) => &self.tensors.values[*pos as usize].tensor, - FuseArg::Output(pos, _, _) => &self.tensors.values[*pos as usize].tensor, + FuseArg::Input(pos, _, _) => &self.tensors.values[*pos].tensor, + FuseArg::Output(pos, _, _) => &self.tensors.values[*pos].tensor, other => panic!("Arg not found: {other:?}"), } } @@ -315,47 +314,47 @@ impl GlobalArgsLaunch<'_, R> { /// Keep track of all local variables that are used as argument in fused /// [element wise operations](ElemwiseOp). pub struct LocalArgs { - pub l_f64: Registry>, - pub l_f32: Registry>, - pub l_f16: Registry>, - pub l_bf16: Registry>, - pub l_i64: Registry>, - pub l_i32: Registry>, - pub l_i16: Registry>, - pub l_i8: Registry>, - pub l_u64: Registry>, - pub l_u32: Registry>, - pub l_u16: Registry>, - pub l_u8: Registry>, - pub l_bool: Registry>, - pub ref_shape: Slice, - pub ref_strides: Slice, + pub l_f64: Registry>, + pub l_f32: Registry>, + pub l_f16: Registry>, + pub l_bf16: Registry>, + pub l_i64: Registry>, + pub l_i32: Registry>, + pub l_i16: Registry>, + pub l_i8: Registry>, + pub l_u64: Registry>, + pub l_u32: Registry>, + pub l_u16: Registry>, + pub l_u8: Registry>, + pub l_bool: Registry>, + pub ref_shape: Slice, + pub ref_strides: Slice, #[cube(comptime)] - pub ref_line_size: u32, + pub ref_line_size: LineSize, } #[cube] impl LocalArgs { /// Creates a new [LocalArgs] container. pub fn new( - ref_shape: Slice, - ref_strides: Slice, - #[comptime] ref_line_size: u32, + ref_shape: Slice, + ref_strides: Slice, + #[comptime] ref_line_size: LineSize, ) -> LocalArgs { LocalArgs { - l_f64: Registry::>::new(), - l_f32: Registry::>::new(), - l_f16: Registry::>::new(), - l_bf16: Registry::>::new(), - l_i64: Registry::>::new(), - l_i32: Registry::>::new(), - l_i16: Registry::>::new(), - l_i8: Registry::>::new(), - l_u64: Registry::>::new(), - l_u32: Registry::>::new(), - l_u16: Registry::>::new(), - l_u8: Registry::>::new(), - l_bool: Registry::>::new(), + l_f64: Registry::>::new(), + l_f32: Registry::>::new(), + l_f16: Registry::>::new(), + l_bf16: Registry::>::new(), + l_i64: Registry::>::new(), + l_i32: Registry::>::new(), + l_i16: Registry::>::new(), + l_i8: Registry::>::new(), + l_u64: Registry::>::new(), + l_u32: Registry::>::new(), + l_u16: Registry::>::new(), + l_u8: Registry::>::new(), + l_bool: Registry::>::new(), ref_shape, ref_strides, ref_line_size, @@ -405,10 +404,10 @@ pub enum FuseType { #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] /// Configuration that encapsulates all comptime information necessary for element wise fusion. pub struct FuseBlockConfig { - pub rank: u32, + pub rank: usize, pub ref_layout: RefLayout, - pub ops: Sequence, - pub width: u8, + pub ops: Vec, + pub width: LineSize, } #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] @@ -425,12 +424,15 @@ pub enum RefLayout { /// tensor with swap dimensions. pub enum VirtualLayout { /// Virtual tensor with the provided shape id and contiguous strides. - Reshaped { reshape_pos: u32, line_size: u32 }, + Reshaped { + reshape_pos: usize, + line_size: LineSize, + }, /// Virtual tensor with the same shape as the given input, but with swap dims and contiguous /// strides. - SwapDims(FuseArg, (u32, u32)), + SwapDims(FuseArg, (usize, usize)), /// Virtual tensor with the same shape as the given input, but with contiguous strides. - Shape(FuseArg, u32), + Shape(FuseArg, usize), } impl FuseArg { diff --git a/crates/burn-cubecl-fusion/src/engine/codegen/kernel.rs b/crates/burn-cubecl-fusion/src/engine/codegen/kernel.rs index 7724083d13..e0cba172de 100644 --- a/crates/burn-cubecl-fusion/src/engine/codegen/kernel.rs +++ b/crates/burn-cubecl-fusion/src/engine/codegen/kernel.rs @@ -26,16 +26,16 @@ pub fn fuse_on_write( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, write_values: Registry>, - #[comptime] write_args: Sequence, + #[comptime] write_args: Vec, #[comptime] config: &FuseBlockConfig, ) { // Write the values given as arguments. #[unroll] - for i in 0..write_args.len() { - let arg = comptime![write_args.index(i).clone()]; - let val = write_values.find(comptime![arg.clone()]); + for _ in 0..write_args.len() { + let arg = comptime![write_args[0].clone()]; + let val = write_values.find(arg.clone()); write::(inputs, outputs, locals, write_pos, val, arg, config); } @@ -62,7 +62,7 @@ pub fn fuse_on_read( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - read_pos: u32, + read_pos: usize, #[comptime] read_args: Sequence, #[comptime] config: &FuseBlockConfig, ) -> Sequence> { @@ -76,11 +76,11 @@ pub fn fuse_on_read( let value = read::(inputs, outputs, locals, read_pos, arg, config); let value_line_size = value.line_size(); - let output_line_size = comptime!(config.width as u32); + let output_line_size = config.width; // We currently don't support broadcasting __across__ blocks. if comptime!(value_line_size != output_line_size) { - let mut tmp = Line::::empty(comptime!(config.width as u32)); + let mut tmp = Line::::empty(config.width); comptime!( assert_eq!(value_line_size, 1, "The input line_size must be 1 or the same as the config width."); ); @@ -88,7 +88,7 @@ pub fn fuse_on_read( let val = value[0]; #[unroll] - for i in 0..comptime!(config.width as u32) { + for i in 0..config.width { tmp[i] = val; } @@ -116,7 +116,7 @@ pub fn init_locals( let mut ref_shape = Array::new(config.rank); let mut ref_strides = Array::new(config.rank); - match comptime![config.ref_layout.clone()] { + match config.ref_layout.clone() { RefLayout::Concrete(arg) => match comptime![arg] { FuseArg::Input(index, ..) => { let layout = inputs.tensors.index(index); @@ -152,23 +152,23 @@ pub fn init_locals( }, RefLayout::Virtual(layout) => match layout { VirtualLayout::SwapDims(original, dims) => { - let layout = match comptime![original.clone()] { + let layout = match original.clone() { FuseArg::Input(pos, ..) => inputs.tensors.index(pos), FuseArg::Output(pos, ..) => outputs.tensors.index(pos), _ => comptime![panic!("Unsupported")], }; - let mut stride_curr = 1u32; + let mut stride_curr = 1; #[unroll] #[allow(clippy::clone_on_copy)] for i in 0..config.rank { let reverse = reverse_index(config.rank, i); - let swap = comptime![swap_dims_transform(comptime![&reverse], dims)]; - let shape = layout.tensor.shape(comptime![swap.clone()]); + let swap = comptime![swap_dims_transform(reverse, dims)]; + let shape = layout.tensor.shape(swap.clone()); - ref_shape[comptime![reverse.clone()]] = shape; - ref_strides[comptime![reverse.clone()]] = stride_curr; + ref_shape[reverse] = shape; + ref_strides[reverse] = stride_curr; stride_curr *= ref_shape[comptime![reverse]]; } @@ -183,7 +183,7 @@ pub fn init_locals( reshape_pos, line_size, } => { - let mut stride_curr = 1u32; + let mut stride_curr = 1; let start = reshape_pos * config.rank; #[unroll] @@ -191,7 +191,7 @@ pub fn init_locals( for i in 0..config.rank { let reverse = reverse_index(config.rank, i); let arg = comptime![FuseArg::ScalarShape(start + reverse)]; - let shape = read_scalar_shape(inputs, comptime![arg.clone()]); + let shape = read_scalar_shape(inputs, arg.clone()); ref_shape[comptime![reverse]] = shape; ref_strides[comptime![reverse]] = stride_curr; @@ -202,12 +202,12 @@ pub fn init_locals( LocalArgs::new(ref_shape.to_slice(), ref_strides.to_slice(), line_size) } VirtualLayout::Shape(original, line_size) => { - let layout = match comptime![original.clone()] { + let layout = match original.clone() { FuseArg::Input(pos, ..) => inputs.tensors.index(pos), FuseArg::Output(pos, ..) => outputs.tensors.index(pos), _ => comptime![panic!("Unsupported")], }; - let mut stride_curr = 1u32; + let mut stride_curr = 1; #[unroll] #[allow(clippy::clone_on_copy)] @@ -233,13 +233,13 @@ fn fuse( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - pos: u32, + pos: usize, #[comptime] config: &FuseBlockConfig, ) { #[unroll] for index in 0..config.ops.len() { - let op = comptime! { config.ops.index(index).clone() }; - set_polyfill::>(comptime![op.cmp_type()]); + let op = config.ops[index].clone(); + set_polyfill::>(op.cmp_type()); match op { FuseOp::Add(op) => { @@ -367,7 +367,7 @@ macro_rules! binary_op { inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] op: BinaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { @@ -387,7 +387,7 @@ macro_rules! binary_func { inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] op: BinaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { @@ -407,7 +407,7 @@ macro_rules! comparison_op { inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] op: BinaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { @@ -427,7 +427,7 @@ macro_rules! unary_func { inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] op: UnaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { @@ -444,7 +444,7 @@ fn assign( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] op: UnaryFuseArgs, #[comptime] config: &FuseBlockConfig, ) { @@ -458,8 +458,8 @@ fn gather( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, - #[comptime] dim: u32, + write_pos: usize, + #[comptime] dim: usize, #[comptime] input: FuseArg, #[comptime] indices: FuseArg, #[comptime] output: FuseArg, @@ -482,7 +482,7 @@ fn gather( let stride_input_dim = global_stride(inputs, dim, pos_input); - let mut index = 0u32; + let mut index = 0; let mut result = Line::empty(line_size); if comptime![dim > 0] { @@ -491,8 +491,8 @@ fn gather( outputs, locals, write_pos, - comptime!(input.clone()), - comptime![Some((0u32, dim))], + input.clone(), + comptime![Some((0, dim))], config, ); index += index_before; @@ -517,7 +517,7 @@ fn gather( locals, write_pos, indices, - comptime![Some((0u32, config.rank))], + comptime![Some((0, config.rank))], config, ); @@ -539,7 +539,7 @@ fn gather( inputs, locals, pos_input, - index + (offset[0] * stride_input_dim), + index + (offset[0] as usize * stride_input_dim), LayoutInfo::IsRef, config, None, @@ -549,7 +549,7 @@ fn gather( } } else { // Shared index for whole line - let stride_input_line = global_stride(inputs, comptime!(config.rank - 1), pos_input); + let stride_input_line = global_stride(inputs, config.rank - 1, pos_input); let offset = read_input::( inputs, @@ -561,7 +561,7 @@ fn gather( None, ); - index += offset[0] * stride_input_dim; + index += offset[0] as usize * stride_input_dim; #[unroll] for i in 0..line_size { @@ -587,8 +587,8 @@ fn select_indices( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, - #[comptime] dim: u32, + write_pos: usize, + #[comptime] dim: usize, #[comptime] input: FuseArg, #[comptime] indices: FuseArg, #[comptime] output: FuseArg, @@ -613,7 +613,7 @@ fn select_indices( let stride_input_dim = global_stride(inputs, dim, pos_input); - let mut index = 0u32; + let mut index = 0; let mut result = Line::empty(line_size_ref); if comptime![dim != config.rank - 1] { @@ -627,8 +627,8 @@ fn select_indices( outputs, locals, write_pos, - comptime!(input.clone()), - comptime![Some((0u32, dim))], + input.clone(), + comptime![Some((0, dim))], config, ); index += index_before; @@ -640,7 +640,7 @@ fn select_indices( outputs, locals, write_pos, - comptime!(input.clone()), + input.clone(), comptime![Some((dim + 1, config.rank))], config, ); @@ -660,7 +660,7 @@ fn select_indices( None, ); - index += offset_dim[0] * stride_input_dim; + index += offset_dim[0] as usize * stride_input_dim; #[unroll] for i in 0..line_size_ref { @@ -687,8 +687,8 @@ fn select_indices( outputs, locals, write_pos, - comptime!(input.clone()), - comptime![Some((0u32, dim))], + input.clone(), + comptime![Some((0, dim))], config, ); index += index_before; @@ -726,7 +726,7 @@ fn select_indices( inputs, locals, pos_input, - index + (offset_dim[0] * stride_input_dim), + index + (offset_dim[0] as usize * stride_input_dim), LayoutInfo::IsRef, config, None, @@ -743,7 +743,7 @@ fn conditional_assign( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] cond: FuseArg, #[comptime] lhs: FuseArg, #[comptime] rhs: FuseArg, @@ -763,7 +763,7 @@ fn clamp( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] input: FuseArg, #[comptime] min: FuseArg, #[comptime] max: FuseArg, @@ -784,7 +784,7 @@ fn dequantize( inputs: &GlobalArgs, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - write_pos: u32, + write_pos: usize, #[comptime] input: FuseArg, #[comptime] scales: FuseArg, #[comptime] output: FuseArg, @@ -834,7 +834,7 @@ fn dequantize( ); let line_size = input.line_size(); - let num_quants = comptime!(scheme.num_quants() as u32); + let num_quants = scheme.num_quants(); let scales = input_as_scales_view::>( inputs, @@ -856,22 +856,15 @@ fn dequantize( } else { let mut line = Line::empty(line_size_result); - // We have to do all index work as comptime because higher line sizes removes the - // possibility to index dynamically on lines. - let mut i = comptime!(0); - #[unroll] - for _ in 0..line_size { - let mut j = comptime!(0); + for i in 0..line_size { let value = result[i]; #[unroll] - for _ in 0..num_quants { - let index = comptime!(i * num_quants + j); + for j in 0..num_quants { + let index = i * num_quants + j; line[index] = value[j]; - comptime!(j += 1); } - comptime!(i += 1); } line diff --git a/crates/burn-cubecl-fusion/src/engine/codegen/view.rs b/crates/burn-cubecl-fusion/src/engine/codegen/view.rs index 4fc2143c8a..4d6da2c31a 100644 --- a/crates/burn-cubecl-fusion/src/engine/codegen/view.rs +++ b/crates/burn-cubecl-fusion/src/engine/codegen/view.rs @@ -24,7 +24,7 @@ pub struct GlobalInput { inputs: GlobalArgs, locals: LocalArgs, #[cube(comptime)] - pos: u32, + pos: usize, #[cube(comptime)] ty: StorageType, #[cube(comptime)] @@ -67,7 +67,7 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_read_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, ) -> ::ExpandType { ViewOperationsExpand::::__expand_read_unchecked_method(self, scope, pos) } @@ -76,7 +76,7 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_read_checked_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, ) -> ::ExpandType { let zero = E::__expand_cast_from(scope, 0.into()); ViewOperationsExpand::::__expand_read_masked_method(self, scope, pos, zero) @@ -86,7 +86,7 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_read_masked_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, value: ::ExpandType, ) -> ::ExpandType { let in_bounds = ViewOperationsExpand::::__expand_is_in_bounds_method( @@ -103,7 +103,7 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_read_unchecked_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, ) -> ::ExpandType { let value = read_input::expand::( scope, @@ -122,8 +122,8 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_to_linear_slice_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, - end: ExpandElementTyped, + pos: ExpandElementTyped, + end: ExpandElementTyped, ) -> SliceExpand { scope.register_type::>(self.ty); let end = add::expand(scope, end.clone(), 1.into()); @@ -136,13 +136,13 @@ impl ViewOperationsExpand for GlobalInputExpand { _scope: &mut Scope, _barrier: BarrierExpand, _shared_memory: SliceExpand, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) { panic!("Not a tensor map") } #[allow(clippy::too_many_arguments)] - fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped { + fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped { global_buffer_len::expand(scope, self.inputs.clone(), self.pos) } @@ -150,7 +150,7 @@ impl ViewOperationsExpand for GlobalInputExpand { fn __expand_is_in_bounds_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, ) -> ExpandElementTyped { let buffer_len = global_buffer_len::expand(scope, self.inputs.clone(), self.pos); lt::expand(scope, pos, buffer_len) @@ -159,7 +159,7 @@ impl ViewOperationsExpand for GlobalInputExpand { impl Lined for GlobalInput {} impl LinedExpand for GlobalInputExpand { - fn line_size(&self) -> u32 { + fn line_size(&self) -> LineSize { let mut temp_scope = Scope::root(false); global_line_size::expand(&mut temp_scope, self.inputs.clone(), self.pos) } @@ -201,7 +201,7 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_read_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) -> as CubeType>::ExpandType { todo!() } @@ -210,7 +210,7 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_read_checked_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) -> as CubeType>::ExpandType { todo!() } @@ -219,7 +219,7 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_read_masked_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, _value: as CubeType>::ExpandType, ) -> as CubeType>::ExpandType { todo!() @@ -229,7 +229,7 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_read_unchecked_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) -> as CubeType>::ExpandType { todo!() } @@ -238,8 +238,8 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_to_linear_slice_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, - _size: ExpandElementTyped, + _pos: ExpandElementTyped, + _size: ExpandElementTyped, ) -> SliceExpand, ReadOnly> { todo!() } @@ -250,13 +250,13 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx _scope: &mut Scope, _barrier: BarrierExpand, _shared_memory: SliceExpand, ReadWrite>, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) { panic!("Not a tensor map") } #[allow(clippy::too_many_arguments)] - fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped { + fn __expand_shape_method(&self, scope: &mut Scope) -> ExpandElementTyped { ref_len::expand( scope, self.inputs.clone(), @@ -270,7 +270,7 @@ impl ViewOperationsExpand, Coords1d> for FusedOutputEx fn __expand_is_in_bounds_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, ) -> ExpandElementTyped { let buffer_len = ref_buffer_len::expand( scope, @@ -289,11 +289,11 @@ impl ViewOperationsMutExpand, Coords1d> for FusedOutpu fn __expand_write_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, value: as CubeType>::ExpandType, ) { let values = Registry::>::__expand_new(scope); - let mut args = comptime![Sequence::::new()]; + let mut args = comptime![Vec::::new()]; values .clone() @@ -316,7 +316,7 @@ impl ViewOperationsMutExpand, Coords1d> for FusedOutpu fn __expand_write_checked_method( &self, scope: &mut Scope, - pos: ExpandElementTyped, + pos: ExpandElementTyped, value: as CubeType>::ExpandType, ) { let in_bounds = ViewOperationsExpand::, Coords1d>::__expand_is_in_bounds_method( @@ -333,8 +333,8 @@ impl ViewOperationsMutExpand, Coords1d> for FusedOutpu fn __expand_to_linear_slice_mut_method( &self, _scope: &mut Scope, - _pos: ExpandElementTyped, - _size: ExpandElementTyped, + _pos: ExpandElementTyped, + _size: ExpandElementTyped, ) -> SliceExpand, ReadWrite> { todo!("Not yet supported") } @@ -344,7 +344,7 @@ impl ViewOperationsMutExpand, Coords1d> for FusedOutpu &self, _scope: &mut Scope, _shared_memory: SliceExpand, ReadOnly>, - _pos: ExpandElementTyped, + _pos: ExpandElementTyped, ) { panic!("Not a tensor map") } @@ -352,7 +352,7 @@ impl ViewOperationsMutExpand, Coords1d> for FusedOutpu impl Lined for FusedOutput {} impl LinedExpand for FusedOutputExpand { - fn line_size(&self) -> u32 { + fn line_size(&self) -> LineSize { self.locals.ref_line_size } } diff --git a/crates/burn-cubecl-fusion/src/engine/fuser.rs b/crates/burn-cubecl-fusion/src/engine/fuser.rs index 1e18d1f2f1..2f117dde30 100644 --- a/crates/burn-cubecl-fusion/src/engine/fuser.rs +++ b/crates/burn-cubecl-fusion/src/engine/fuser.rs @@ -276,11 +276,7 @@ impl TraceOperationFuser { } if self.fuser.fuse(|fuser| { - fuser.input_swap_dims( - &desc.input, - &desc.out, - (desc.dim1 as u32, desc.dim2 as u32), - )?; + fuser.input_swap_dims(&desc.input, &desc.out, (desc.dim1, desc.dim2))?; Some(()) }) { @@ -364,7 +360,7 @@ impl TraceOperationFuser { input, indices, output, - dim: desc.dim as u32, + dim: desc.dim, }); Some(()) @@ -384,7 +380,7 @@ impl TraceOperationFuser { input, indices, output, - dim: desc.dim as u32, + dim: desc.dim, }); Some(()) diff --git a/crates/burn-cubecl-fusion/src/engine/launch/executor.rs b/crates/burn-cubecl-fusion/src/engine/launch/executor.rs index 850f1cdace..c7d30e922a 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/executor.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/executor.rs @@ -16,8 +16,7 @@ use burn_std::DType; use cubecl::{ CubeElement, Runtime, client::ComputeClient, - prelude::{ScalarArg, Sequence, TensorArg}, - std::scalar::InputScalar, + prelude::{InputScalar, ScalarArg, TensorArg}, }; use std::marker::PhantomData; @@ -87,15 +86,15 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { let reference = match block_plan.reference { ReferenceSelection::Concrete { layout, .. } => RefLayout::Concrete(layout), ReferenceSelection::VirtualShape { original, .. } => { - RefLayout::Virtual(VirtualLayout::Shape(original, block_plan.width as u32)) + RefLayout::Virtual(VirtualLayout::Shape(original, block_plan.width)) } ReferenceSelection::SwapDims { original, dims } => { RefLayout::Virtual(VirtualLayout::SwapDims(original, dims)) } ReferenceSelection::Reshaped { reshape_pos } => { RefLayout::Virtual(VirtualLayout::Reshaped { - reshape_pos: reshape_pos as u32, - line_size: block_plan.width as u32, + reshape_pos, + line_size: block_plan.width, }) } ReferenceSelection::NotFound | ReferenceSelection::Searching => { @@ -107,7 +106,7 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { } }; - let mut ops = Sequence::::new(); + let mut ops = Vec::::new(); for read_ops in block_plan.reads.into_values() { for op in read_ops { @@ -124,7 +123,7 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { } let config = FuseBlockConfig { - rank: plan.rank as u32, + rank: plan.rank, ref_layout: reference, ops, width: block_plan.width, @@ -153,7 +152,7 @@ fn register_inputs<'h, R: Runtime>( HandleInput::Normal(hi) => { let arg = hi .handle - .as_tensor_arg(&hi.global_ir.shape.dims, hi.vectorization); + .as_tensor_arg(&hi.global_ir.shape.dims, hi.line_size); inputs.tensors.push(GlobalTensorArg::new( arg, hi.precision.into_elem(), @@ -163,7 +162,7 @@ fn register_inputs<'h, R: Runtime>( HandleInput::QuantValues(hi) => { let arg = hi .handle - .as_tensor_arg(&hi.global_ir.shape.dims, hi.vectorization); + .as_tensor_arg(&hi.global_ir.shape.dims, hi.line_size); inputs .tensors .push(GlobalTensorArg::new(arg, hi.precision.into_elem(), false)); @@ -209,12 +208,12 @@ fn register_outputs<'s, BT: CubeElement, R: Runtime>( precision, handle, global_shape, - vectorization, + vectorization: line_size, #[cfg(feature = "autotune-checks")] relative_id, .. } => { - let arg = handle.as_tensor_arg(global_shape, *vectorization); + let arg = handle.as_tensor_arg(global_shape, *line_size); let elem = match precision { FuseType::Bool => match elem_dtype::() { @@ -269,7 +268,7 @@ fn register_scalars<'h, R: Runtime>( let global = context.tensors.get(reshaped).unwrap(); for shape in global.shape.iter() { - inputs.reshapes.push(ScalarArg::new(*shape as u32)); + inputs.reshapes.push(ScalarArg::new(*shape)); } } } diff --git a/crates/burn-cubecl-fusion/src/engine/launch/input.rs b/crates/burn-cubecl-fusion/src/engine/launch/input.rs index f712d95f6f..549e554ed6 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/input.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/input.rs @@ -84,7 +84,7 @@ impl<'a, R: Runtime> InputPlanner<'a, R> { global_ir: tensor_global, precision, handle, - vectorization: 1, + line_size: 1, })); plan.handle_inputs @@ -208,7 +208,7 @@ impl<'a, R: Runtime> InputPlanner<'a, R> { && shape_relative == &block.shape_ref { block_plan.potential_reference_input = Some(InputReference::Reshaped { - reshape_pos: *reshape_pos as usize, + reshape_pos: *reshape_pos, }); } return true; @@ -225,11 +225,7 @@ impl<'a, R: Runtime> InputPlanner<'a, R> { } if original == &tensor_relative.id { - let shape = tensor_relative - .shape - .clone() - .swap(dims.0 as usize, dims.1 as usize) - .unwrap(); + let shape = tensor_relative.shape.clone().swap(dims.0, dims.1).unwrap(); if block_plan.potential_reference_input.is_none() && shape.dims == block.shape_ref diff --git a/crates/burn-cubecl-fusion/src/engine/launch/output.rs b/crates/burn-cubecl-fusion/src/engine/launch/output.rs index deefeaa1a5..312c4a716e 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/output.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/output.rs @@ -198,7 +198,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { let set_ref_as_concrete = |block: &mut BlockPlan<'_>| { block.reference = ReferenceSelection::Concrete { layout: FuseArg::Input( - input_pos as u32, + input_pos, reference.precision, LayoutInfo::IsRef, ), @@ -210,7 +210,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { let set_ref_as_virtual = |block: &mut BlockPlan<'_>| { block.reference = ReferenceSelection::VirtualShape { original: FuseArg::Input( - input_pos as u32, + input_pos, reference.precision, LayoutInfo::Unknown, ), @@ -243,7 +243,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { .expect("Quant can't be used in swap dims operation"); block.reference = ReferenceSelection::SwapDims { original: FuseArg::Input( - original_pos as u32, + original_pos, reference.precision, LayoutInfo::Unknown, ), @@ -407,11 +407,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { && self.blocks[block_idx].shape_ref == output.tensor_relative.shape.dims { block.reference = ReferenceSelection::Concrete { - layout: FuseArg::Output( - output.pos_original as u32, - output.precision, - LayoutInfo::IsRef, - ), + layout: FuseArg::Output(output.pos_original, output.precision, LayoutInfo::IsRef), shape: tensor_global.shape.dims.clone(), strides: strides.clone(), }; @@ -553,7 +549,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { output: OutputSorted, tensor_global: TensorIr, original: TensorId, - dims: (u32, u32), + dims: (usize, usize), block_idx: usize, ) { let block = &mut plan.blocks[block_idx]; @@ -578,7 +574,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { dtype, qparams: original_handle.handle.qparams.clone(), }; - handle.strides.swap(dims.0 as usize, dims.1 as usize); + handle.strides.swap(dims.0, dims.1); context .handles diff --git a/crates/burn-cubecl-fusion/src/engine/launch/plan.rs b/crates/burn-cubecl-fusion/src/engine/launch/plan.rs index aae753e130..7f359be83b 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/plan.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/plan.rs @@ -7,7 +7,7 @@ use crate::{ }, }; use burn_ir::{TensorId, TensorIr}; -use cubecl::Runtime; +use cubecl::{Runtime, ir::LineSize}; use std::collections::BTreeMap; /// The plan is responsible to keep runtime information related to the launch of a fused kernel @@ -30,7 +30,7 @@ pub struct BlockPlan<'a> { pub reference: ReferenceSelection, pub reads: BTreeMap>, pub writes: BTreeMap, - pub width: u8, + pub width: LineSize, } #[derive(Debug)] @@ -40,7 +40,7 @@ pub enum InputReference { }, SwapDims { original_pos: usize, - dims: (u32, u32), + dims: (usize, usize), }, Reshaped { reshape_pos: usize, @@ -59,7 +59,7 @@ pub enum ReferenceSelection { }, SwapDims { original: FuseArg, - dims: (u32, u32), + dims: (usize, usize), }, Reshaped { reshape_pos: usize, @@ -137,7 +137,7 @@ pub enum HandleOutput { precision: FuseType, handle: CubeFusionHandle, global_shape: Vec, - vectorization: u8, + vectorization: LineSize, }, } @@ -147,7 +147,7 @@ pub struct NormalHandleInput { pub global_ir: TensorIr, pub precision: FuseType, pub handle: CubeFusionHandle, - pub vectorization: u8, + pub line_size: LineSize, pub broadcated: bool, // Strides can be modified during plan execution, but need to be restored on rollback pub orig_strides: Vec, @@ -159,7 +159,7 @@ pub struct QuantValuesHandleInput { pub global_ir: TensorIr, pub precision: FuseType, pub handle: CubeFusionHandle, - pub vectorization: u8, + pub line_size: LineSize, } #[derive(Debug)] @@ -200,7 +200,7 @@ impl NormalHandleInput { handle, relative_id: tensor_relative.id, global_ir: tensor_global, - vectorization: 1, + line_size: 1, broadcated: false, orig_strides: strides, } diff --git a/crates/burn-cubecl-fusion/src/engine/launch/runner.rs b/crates/burn-cubecl-fusion/src/engine/launch/runner.rs index 0f53dd8970..e08d243ef1 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/runner.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/runner.rs @@ -76,9 +76,9 @@ pub trait Vectorization { inputs: impl Iterator>, outputs: impl Iterator, reshaped: impl Iterator, - swapped: impl Iterator, - line_sizes: &[u8], - max: u8, + swapped: impl Iterator, + line_sizes: &[LineSize], + max: LineSize, axis: VectorizationAxis, ) { vectorization_default( diff --git a/crates/burn-cubecl-fusion/src/engine/launch/vectorization/base.rs b/crates/burn-cubecl-fusion/src/engine/launch/vectorization/base.rs index da33e5598e..c0a45d949a 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/vectorization/base.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/vectorization/base.rs @@ -4,18 +4,18 @@ use crate::{ }; use burn_fusion::stream::Context; use burn_ir::{TensorId, TensorIr}; -use cubecl::Runtime; +use cubecl::{Runtime, ir::LineSize}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; #[derive(Debug, Clone, Copy)] pub enum Vect { Broadcasted, - Aligned(u8), + Aligned(LineSize), } impl Vect { - pub fn line_size(&self) -> u8 { + pub fn line_size(&self) -> LineSize { match self { Vect::Broadcasted => 1, Vect::Aligned(val) => *val, @@ -29,13 +29,13 @@ impl Vect { #[derive(Default, Clone, Serialize, Deserialize, Debug)] pub struct LineSizeOverrides { - state: Option>>, - default: Option>, + state: Option>>, + default: Option>, } #[allow(unused)] impl LineSizeOverrides { - pub fn overrides(&mut self, tensor_id: &TensorId, line_sizes: Vec) { + pub fn overrides(&mut self, tensor_id: &TensorId, line_sizes: Vec) { let map = match &mut self.state { Some(val) => val, None => { @@ -46,7 +46,7 @@ impl LineSizeOverrides { map.insert(*tensor_id, line_sizes); } - pub fn overrides_default(&mut self, line_sizes: Vec) { + pub fn overrides_default(&mut self, line_sizes: Vec) { self.default = Some(line_sizes); } @@ -72,7 +72,7 @@ impl LineSizeOverrides { } } - pub fn tensor(&self, tensor_id: &TensorId) -> Option<&Vec> { + pub fn tensor(&self, tensor_id: &TensorId) -> Option<&Vec> { let map = match &self.state { Some(val) => val, None => match &self.default { @@ -97,10 +97,10 @@ pub(crate) fn vectorization_default<'a, R: Runtime>( inputs: impl Iterator>, outputs: impl Iterator, reshaped: impl Iterator, - swapped: impl Iterator, - line_sizes: &[u8], + swapped: impl Iterator, + line_sizes: &[LineSize], overrides: &LineSizeOverrides, - max: u8, + max: LineSize, axis: &VectorizationAxis, ) { let swapped: Vec<_> = swapped.collect(); @@ -148,7 +148,7 @@ pub(crate) fn vectorization_default<'a, R: Runtime>( overrides.tensor(&tensor_ir.id), ); let num_quants = match tensor_ir.dtype { - burn_std::DType::QFloat(quant_scheme) => quant_scheme.num_quants() as u8, + burn_std::DType::QFloat(quant_scheme) => quant_scheme.num_quants(), _ => panic!(""), }; let val = match val { @@ -214,8 +214,8 @@ fn vectorization_input( handle: &CubeFusionHandle, desc: &TensorIr, axis: &VectorizationAxis, - line_sizes: &[u8], - overrides: Option<&Vec>, + line_sizes: &[LineSize], + overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(desc.id, || handle.strides.len() - 1); let shape_axis = desc.shape[axis]; @@ -229,9 +229,9 @@ fn vectorization_input( return Vect::Aligned(1); } - let inner = |s: u8| { + let inner = |s: LineSize| { // The last dimension should be a multiple of the vector size or broadcated. - if shape_axis.is_multiple_of(s as usize) { + if shape_axis.is_multiple_of(s) { return Some(Vect::Aligned(s)); } None @@ -260,15 +260,15 @@ fn vectorization_input( fn vectorization_output( desc: &TensorIr, axis: &VectorizationAxis, - line_sizes: &[u8], - max: u8, - overrides: Option<&Vec>, + line_sizes: &[LineSize], + max: LineSize, + overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(desc.id, || desc.shape.rank() - 1); - let inner = |s: u8| { + let inner = |s: LineSize| { // The dimension should be a multiple of the vector size. - if desc.shape[axis].is_multiple_of(s as usize) && s <= max { + if desc.shape[axis].is_multiple_of(s) && s <= max { return Some(Vect::Aligned(s)); } @@ -299,9 +299,9 @@ fn vectorization_reshape( original: &TensorIr, multi_reads: bool, axis: &VectorizationAxis, - line_sizes: &[u8], - max: u8, - overrides: Option<&Vec>, + line_sizes: &[LineSize], + max: LineSize, + overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(reshaped.id, || reshaped.shape.rank() - 1); let reshape_shape_axis = reshaped.shape[axis]; @@ -321,10 +321,10 @@ fn vectorization_reshape( return Vect::Aligned(1); } - let inner = |s: u8| { + let inner = |s: LineSize| { if !multi_reads { // The last dimension should be a multiple of the vector size or broadcated. - if reshape_shape_axis.is_multiple_of(s as usize) && s <= max { + if reshape_shape_axis.is_multiple_of(s) && s <= max { Some(Vect::Aligned(s)) } else { None @@ -333,8 +333,8 @@ fn vectorization_reshape( // Since the original tensor must share the same vectorization factor as the // reshaped tensor, they must have compatible shapes when both are access // independently. - if reshape_shape_axis.is_multiple_of(s as usize) - && original_shape_axis.is_multiple_of(s as usize) + if reshape_shape_axis.is_multiple_of(s) + && original_shape_axis.is_multiple_of(s) && s <= max { Some(Vect::Aligned(s)) @@ -370,11 +370,11 @@ fn vectorization_swapped( swapped: &TensorIr, original: &TensorIr, multi_reads: bool, - dims: &(u32, u32), - max: u8, + dims: &(usize, usize), + max: LineSize, axis: &VectorizationAxis, - line_sizes: &[u8], - overrides: Option<&Vec>, + line_sizes: &[LineSize], + overrides: Option<&Vec>, ) -> Vect { let axis = axis.get(swapped.id, || swapped.shape.rank() - 1); @@ -382,10 +382,10 @@ fn vectorization_swapped( let shape_axis = original.shape[axis]; let axis_index = axis; - let dim_index = if dims.0 as usize == axis_index { - dims.1 as usize - } else if dims.1 as usize == axis_index { - dims.0 as usize + let dim_index = if dims.0 == axis_index { + dims.1 + } else if dims.1 == axis_index { + dims.0 } else { axis_index }; @@ -406,16 +406,13 @@ fn vectorization_swapped( return Vect::Broadcasted; } - let inner = |s: u8| { + let inner = |s: LineSize| { // The last dimension should be a multiple of the vector size or broadcated. if multi_reads { - if swapped_axis.is_multiple_of(s as usize) && s <= max { + if swapped_axis.is_multiple_of(s) && s <= max { return Some(Vect::Aligned(s)); } - } else if swapped_axis.is_multiple_of(s as usize) - && shape_axis.is_multiple_of(s as usize) - && s <= max - { + } else if swapped_axis.is_multiple_of(s) && shape_axis.is_multiple_of(s) && s <= max { return Some(Vect::Aligned(s)); } None diff --git a/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs b/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs index 6e9ebe120c..8d4476cf75 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs @@ -15,12 +15,15 @@ use crate::{ }; use burn_fusion::stream::Context; use burn_ir::TensorId; -use cubecl::quant::scheme::{QuantScheme, QuantStore, QuantValue}; use cubecl::{ Runtime, client::ComputeClient, ir::{ElemType, StorageType, UIntKind}, }; +use cubecl::{ + ir::LineSize, + quant::scheme::{QuantScheme, QuantStore, QuantValue}, +}; use std::marker::PhantomData; /// Select the best vectorization factor for each tensor handle. @@ -78,7 +81,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { }); let mut ref_elem = (ElemType::UInt(UIntKind::U64).into(), 8); - let mut quants_line_sizes: Option> = None; + let mut quants_line_sizes: Option> = None; for input in plan.handle_inputs.iter() { let elem: StorageType = match input { @@ -124,7 +127,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { Some(line_sizes) => line_sizes, None => client .io_optimized_line_sizes_unchecked(ref_elem.0.size()) - .collect::>(), + .collect::>(), }; let vectorization_axis = runner.axis(plan); @@ -154,7 +157,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { tensors_reshaped, tensors_swapped, &line_sizes, - u8::MAX, + u8::MAX as usize, vectorization_axis, ); @@ -228,13 +231,13 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { match plan.vectorizations.get(&input_global.id).unwrap() { Vect::Aligned(vect) => { - let handle = &mut plan.handle_inputs[pos as usize]; + let handle = &mut plan.handle_inputs[pos]; match handle { HandleInput::Normal(handle) => { - handle.vectorization = *vect; + handle.line_size = *vect; } HandleInput::QuantValues(handle) => { - handle.vectorization = *vect; + handle.line_size = *vect; } HandleInput::QuantParams(_) => {} } @@ -255,7 +258,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { &mut plan.handle_inputs, &mut plan.handle_outputs, block_plan, - u8::MAX, + u8::MAX as usize, ); } VectorizationSetting::SmallerOrEqualThanPreviousBlock => { @@ -291,7 +294,7 @@ enum VectorizationAction { #[derive(Debug)] struct BlockVectorization { action: VectorizationAction, - potential: u8, + potential: LineSize, broadcasted: bool, } @@ -300,7 +303,7 @@ fn apply_vectorization_block( inputs: &mut [HandleInput], outputs: &mut [HandleOutput], block_plan: &mut BlockPlan, - max: u8, + max: LineSize, ) { for item in block_vectorization { match item.action { @@ -313,11 +316,11 @@ fn apply_vectorization_block( match &mut inputs[pos] { HandleInput::Normal(input) => { - input.vectorization = vect; + input.line_size = vect; input.broadcated = br; } HandleInput::QuantValues(input) => { - input.vectorization = vect; + input.line_size = vect; } HandleInput::QuantParams(_) => { // Not vectorized @@ -348,7 +351,7 @@ fn apply_vectorization_block( fn line_sizes_quants( client: &ComputeClient, - quants_line_sizes: &mut Option>, + quants_line_sizes: &mut Option>, scheme: QuantScheme, ) { match scheme.store { @@ -361,7 +364,7 @@ fn line_sizes_quants( | QuantValue::E2M1 => { let line_sizes = client .io_optimized_line_sizes_unchecked(size_of::()) - .collect::>(); + .collect::>(); match &quants_line_sizes { Some(sizes) => { @@ -381,9 +384,9 @@ fn line_sizes_quants( QuantStore::U32 => { let mut line_sizes = client .io_optimized_line_sizes_unchecked(size_of::()) - .collect::>(); + .collect::>(); for val in line_sizes.iter_mut() { - *val *= scheme.num_quants() as u8; + *val *= scheme.num_quants(); } match &quants_line_sizes { diff --git a/crates/burn-cubecl-fusion/src/engine/trace/base.rs b/crates/burn-cubecl-fusion/src/engine/trace/base.rs index da0fbe11d0..573c48cbdf 100644 --- a/crates/burn-cubecl-fusion/src/engine/trace/base.rs +++ b/crates/burn-cubecl-fusion/src/engine/trace/base.rs @@ -169,13 +169,13 @@ pub enum TensorView { Reshape { reshaped: TensorId, original: TensorId, - reshape_pos: u32, + reshape_pos: usize, shape_relative: Vec, }, SwapDims { swapped: TensorId, original: TensorId, - dims: (u32, u32), + dims: (usize, usize), }, } @@ -227,7 +227,7 @@ impl RegisteredTensors { } /// Doesn't return quantized tensor. - pub fn get_index(&self, tensor_id: TensorId) -> Option { + pub fn get_index(&self, tensor_id: TensorId) -> Option { self.tensors .iter() .enumerate() @@ -236,11 +236,11 @@ impl RegisteredTensors { RegisterTensor::QuantValues(_) => false, RegisterTensor::QuantParams(_) => false, }) - .map(|(pos, _)| pos as u32) + .map(|(pos, _)| pos) } /// Get the index of a quantized tensor. - pub fn get_index_quant(&self, tensor_id: TensorId) -> Option { + pub fn get_index_quant(&self, tensor_id: TensorId) -> Option { self.tensors .iter() .enumerate() @@ -249,7 +249,7 @@ impl RegisteredTensors { RegisterTensor::QuantValues(tensor_ir) => tensor_ir.id == tensor_id, RegisterTensor::QuantParams(_) => false, }) - .map(|(pos, _)| pos as u32) + .map(|(pos, _)| pos) } /// Doesn't return quantized tensor. @@ -273,12 +273,12 @@ impl RegisteredTensors { /// Insert a quantized tensor. /// /// It will return the positions for both the value tensor and param tensor. - pub fn insert_quant(&mut self, tensor: TensorIr) -> (u32, u32) { + pub fn insert_quant(&mut self, tensor: TensorIr) -> (usize, usize) { if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val { RegisterTensor::QuantValues(tensor_ir) => tensor_ir == &tensor, _ => false, }) { - let values = old.0 as u32; + let values = old.0; let params = values + 1; return (values, params); } @@ -291,16 +291,16 @@ impl RegisteredTensors { let pos_params = self.len(); self.tensors.push(params); - (pos_values as u32, pos_params as u32) + (pos_values, pos_params) } /// Insert a normal tensor with the given [precision](FusePrecision) in the current block. - pub fn insert(&mut self, precision: FuseType, tensor: TensorIr) -> u32 { + pub fn insert(&mut self, precision: FuseType, tensor: TensorIr) -> usize { if let Some(old) = self.tensors.iter().enumerate().find(|(_, val)| match &val { RegisterTensor::Normal(tensor_ir, _) => tensor_ir == &tensor, _ => false, }) { - return old.0 as u32; + return old.0; } let value = RegisterTensor::Normal(tensor, precision); @@ -308,7 +308,7 @@ impl RegisteredTensors { self.tensors.push(value); - pos as u32 + pos } /// Update the already registered tensor with the given [tensor ir](TensorIr). diff --git a/crates/burn-cubecl-fusion/src/engine/trace/block.rs b/crates/burn-cubecl-fusion/src/engine/trace/block.rs index 6cd295dfd1..764a50810d 100644 --- a/crates/burn-cubecl-fusion/src/engine/trace/block.rs +++ b/crates/burn-cubecl-fusion/src/engine/trace/block.rs @@ -4,7 +4,6 @@ use crate::engine::{ }; use burn_ir::{TensorId, TensorIr, TensorStatus}; use burn_std::{DType, quantization::QuantParam}; -use cubecl::prelude::Sequence; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, btree_map::Entry}; @@ -210,7 +209,7 @@ impl FuseBlockBuilder { &mut self, tensor: &TensorIr, output: &TensorIr, - dims: (u32, u32), + dims: (usize, usize), resources: &mut FuseResources, ) -> Option { if matches!(tensor.dtype, DType::QFloat(..)) { @@ -317,7 +316,7 @@ impl FuseBlockBuilder { let out = self.output(output, resources)?; let original = FuseArg::Input(input_index, precision_input, LayoutInfo::Unknown); - let mut shape = Sequence::new(); + let mut shape = Vec::new(); let index = resources.num_reshaped; resources.num_reshaped += 1; @@ -326,13 +325,13 @@ impl FuseBlockBuilder { for i in 0..output.shape.rank() { let id = index * rank + i; - shape.push(FuseArg::ScalarShape(id as u32)); + shape.push(FuseArg::ScalarShape(id)); } resources.views.push(TensorView::Reshape { reshaped: output.id, original: tensor.id, - reshape_pos: index as u32, + reshape_pos: index, shape_relative: output.shape.dims.clone(), }); @@ -381,11 +380,7 @@ impl FuseBlockBuilder { tensor.id, FuseOp::Assign(UnaryFuseArgs { input: local, - out: FuseArg::Output( - out_index + offset as u32, - *precision, - LayoutInfo::Unknown, - ), + out: FuseArg::Output(out_index + offset, *precision, LayoutInfo::Unknown), }), ); } @@ -657,7 +652,7 @@ impl FuseBlockBuilder { #[derive(Default, Clone, Debug)] struct LocalVariablePool { - values: BTreeMap>, + values: BTreeMap>, } impl LocalVariablePool { @@ -681,7 +676,7 @@ impl LocalVariablePool { None } - fn find_tensor_id(&self, precision: FuseType, position: u32) -> Option { + fn find_tensor_id(&self, precision: FuseType, position: usize) -> Option { if let Some(indexes) = self.values.get(&precision) { indexes .iter() @@ -694,7 +689,7 @@ impl LocalVariablePool { fn create(&mut self, precision: FuseType, tensor_id: TensorId) -> FuseArg { if let Some(indexes) = self.values.get_mut(&precision) { - let new_index = indexes.len() as u32; + let new_index = indexes.len(); indexes.insert(tensor_id, new_index); return FuseArg::Local(new_index, precision); } diff --git a/crates/burn-cubecl-fusion/src/engine/trace/fuser.rs b/crates/burn-cubecl-fusion/src/engine/trace/fuser.rs index a49a4dbca2..00c62c43e4 100644 --- a/crates/burn-cubecl-fusion/src/engine/trace/fuser.rs +++ b/crates/burn-cubecl-fusion/src/engine/trace/fuser.rs @@ -205,7 +205,7 @@ impl TraceFuser { &mut self, tensor: &TensorIr, output: &TensorIr, - dims: (u32, u32), + dims: (usize, usize), ) -> Option { if matches!(tensor.dtype, DType::QFloat(_)) { return None; @@ -238,7 +238,7 @@ impl TraceFuser { FuseType::Bool => self.bool_precision, _ => precision, }; - let new_index = self.resources.scalars.len() as u32; + let new_index = self.resources.scalars.len(); self.resources.scalars.push((precision, id.value)); FuseArg::Scalar(new_index, precision) diff --git a/crates/burn-cubecl-fusion/src/optim/elemwise/optimization.rs b/crates/burn-cubecl-fusion/src/optim/elemwise/optimization.rs index ab6dff4893..e6b895a914 100644 --- a/crates/burn-cubecl-fusion/src/optim/elemwise/optimization.rs +++ b/crates/burn-cubecl-fusion/src/optim/elemwise/optimization.rs @@ -86,13 +86,13 @@ impl TraceRunner for ElemwiseRunner { let config = &configs[0]; let shape = match &config.ref_layout { RefLayout::Concrete(arg) => match arg { - FuseArg::Input(..) => inputs.shape_ref(&config.ref_layout, config.rank as usize), - FuseArg::Output(..) => outputs.shape_ref(&config.ref_layout, config.rank as usize), + FuseArg::Input(..) => inputs.shape_ref(&config.ref_layout, config.rank), + FuseArg::Output(..) => outputs.shape_ref(&config.ref_layout, config.rank), _ => panic!("Invalid concreate ref layout"), }, - RefLayout::Virtual(_) => inputs.shape_ref(&config.ref_layout, config.rank as usize), + RefLayout::Virtual(_) => inputs.shape_ref(&config.ref_layout, config.rank), }; - let working_units = shape.iter().product::() / config.width as usize; + let working_units = shape.iter().product::() / config.width; let cube_dim = CubeDim::new(client, working_units); let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim); @@ -119,7 +119,7 @@ fn elemwise_fuse( ) { // We write no values for this fusion. let values = Registry::>::new(); - let args = comptime![Sequence::::new()]; + let args = comptime![Vec::::new()]; let pos = ABSOLUTE_POS; let mut locals = init_locals(inputs, outputs, config); diff --git a/crates/burn-cubecl-fusion/src/optim/matmul/args.rs b/crates/burn-cubecl-fusion/src/optim/matmul/args.rs index 538c0ac9f5..ccdfa4e3c8 100644 --- a/crates/burn-cubecl-fusion/src/optim/matmul/args.rs +++ b/crates/burn-cubecl-fusion/src/optim/matmul/args.rs @@ -16,7 +16,7 @@ use cubecl::{ }, tensor::{ View, ViewExpand, - layout::{Coords1d, Coords2d, Coords3d, VirtualLayout}, + layout::{Coords1d, Coords2d, VirtualLayout}, }, }, }; @@ -26,7 +26,7 @@ use cubek::matmul::{ GlobalScaleLayout, GlobalScaleLayoutExpand, NoopLayout, }, definition::MatrixLayout, - launch::MatmulArgs, + launch::{BatchedCoords, MatmulArgs}, }; use serde::{Deserialize, Serialize}; use std::marker::PhantomData; @@ -70,7 +70,7 @@ impl MatmulArgs for FusedMatmulArgs { #[unroll] for i in 0..rank - 2 { - batch_shape.push(FastDivmod::new_Fallback(locals.ref_shape[i])); + batch_shape.push(FastDivmod::new_Fallback(locals.ref_shape[i] as u32)); batch_strides_out.push(locals.ref_strides[i]); } @@ -115,27 +115,27 @@ impl MatmulArgs for FusedMatmulArgs { fn view_lhs( state: &Self::State, - ) -> View, Coords3d> { + ) -> View, BatchedCoords> { global_view( &state.inputs, &state.locals, &state.batch_shape, comptime![state.a.clone()], comptime![state.config.clone()], - comptime![state.lhs_layout_config], + state.lhs_layout_config, ) } fn batch_lhs( state: &Self::State, - batch: u32, - ) -> u32 { + batch: usize, + ) -> usize { state.a_batch.to_source_pos(batch) } fn view_rhs( state: &Self::State, - ) -> View, Coords3d> { + ) -> View, BatchedCoords> { global_view( &state.inputs, &state.locals, @@ -148,14 +148,14 @@ impl MatmulArgs for FusedMatmulArgs { fn batch_rhs( state: &Self::State, - batch: u32, - ) -> u32 { + batch: usize, + ) -> usize { state.b_batch.to_source_pos(batch) } fn view_acc( state: &Self::State, - ) -> CubeOption, Coords3d>> { + ) -> CubeOption, BatchedCoords>> { match comptime![state.c.clone()] { Option::Some(c) => { let view = global_view( @@ -174,8 +174,8 @@ impl MatmulArgs for FusedMatmulArgs { fn batch_acc( state: &Self::State, - batch: u32, - ) -> u32 { + batch: usize, + ) -> usize { match state.c_batch { CubeOption::Some(c_batch) => c_batch.to_source_pos(batch), CubeOption::None => batch, @@ -184,11 +184,11 @@ impl MatmulArgs for FusedMatmulArgs { fn view_out( state: &mut Self::State, - ) -> View, Coords3d, ReadWrite> { + ) -> View, BatchedCoords, ReadWrite> { let rank = comptime![state.config.rank]; - let shape_row = state.locals.ref_shape[rank - 2]; - let shape_col = state.locals.ref_shape[rank - 1]; + let shape_row = state.locals.ref_shape[rank - 2] as u32; + let shape_col = state.locals.ref_shape[rank - 1] as u32; let stride_row = state.locals.ref_strides[rank - 2]; let stride_col = state.locals.ref_strides[rank - 1]; @@ -215,8 +215,8 @@ impl MatmulArgs for FusedMatmulArgs { fn batch_out( state: &Self::State, - batch: u32, - ) -> u32 { + batch: usize, + ) -> usize { state.out_batch.to_source_pos(batch) } } @@ -225,11 +225,11 @@ impl MatmulArgs for FusedMatmulArgs { fn global_view( inputs: &GlobalArgs, locals: &LocalArgs, - batch_shape: &Sequence, + batch_shape: &Sequence>, #[comptime] arg: MatmulArg, #[comptime] config: FuseBlockConfig, #[comptime] layout_config: GlobalLayoutConfig, -) -> View, Coords3d> { +) -> View, BatchedCoords> { let rank = comptime![config.rank]; let data = comptime![arg.data().clone()]; let data_tensor = match comptime![data.clone()] { @@ -237,13 +237,13 @@ fn global_view( _ => panic!("Input must be concrete"), }; - let mut shape_row = data_tensor.tensor.shape(rank - 2); - let mut shape_col = data_tensor.tensor.shape(rank - 1); - let mut packing = comptime![1u32]; + let mut shape_row = data_tensor.tensor.shape(rank - 2) as u32; + let mut shape_col = data_tensor.tensor.shape(rank - 1) as u32; + let mut packing = comptime![1]; - if comptime![arg.scheme().is_some()] { - let scheme = comptime![arg.scheme().unwrap()]; - let num_quants = comptime![scheme.num_quants() as u32]; + if arg.scheme().is_some() { + let scheme = arg.scheme().unwrap(); + let num_quants = scheme.num_quants() as u32; comptime![packing = num_quants]; match comptime![layout_config.matrix_layout] { MatrixLayout::RowMajor => shape_col *= num_quants, @@ -266,8 +266,8 @@ fn global_view( inputs, shape, batch_layout, - comptime![arg.data().clone()], - comptime![config.clone()], + arg.data().clone(), + config.clone(), data_tensor.tensor.line_size(), layout_config, packing, @@ -296,7 +296,7 @@ fn global_view( batch_layout, comptime![scales.clone()], comptime![config.clone()], - 1u32, + 1usize, layout_config, 1u32, ); @@ -316,10 +316,10 @@ fn global_view( #[cube] fn input_batch_layout( inputs: &GlobalArgs, - batch_shape: &Sequence, + batch_shape: &Sequence>, #[comptime] arg: MatmulArg, #[comptime] config: FuseBlockConfig, -) -> VirtualLayout { +) -> VirtualLayout { let rank = comptime![config.rank]; match comptime![arg.clone()] { MatmulArg::Normal(arg) => { @@ -346,10 +346,10 @@ fn input_batch_layout( fn global_layout( inputs: &GlobalArgs, shape: Coords2d, - batch_layout: VirtualLayout, + batch_layout: VirtualLayout, #[comptime] arg: FuseArg, #[comptime] config: FuseBlockConfig, - #[comptime] line_size: u32, + #[comptime] line_size: LineSize, #[comptime] layout_config: GlobalLayoutConfig, #[comptime] packing: u32, ) -> GlobalLayout { @@ -387,7 +387,7 @@ struct CreateQuantView<'a, E: Numeric> { } impl<'a, E: Numeric> RunWithQuantType for CreateQuantView<'a, E> { - type Output = ViewExpand, Coords3d>; + type Output = ViewExpand, BatchedCoords>; fn execute(self) -> Self::Output { create_quant_view::expand::( @@ -409,7 +409,7 @@ fn create_quant_view_dynamic( scales_buf: GlobalInput, scales_layout: GlobalScaleLayout, #[comptime] scheme: QuantScheme, -) -> View, Coords3d> { +) -> View, BatchedCoords> { intrinsic!(|scope| { let func = CreateQuantView { scope, @@ -431,10 +431,10 @@ fn create_quant_view( scales_buf: GlobalInput, scales_layout: GlobalScaleLayout, #[comptime] scheme: QuantScheme, -) -> View, Coords3d> { - let data_view: View, Coords3d> = +) -> View, BatchedCoords> { + let data_view: View, BatchedCoords> = View::new::(&data_buf, data_layout); - let scales_view: View = + let scales_view: View = View::new::(&scales_buf, scales_layout); QuantizedView::new(data_view, scales_view, scheme).view() } @@ -464,7 +464,7 @@ pub struct FusedMatmulState { rhs_layout_config: GlobalLayoutConfig, #[cube(comptime)] out_layout_config: GlobalLayoutConfig, - batch_shape: Sequence, + batch_shape: Sequence>, } #[cube] @@ -474,11 +474,11 @@ impl FusedMatmulState { inputs: &FusedMatmulInput, outputs: &mut GlobalArgs, locals: &mut LocalArgs, - a_batch: VirtualLayout, - b_batch: VirtualLayout, - c_batch: CubeOption>, - out_batch: VirtualLayout, - batch_shape: Sequence, + a_batch: VirtualLayout, + b_batch: VirtualLayout, + c_batch: CubeOption>, + out_batch: VirtualLayout, + batch_shape: Sequence>, #[comptime] config: &FuseBlockConfig, #[comptime] lhs_layout_config: GlobalLayoutConfig, #[comptime] rhs_layout_config: GlobalLayoutConfig, diff --git a/crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs b/crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs index f5f4dbad16..9f10272b34 100644 --- a/crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs +++ b/crates/burn-cubecl-fusion/src/optim/matmul/optimization.rs @@ -26,8 +26,8 @@ use cubecl::{ use cubek::matmul::{ components::tile::{cmma::CmmaMatmul, io::Filled, mma::MmaMatmul}, definition::{ - MatmulElemType, MatmulElems, MatmulGlobalElems, MatmulLineSizes, MatmulProblem, - MatmulSetupError, MatrixLayout, + MatmulElems, MatmulGlobalElems, MatmulLineSizes, MatmulProblem, MatmulSetupError, + MatrixLayout, }, launch::launch_kernel_virtual, routines::{ @@ -359,18 +359,9 @@ impl TraceRunner for FusedMatmulLaunch<'_> { configs: &'a [FuseBlockConfig], ) -> Result<(), FusedMatmulError> { let global_elems = MatmulGlobalElems { - lhs: MatmulElemType { - dtype: self.matmul.lhs.precision().into_type(), - quantized: false, - }, - rhs: MatmulElemType { - dtype: self.matmul.rhs.precision().into_type(), - quantized: false, - }, - out: MatmulElemType { - dtype: self.matmul.out.precision().into_type(), - quantized: false, - }, + lhs: self.matmul.lhs.precision().into_type(), + rhs: self.matmul.rhs.precision().into_type(), + out: self.matmul.out.precision().into_type(), }; let dtypes = MatmulElems::from_globals(&global_elems); self.matmul_fused(client, inputs, outputs, &configs[0], dtypes) @@ -411,7 +402,7 @@ impl FusedMatmulLaunch<'_> { ) -> Result<(), FusedMatmulError> { let lhs_shape = inputs.shape(self.matmul.lhs.data()); let rhs_shape = inputs.shape(self.matmul.rhs.data()); - let out_shape = outputs.shape_ref(&config.ref_layout, config.rank as usize); + let out_shape = outputs.shape_ref(&config.ref_layout, config.rank); let lhs_strides = inputs.strides(self.matmul.lhs.data()); let rhs_strides = inputs.strides(self.matmul.rhs.data()); @@ -447,10 +438,10 @@ impl FusedMatmulLaunch<'_> { } if let MatmulArg::Quantized { scheme, .. } = self.matmul.lhs { - line_sizes.lhs *= scheme.num_quants() as u8; + line_sizes.lhs *= scheme.num_quants(); } if let MatmulArg::Quantized { scheme, .. } = self.matmul.rhs { - line_sizes.rhs *= scheme.num_quants() as u8; + line_sizes.rhs *= scheme.num_quants(); } let out_strides = MatrixLayout::RowMajor.to_strides(&out_shape); @@ -485,7 +476,6 @@ impl FusedMatmulLaunch<'_> { problem, line_sizes, &BlueprintStrategy::Inferred(SimpleArgs { multi_rows }), - dtypes, ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), @@ -510,7 +500,6 @@ impl FusedMatmulLaunch<'_> { problem, line_sizes, &BlueprintStrategy::Inferred(DoubleBufferingArgs { specialized }), - dtypes, ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), @@ -542,7 +531,6 @@ impl FusedMatmulLaunch<'_> { rows_per_plane: Some(2), partition_k: Some(2), }), - dtypes, ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), @@ -563,7 +551,6 @@ impl FusedMatmulLaunch<'_> { problem, line_sizes, &Default::default(), - dtypes, ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), @@ -584,7 +571,6 @@ impl FusedMatmulLaunch<'_> { problem, line_sizes, &Default::default(), - dtypes, ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), @@ -605,7 +591,6 @@ impl FusedMatmulLaunch<'_> { problem, line_sizes, &Default::default(), - dtypes, ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), @@ -626,7 +611,6 @@ impl FusedMatmulLaunch<'_> { problem, line_sizes, &Default::default(), - dtypes, ) { Ok(_) => Ok(()), Err(err) => Err(FusedMatmulError::LaunchError(err)), @@ -643,27 +627,13 @@ fn launch_inner_fix_dtype<'a, R: Runtime, A: Routine>( problem: MatmulProblem, line_sizes: MatmulLineSizes, blueprint_strategy: &BlueprintStrategy, - mut dtypes: MatmulElems, ) -> Result<(), MatmulSetupError> { - let fix_plane_dim = |plane_dim: u32| { - // Sometimes the GPU doesn't support plane instructions and doesn't report the - // plane size, but we can still execute algorithms that don't use plane instructions. - // - // In this case, we set a plane size for the selector to work, defaulting to 32 as it - // is a common plane size. - if plane_dim == 0 { 32 } else { plane_dim } - }; - - let plane_size = fix_plane_dim(A::select_plane_dim(client)); - launch_kernel_virtual::( client, input, output, problem, line_sizes, - plane_size, blueprint_strategy, - &mut dtypes, ) } diff --git a/crates/burn-cubecl-fusion/src/optim/matmul/tune.rs b/crates/burn-cubecl-fusion/src/optim/matmul/tune.rs index 96a24131b2..f1cf0d166b 100644 --- a/crates/burn-cubecl-fusion/src/optim/matmul/tune.rs +++ b/crates/burn-cubecl-fusion/src/optim/matmul/tune.rs @@ -6,13 +6,12 @@ use crate::{ tune::{TuneContext, TuneInput}, }; use burn_fusion::stream::Context; -use burn_std::DType; use cubecl::{ AutotuneKey, CubeElement, CubeTuneId, Runtime, tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner}, }; use cubek::matmul::{ - definition::{MatmulElemType, MatmulKind}, + definition::MatmulKind, launch::{MatmulAutotuneKey, MatmulGlobalScale, should_tune_double_buffering}, }; use serde::{Deserialize, Serialize}; @@ -217,18 +216,9 @@ pub(crate) fn create_key( &rhs.shape.dims, &lhs_strides, &rhs_strides, - MatmulElemType { - dtype: lhs.dtype.into(), - quantized: matches!(lhs.dtype, DType::QFloat(_)), - }, - MatmulElemType { - dtype: rhs.dtype.into(), - quantized: matches!(rhs.dtype, DType::QFloat(_)), - }, - MatmulElemType { - dtype: out.dtype.into(), - quantized: matches!(out.dtype, DType::QFloat(_)), - }, + lhs.dtype.into(), + rhs.dtype.into(), + out.dtype.into(), ); FusedMatmulAutotuneKey::new(key, opt.info.num_output_buffers(), opt.info.num_ops_fused()) } diff --git a/crates/burn-cubecl-fusion/src/optim/reduce/args.rs b/crates/burn-cubecl-fusion/src/optim/reduce/args.rs index e56fcac60c..acf901494b 100644 --- a/crates/burn-cubecl-fusion/src/optim/reduce/args.rs +++ b/crates/burn-cubecl-fusion/src/optim/reduce/args.rs @@ -65,8 +65,8 @@ impl ReduceArgs for FusedReduceArgs { FusedReduceState::new(input, output, &mut locals_read, &mut locals_write) } - fn read_input(state: &Self::State

, index: u32) -> Line { - *fuse_on_read::( + fn read_input(state: &Self::State

, index: usize) -> Line { + fuse_on_read::( unsafe { &(*state.inputs) }, unsafe { &mut (*state.outputs) }, unsafe { &mut (*state.locals_on_read) }, @@ -77,17 +77,16 @@ impl ReduceArgs for FusedReduceArgs { sequence }, &state.config_on_read, - ) - .index(0) + )[0] } - fn read_output(_state: &Self::State

, _index: u32) -> Line { - Line::empty(1_u32) + fn read_output(_state: &Self::State

, _index: usize) -> Line { + Line::empty(1usize) } - fn write_output(state: &mut Self::State

, index: u32, value: Line) { + fn write_output(state: &mut Self::State

, index: usize, value: Line) { let mut values = Registry::>::new(); - let mut args = comptime![Sequence::::new()]; + let mut args = comptime![Vec::::new()]; values.insert(comptime![state.out.clone()], value); comptime![args.push(state.out.clone())]; @@ -103,7 +102,7 @@ impl ReduceArgs for FusedReduceArgs { ); } - fn len_input(state: &Self::State

) -> u32 { + fn len_input(state: &Self::State

) -> usize { ref_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, @@ -112,7 +111,7 @@ impl ReduceArgs for FusedReduceArgs { ) } - fn len_output(state: &Self::State

) -> u32 { + fn len_output(state: &Self::State

) -> usize { ref_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, @@ -121,7 +120,7 @@ impl ReduceArgs for FusedReduceArgs { ) } - fn buffer_len_input(state: &Self::State

) -> u32 { + fn buffer_len_input(state: &Self::State

) -> usize { ref_buffer_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, @@ -130,7 +129,7 @@ impl ReduceArgs for FusedReduceArgs { ) } - fn buffer_len_output(state: &Self::State

) -> u32 { + fn buffer_len_output(state: &Self::State

) -> usize { ref_buffer_len( unsafe { &(*state.inputs) }, unsafe { &(*state.outputs) }, @@ -139,35 +138,35 @@ impl ReduceArgs for FusedReduceArgs { ) } - fn rank_input(state: &Self::State

) -> u32 { + fn rank_input(state: &Self::State

) -> usize { state.config_on_read.rank.runtime() } - fn rank_output(state: &Self::State

) -> u32 { + fn rank_output(state: &Self::State

) -> usize { state.config_on_write.rank.runtime() } - fn shape_input(state: &Self::State

, dim: u32) -> u32 { + fn shape_input(state: &Self::State

, dim: usize) -> usize { ref_shape(unsafe { &(*state.locals_on_read) }, dim) } - fn shape_output(state: &Self::State

, dim: u32) -> u32 { + fn shape_output(state: &Self::State

, dim: usize) -> usize { ref_shape(unsafe { &(*state.locals_on_write) }, dim) } - fn stride_input(state: &Self::State

, dim: u32) -> u32 { + fn stride_input(state: &Self::State

, dim: usize) -> usize { ref_stride(unsafe { &(*state.locals_on_read) }, dim) } - fn stride_output(state: &Self::State

, dim: u32) -> u32 { + fn stride_output(state: &Self::State

, dim: usize) -> usize { ref_stride(unsafe { &(*state.locals_on_write) }, dim) } - fn line_size_input(state: &Self::State

) -> comptime_type!(u32) { + fn line_size_input(state: &Self::State

) -> comptime_type!(LineSize) { ref_line_size(unsafe { &(*state.locals_on_read) }) } - fn line_size_output(state: &Self::State

) -> comptime_type!(u32) { + fn line_size_output(state: &Self::State

) -> comptime_type!(LineSize) { ref_line_size(unsafe { &(*state.locals_on_write) }) } } diff --git a/crates/burn-cubecl-fusion/src/optim/reduce/optimization.rs b/crates/burn-cubecl-fusion/src/optim/reduce/optimization.rs index 5a86858475..8081588e42 100644 --- a/crates/burn-cubecl-fusion/src/optim/reduce/optimization.rs +++ b/crates/burn-cubecl-fusion/src/optim/reduce/optimization.rs @@ -279,17 +279,17 @@ impl TraceRunner for FusedReduceLaunch<'_> { let [config_read, config_write] = [&configs[0], &configs[1]]; let shape = match &config_read.ref_layout { RefLayout::Concrete(FuseArg::Output(..)) => { - outputs.shape_ref(&config_read.ref_layout, config_read.rank as usize) + outputs.shape_ref(&config_read.ref_layout, config_read.rank) } - _ => inputs.shape_ref(&config_read.ref_layout, config_read.rank as usize), + _ => inputs.shape_ref(&config_read.ref_layout, config_read.rank), }; - let reduce_count: u32 = shape + let reduce_count: usize = shape .iter() .enumerate() - .map(|(i, s)| if i == self.reduce.axis { 1 } else { *s as u32 }) + .map(|(i, s)| if i == self.reduce.axis { 1 } else { *s }) .product(); - let line_mode = match self.reduce.axis == config_read.rank as usize - 1 { + let line_mode = match self.reduce.axis == config_read.rank - 1 { true => LineMode::Parallel, false => LineMode::Perpendicular, }; @@ -300,9 +300,9 @@ impl TraceRunner for FusedReduceLaunch<'_> { line_size_output: config_write.width, }; let problem = ReduceProblem { - vector_size: shape[self.reduce.axis] as u32, + vector_size: shape[self.reduce.axis], vector_count: reduce_count, - axis: self.reduce.axis as u32, + axis: self.reduce.axis, dtypes: ReduceDtypes { input: self.reduce.op.input.dtype.into(), output: self.reduce.op.out.dtype.into(), @@ -329,7 +329,7 @@ impl TraceRunner for FusedReduceLaunch<'_> { client, inputs, outputs, - axis: self.reduce.axis as u32, + axis: self.reduce.axis, config_fuse_read: config_read.clone(), config_fuse_write: config_write.clone(), input: self.reduce.input.clone(), @@ -356,7 +356,7 @@ struct ReduceKwArgs<'a, 'b, Run: Runtime> { client: &'b ComputeClient, inputs: GlobalArgsLaunch<'a, Run>, outputs: GlobalArgsLaunch<'a, Run>, - axis: u32, + axis: usize, blueprint: ReduceBlueprint, settings: ReduceLaunchSettings, config_fuse_read: FuseBlockConfig, @@ -413,7 +413,7 @@ fn launch_reduce( pub fn reduce_kernel( input: &FusedReduceInput, output: &mut FusedReduceOutput, - axis_reduce: u32, + axis_reduce: usize, #[comptime] blueprint: ReduceBlueprint, #[comptime] config: ReduceOperationConfig, #[define(In)] _input_dtype: StorageType, diff --git a/crates/burn-cubecl/src/kernel/binary.rs b/crates/burn-cubecl/src/kernel/binary.rs index 97267ee379..0ea46ff4bc 100644 --- a/crates/burn-cubecl/src/kernel/binary.rs +++ b/crates/burn-cubecl/src/kernel/binary.rs @@ -6,9 +6,7 @@ use crate::{ }; use burn_backend::{bf16, f16}; use cubecl::{ - calculate_cube_count_elemwise, intrinsic, - prelude::*, - std::{scalar::InputScalar, tensor::layout::linear::LinearView}, + calculate_cube_count_elemwise, intrinsic, prelude::*, std::tensor::layout::linear::LinearView, }; pub(crate) trait BinaryOpFamily: Send + Sync + 'static { diff --git a/crates/burn-cubecl/src/kernel/binary_int.rs b/crates/burn-cubecl/src/kernel/binary_int.rs index 315b271512..07a6310bb1 100644 --- a/crates/burn-cubecl/src/kernel/binary_int.rs +++ b/crates/burn-cubecl/src/kernel/binary_int.rs @@ -4,11 +4,7 @@ use crate::{ ops::{max_line_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; -use cubecl::{ - calculate_cube_count_elemwise, - prelude::*, - std::{scalar::InputScalar, tensor::layout::linear::LinearView}, -}; +use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait BinaryOpIntFamily: Send + Sync + 'static { type BinaryOp: BinaryOpInt; diff --git a/crates/burn-cubecl/src/kernel/clamp.rs b/crates/burn-cubecl/src/kernel/clamp.rs index 589188d941..46059aaf52 100644 --- a/crates/burn-cubecl/src/kernel/clamp.rs +++ b/crates/burn-cubecl/src/kernel/clamp.rs @@ -1,4 +1,4 @@ -use cubecl::{prelude::*, std::scalar::InputScalar}; +use cubecl::prelude::*; use crate::{ CubeRuntime, diff --git a/crates/burn-cubecl/src/kernel/comparison.rs b/crates/burn-cubecl/src/kernel/comparison.rs index 557b97af99..646c68489e 100644 --- a/crates/burn-cubecl/src/kernel/comparison.rs +++ b/crates/burn-cubecl/src/kernel/comparison.rs @@ -5,11 +5,7 @@ use crate::{ tensor::CubeTensor, }; use burn_backend::DType; -use cubecl::{ - calculate_cube_count_elemwise, - prelude::*, - std::{scalar::InputScalar, tensor::layout::linear::LinearView}, -}; +use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; #[cube] pub(crate) trait ComparisonOpFamily: 'static + Send + Sync { diff --git a/crates/burn-cubecl/src/kernel/contiguous.rs b/crates/burn-cubecl/src/kernel/contiguous.rs index ccc1b384d0..95f7e3e69f 100644 --- a/crates/burn-cubecl/src/kernel/contiguous.rs +++ b/crates/burn-cubecl/src/kernel/contiguous.rs @@ -83,7 +83,7 @@ fn into_contiguous_quantized( &values.as_handle_ref(), &out_values.as_handle_ref(), &tensor.shape, - scheme.num_quants() as u32, + scheme.num_quants(), DType::U32.into(), ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/conv/backward_data/implicit_gemm/launch.rs b/crates/burn-cubecl/src/kernel/conv/backward_data/implicit_gemm/launch.rs index def3395bbd..9c1284e691 100644 --- a/crates/burn-cubecl/src/kernel/conv/backward_data/implicit_gemm/launch.rs +++ b/crates/burn-cubecl/src/kernel/conv/backward_data/implicit_gemm/launch.rs @@ -6,7 +6,7 @@ use cubek::{ components::ConvSetupError, }, matmul::{ - definition::{MatmulElemType, MatmulElems, MatmulGlobalElems}, + definition::{MatmulElems, MatmulGlobalElems}, launch::MatmulInputHandleRef, }, }; @@ -107,18 +107,9 @@ pub fn launch_backwards_data( let client = out_grad.client.clone(); let dtypes = MatmulElems::from_globals(&MatmulGlobalElems { - lhs: MatmulElemType { - dtype: out_grad.dtype.into(), - quantized: false, - }, - rhs: MatmulElemType { - dtype: weights.dtype.into(), - quantized: false, - }, - out: MatmulElemType { - dtype: out_dtype.into(), - quantized: false, - }, + lhs: out_grad.dtype.into(), + rhs: weights.dtype.into(), + out: out_dtype.into(), }); let out_grad = MatmulInputHandleRef::new(out_grad.as_handle_ref(), out_grad.dtype.into()); let weights = MatmulInputHandleRef::new(weights.as_handle_ref(), weights.dtype.into()); diff --git a/crates/burn-cubecl/src/kernel/conv/backward_weight/implicit_gemm/launch.rs b/crates/burn-cubecl/src/kernel/conv/backward_weight/implicit_gemm/launch.rs index 71a69393a4..1c0f38a974 100644 --- a/crates/burn-cubecl/src/kernel/conv/backward_weight/implicit_gemm/launch.rs +++ b/crates/burn-cubecl/src/kernel/conv/backward_weight/implicit_gemm/launch.rs @@ -6,7 +6,7 @@ use cubek::{ components::ConvSetupError, }, matmul::{ - definition::{MatmulElemType, MatmulElems, MatmulGlobalElems}, + definition::{MatmulElems, MatmulGlobalElems}, launch::MatmulInputHandleRef, }, }; @@ -107,18 +107,9 @@ pub fn launch_backwards_weight( let client = input.client.clone(); let dtypes = MatmulElems::from_globals(&MatmulGlobalElems { - lhs: MatmulElemType { - dtype: input.dtype.into(), - quantized: false, - }, - rhs: MatmulElemType { - dtype: out_grad.dtype.into(), - quantized: false, - }, - out: MatmulElemType { - dtype: out_dtype.into(), - quantized: false, - }, + lhs: input.dtype.into(), + rhs: out_grad.dtype.into(), + out: out_dtype.into(), }); let input = MatmulInputHandleRef::new(input.as_handle_ref(), input.dtype.into()); let out_grad = MatmulInputHandleRef::new(out_grad.as_handle_ref(), out_grad.dtype.into()); diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs index 0818532706..c0bf4db27f 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs @@ -209,17 +209,17 @@ fn col2im( bias.as_tensor_arg(vectorization), out.as_tensor_arg(vectorization), Col2ImArgsLaunch::new( - ScalarArg::new(out_h as u32), - ScalarArg::new(out_w as u32), - ScalarArg::new(kernel_h as u32), - ScalarArg::new(kernel_w as u32), - ScalarArg::new(options.padding[0] as u32), - ScalarArg::new(options.padding[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(col_size_1 as u32), + ScalarArg::new(out_h), + ScalarArg::new(out_w), + ScalarArg::new(kernel_h), + ScalarArg::new(kernel_w), + ScalarArg::new(options.padding[0]), + ScalarArg::new(options.padding[1]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(col_size_1), ), has_bias, dtype.into(), @@ -229,20 +229,20 @@ fn col2im( #[derive(CubeLaunch, CubeType)] struct Col2ImArgs { - out_h: u32, - out_w: u32, + out_h: usize, + out_w: usize, - kernel_h: u32, - kernel_w: u32, + kernel_h: usize, + kernel_w: usize, - pad_h: u32, - pad_w: u32, - dilation_h: u32, - dilation_w: u32, - stride_h: u32, - stride_w: u32, + pad_h: usize, + pad_w: usize, + dilation_h: usize, + dilation_w: usize, + stride_h: usize, + stride_w: usize, - col_size_1: u32, + col_size_1: usize, } #[cube(launch_unchecked)] @@ -271,13 +271,13 @@ fn col2im_kernel( let x_col_start = if im_x >= kernel_extent_w { (im_x - kernel_extent_w) / args.stride_w + 1 } else { - 0u32.runtime() + 0usize.runtime() }; let x_col_end = Min::min(im_x / args.stride_w + 1, args.out_w); let y_col_start = if im_y >= kernel_extent_h { (im_y - kernel_extent_h) / args.stride_h + 1 } else { - 0u32.runtime() + 0usize.runtime() }; let y_col_end = Min::min(im_y / args.stride_h + 1, args.out_h); diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs index d0abd6e967..3d61210108 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs @@ -13,13 +13,13 @@ use cubek::convolution::components::ConvSetupError; #[derive(CubeLaunch, CubeType)] struct ConvArgs { - conv_stride_0: u32, - conv_stride_1: u32, - dilation_0: u32, - dilation_1: u32, - padding_0: u32, - padding_1: u32, - groups: u32, + conv_stride_0: usize, + conv_stride_1: usize, + dilation_0: usize, + dilation_1: usize, + padding_0: usize, + padding_1: usize, + groups: usize, } #[cube(launch)] @@ -61,10 +61,10 @@ fn conv_transpose2d_direct_kernel( let y_start = ((out_y + args.padding_0) as i32 - kms_h) / stride_0_i; let x_start = ((out_x + args.padding_1) as i32 - kms_w) / stride_1_i; - let y_end = Min::min(Max::max(kms_h + y_start + 1, 0) as u32, input.shape(2)); - let x_end = Min::min(Max::max(kms_w + x_start + 1, 0) as u32, input.shape(3)); - let y_start = Max::max(y_start, 0) as u32; - let x_start = Max::max(x_start, 0) as u32; + let y_end = Min::min(Max::max(kms_h + y_start + 1, 0) as usize, input.shape(2)); + let x_end = Min::min(Max::max(kms_w + x_start + 1, 0) as usize, input.shape(3)); + let y_start = Max::max(y_start, 0) as usize; + let x_start = Max::max(x_start, 0) as usize; let idx_input_batch = batch * input.stride(0); let idx_weight_oc = out_c * weight.stride(1); @@ -183,13 +183,13 @@ pub fn conv_transpose2d_direct( bias.as_tensor_arg(1), output.as_tensor_arg(1), ConvArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), - ScalarArg::new(options.padding[0] as u32), - ScalarArg::new(options.padding[1] as u32), - ScalarArg::new(options.groups as u32), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), + ScalarArg::new(options.padding[0]), + ScalarArg::new(options.padding[1]), + ScalarArg::new(options.groups), ), input.dtype.into(), )?; diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs index 329ede2383..96fdfac9fa 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs @@ -13,16 +13,16 @@ use burn_backend::{Shape, ops::ConvTransposeOptions}; #[derive(CubeLaunch, CubeType)] struct ConvArgs { - conv_stride_0: u32, - conv_stride_1: u32, - conv_stride_2: u32, - dilation_0: u32, - dilation_1: u32, - dilation_2: u32, - padding_0: u32, - padding_1: u32, - padding_2: u32, - groups: u32, + conv_stride_0: usize, + conv_stride_1: usize, + conv_stride_2: usize, + dilation_0: usize, + dilation_1: usize, + dilation_2: usize, + padding_0: usize, + padding_1: usize, + padding_2: usize, + groups: usize, } #[cube(launch)] @@ -68,13 +68,13 @@ fn conv_transpose3d_kernel( let y_start = ((out_y + args.padding_1) as i32 - kernel_h) / stride_1_i; let x_start = ((out_x + args.padding_2) as i32 - kernel_w) / stride_2_i; - let z_end = Min::min(Max::max(kernel_d + z_start + 1, 0) as u32, input.shape(2)); - let y_end = Min::min(Max::max(kernel_h + y_start + 1, 0) as u32, input.shape(3)); - let x_end = Min::min(Max::max(kernel_w + x_start + 1, 0) as u32, input.shape(4)); + let z_end = Min::min(Max::max(kernel_d + z_start + 1, 0) as usize, input.shape(2)); + let y_end = Min::min(Max::max(kernel_h + y_start + 1, 0) as usize, input.shape(3)); + let x_end = Min::min(Max::max(kernel_w + x_start + 1, 0) as usize, input.shape(4)); - let z_start = Max::max(z_start, 0) as u32; - let y_start = Max::max(y_start, 0) as u32; - let x_start = Max::max(x_start, 0) as u32; + let z_start = Max::max(z_start, 0) as usize; + let y_start = Max::max(y_start, 0) as usize; + let x_start = Max::max(x_start, 0) as usize; let index_input_batch = batch * input.stride(0); let index_weight_out_c = out_channel * weight.stride(1); @@ -218,16 +218,16 @@ pub(crate) fn conv_transpose3d( bias.as_tensor_arg(1), output.as_tensor_arg(1), ConvArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.stride[2] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), - ScalarArg::new(options.dilation[2] as u32), - ScalarArg::new(options.padding[0] as u32), - ScalarArg::new(options.padding[1] as u32), - ScalarArg::new(options.padding[2] as u32), - ScalarArg::new(options.groups as u32), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(options.stride[2]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), + ScalarArg::new(options.dilation[2]), + ScalarArg::new(options.padding[0]), + ScalarArg::new(options.padding[1]), + ScalarArg::new(options.padding[2]), + ScalarArg::new(options.groups), ), input.dtype.into(), )?; diff --git a/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs b/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs index 8a2b1059b8..28ca0baf42 100644 --- a/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs @@ -1,4 +1,4 @@ -use cubecl::{calculate_cube_count_elemwise, prelude::*, std::scalar::InputScalar}; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; use cubek::convolution::components::ConvSetupError; use burn_backend::{ @@ -21,20 +21,20 @@ use crate::{ #[derive(CubeLaunch, CubeType)] struct DeformConv2dArgs { - conv_stride_h: u32, - conv_stride_w: u32, - dilation_h: u32, - dilation_w: u32, + conv_stride_h: usize, + conv_stride_w: usize, + dilation_h: usize, + dilation_w: usize, padding_h: InputScalar, padding_w: InputScalar, - offset_groups: u32, + offset_groups: usize, - kernel_height: u32, - kernel_width: u32, - out_h: u32, - out_w: u32, + kernel_height: usize, + kernel_width: usize, + out_h: usize, + out_w: usize, - col_stride_0: u32, + col_stride_0: usize, } #[cube(launch)] @@ -44,8 +44,8 @@ fn deform_im2col_kernel( mask: &Tensor, columns: &mut Tensor, args: &DeformConv2dArgs, - #[comptime] kernel_h_unroll: Option, - #[comptime] kernel_w_unroll: Option, + #[comptime] kernel_h_unroll: Option, + #[comptime] kernel_w_unroll: Option, #[comptime] use_mask: bool, #[define(F)] _dtype: StorageType, ) { @@ -131,11 +131,11 @@ fn deform_im2col_kernel( #[cube] pub(crate) fn bilinear_interpolate( input: &Tensor, - height: u32, - width: u32, + height: usize, + width: usize, y: F, x: F, - offset: u32, + offset: usize, ) -> F { // To simplify code let y = f32::cast_from(y); @@ -143,26 +143,26 @@ pub(crate) fn bilinear_interpolate( let mut result = F::new(0.0); if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x { - let in_w = u32::cast_from(width); + let in_w = width; let y_low = f32::floor(y); let x_low = f32::floor(x); - let y_high = (y_low + 1.) as u32; - let x_high = (x_low + 1.) as u32; + let y_high = (y_low + 1.) as usize; + let x_high = (x_low + 1.) as usize; let zero = F::new(0.0); let v1: F = if y_low >= 0. && x_low >= 0. { - input[offset + y_low as u32 * in_w + x_low as u32] + input[offset + y_low as usize * in_w + x_low as usize] } else { zero }; let v2: F = if y_low >= 0. && x_high < width { - input[offset + y_low as u32 * in_w + x_high] + input[offset + y_low as usize * in_w + x_high] } else { zero }; let v3: F = if y_high < height && x_low >= 0. { - input[offset + y_high * in_w + x_low as u32] + input[offset + y_high * in_w + x_low as usize] } else { zero }; @@ -237,10 +237,10 @@ pub(crate) fn deform_im2col( mask.as_handle_ref().as_tensor_arg(1), output.as_handle_ref().as_tensor_arg(1), DeformConv2dArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), { let val = options.padding[0] as f32; InputScalar::new(val, dtype) @@ -249,15 +249,15 @@ pub(crate) fn deform_im2col( let val = options.padding[1] as f32; InputScalar::new(val, dtype) }, - ScalarArg::new(options.offset_groups as u32), - ScalarArg::new(kernel_height as u32), - ScalarArg::new(kernel_width as u32), - ScalarArg::new(out_height as u32), - ScalarArg::new(out_width as u32), - ScalarArg::new(output.strides[0] as u32), + ScalarArg::new(options.offset_groups), + ScalarArg::new(kernel_height), + ScalarArg::new(kernel_width), + ScalarArg::new(out_height), + ScalarArg::new(out_width), + ScalarArg::new(output.strides[0]), ), - Some(kernel_height as u32), - Some(kernel_width as u32), + Some(kernel_height), + Some(kernel_width), use_mask, dtype.into(), )?; diff --git a/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs index 94f040a33e..9787dfe53a 100644 --- a/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs @@ -16,7 +16,6 @@ use crate::{ use burn_backend::{DType, Shape, ops::DeformConvOptions}; use cubecl::{ CubeDim, CubeLaunch, calculate_cube_count_elemwise, cube, features::TypeUsage, prelude::*, - std::scalar::InputScalar, }; use cubek::{ convolution::components::ConvSetupError, @@ -273,15 +272,15 @@ fn compute_offset_and_mask_gradient( grad_offset.as_handle_ref().as_tensor_arg(1), grad_mask.as_handle_ref().as_tensor_arg(1), DeformConv2dCol2ImgCoordArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), InputScalar::new(options.padding[0] as f32, dtype.elem_type()), InputScalar::new(options.padding[1] as f32, dtype.elem_type()), - ScalarArg::new(options.offset_groups as u32), - ScalarArg::new(kernel_height as u32), - ScalarArg::new(kernel_width as u32), + ScalarArg::new(options.offset_groups), + ScalarArg::new(kernel_height), + ScalarArg::new(kernel_width), ), use_mask, dtype, @@ -294,15 +293,15 @@ fn compute_offset_and_mask_gradient( #[derive(CubeLaunch, CubeType)] struct DeformConv2dCol2ImgCoordArgs { - stride_h: u32, - stride_w: u32, - dilation_h: u32, - dilation_w: u32, + stride_h: usize, + stride_w: usize, + dilation_h: usize, + dilation_w: usize, pad_h: InputScalar, pad_w: InputScalar, - offset_groups: u32, - kernel_height: u32, - kernel_width: u32, + offset_groups: usize, + kernel_height: usize, + kernel_width: usize, } #[expect(clippy::collapsible_if)] @@ -431,8 +430,8 @@ fn deform_col2img_coord_kernel( #[cube] fn get_coordinate_weight( input: &Slice, - height: u32, - width: u32, + height: usize, + width: usize, y: F, x: F, is_y_direction: bool, @@ -453,22 +452,22 @@ fn get_coordinate_weight( let valid_x_high = x_high >= 0. && x_high < width as f32; let bottom_left = if valid_y_low && valid_x_low { - input[y_low as u32 * stride_y + x_low as u32] + input[y_low as usize * stride_y + x_low as usize] } else { F::new(0.0) }; let bottom_right = if valid_y_low && valid_x_high { - input[y_low as u32 * stride_y + x_high as u32] + input[y_low as usize * stride_y + x_high as usize] } else { F::new(0.0) }; let top_left = if valid_y_high && valid_x_low { - input[y_high as u32 * stride_y + x_low as u32] + input[y_high as usize * stride_y + x_low as usize] } else { F::new(0.0) }; let top_right = if valid_y_high && valid_x_high { - input[y_high as u32 * stride_y + x_high as u32] + input[y_high as usize * stride_y + x_high as usize] } else { F::new(0.0) }; @@ -549,19 +548,19 @@ fn compute_input_grad( columns.as_tensor_arg(1), grad_arg, DeformConv2dCol2ImgArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(options.stride[0]), + ScalarArg::new(options.stride[1]), + ScalarArg::new(options.dilation[0]), + ScalarArg::new(options.dilation[1]), InputScalar::new(options.padding[0] as f32, dtypes[0].elem_type()), InputScalar::new(options.padding[1] as f32, dtypes[0].elem_type()), - ScalarArg::new(options.offset_groups as u32), - ScalarArg::new(batch_size as u32), - ScalarArg::new(in_channels as u32), - ScalarArg::new(height as u32), - ScalarArg::new(width as u32), - ScalarArg::new(kernel_height as u32), - ScalarArg::new(kernel_width as u32), + ScalarArg::new(options.offset_groups), + ScalarArg::new(batch_size), + ScalarArg::new(in_channels), + ScalarArg::new(height), + ScalarArg::new(width), + ScalarArg::new(kernel_height), + ScalarArg::new(kernel_width), ), use_mask, dtypes, @@ -577,19 +576,19 @@ fn compute_input_grad( #[derive(CubeLaunch, CubeType)] struct DeformConv2dCol2ImgArgs { - stride_h: u32, - stride_w: u32, - dilation_h: u32, - dilation_w: u32, + stride_h: usize, + stride_w: usize, + dilation_h: usize, + dilation_w: usize, pad_h: InputScalar, pad_w: InputScalar, - offset_groups: u32, - batch_size: u32, - in_channels: u32, - height: u32, - width: u32, - kernel_height: u32, - kernel_width: u32, + offset_groups: usize, + batch_size: usize, + in_channels: usize, + height: usize, + width: usize, + kernel_height: usize, + kernel_width: usize, } #[cube(launch_unchecked)] @@ -669,8 +668,8 @@ fn deform_col2img_kernel( && F::abs(x - xp) < F::new(1.0) { let gradient_pos = - ((batch * n_in_channels + in_channel) * height + u32::cast_from(yp)) * width - + u32::cast_from(xp); + ((batch * n_in_channels + in_channel) * height + usize::cast_from(yp)) * width + + usize::cast_from(xp); let weight = (F::new(1.0) - F::abs(y - yp)) * (F::new(1.0) - F::abs(x - xp)); diff --git a/crates/burn-cubecl/src/kernel/conv/direct.rs b/crates/burn-cubecl/src/kernel/conv/direct.rs index 1395483ad6..e4a268a80d 100644 --- a/crates/burn-cubecl/src/kernel/conv/direct.rs +++ b/crates/burn-cubecl/src/kernel/conv/direct.rs @@ -1,15 +1,12 @@ use crate::ops::numeric::empty_device_optimized_dtype; use crate::{ CubeRuntime, - kernel::{ - into_contiguous_aligned, - utils::{linear_view, shape_divmod}, - }, + kernel::{into_contiguous_aligned, utils::linear_view}, ops::max_line_size, tensor::CubeTensor, }; use burn_backend::ops::{ConvOptions, conv::calculate_conv_output_sizes}; -use cubecl::std::{CubeOption, CubeOptionExpand, FastDivmod}; +use cubecl::std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs}; use cubecl::{ calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView, tensor_line_size_parallel, @@ -36,8 +33,8 @@ fn direct_conv2d_kernel( bias: CubeOption>>, output: &mut LinearView, ReadWrite>, args: Conv2dArgs, - shape_out: Sequence, - shape_out_c: FastDivmod, + shape_out: Sequence>, + shape_out_c: FastDivmod, #[comptime] has_padding: bool, #[define(E)] _dtype: StorageType, ) { @@ -50,20 +47,20 @@ fn direct_conv2d_kernel( let line_size_out = output.line_size(); let pos = ABSOLUTE_POS * line_size_out; - let in_c_per_group = weight.shape(weight.rank() - 1); + let in_c_per_group = weight.shape(weight.rank() - 1) as u32; - let (rem, out_c) = shape_out_c.div_mod(pos); + let (rem, out_c) = shape_out_c.div_mod(pos as u32); let (b, spatial_pos) = div_mod_seq(rem, &shape_out); let g = out_c / args.channels_per_group; let ic_start = in_c_per_group * g; let mut sum = match bias { - CubeOption::Some(bias) => bias[out_c / line_size_out], + CubeOption::Some(bias) => bias[out_c as usize / line_size_out], CubeOption::None => Line::empty(line_size_out).fill(E::from_int(0)), }; - let in_offs = b * input.stride(0) + ic_start; + let in_offs = b as usize * input.stride(0) + ic_start as usize; let stride_oc = weight.stride(0); @@ -74,13 +71,13 @@ fn direct_conv2d_kernel( #[unroll] for i in 0..n_spatial { - in_shape.push(input.shape(i + 1)); + in_shape.push(input.shape(i + 1) as u32); in_strides.push(input.stride(i + 1)); - kernel_shape.push(weight.shape(i + 1)); + kernel_shape.push(weight.shape(i + 1) as u32); kernel_strides.push(weight.stride(i + 1)); } - let weight_offs = out_c * stride_oc; + let weight_offs = out_c as usize * stride_oc; let loop_params = LoopParams { out_pos: spatial_pos, @@ -101,7 +98,7 @@ fn direct_conv2d_kernel( true, weight_offs, &loop_params, - 0u32, + 0usize, has_padding, ); @@ -112,13 +109,13 @@ fn direct_conv2d_kernel( struct LoopParams { out_pos: Sequence, in_shape: Sequence, - in_strides: Sequence, + in_strides: Sequence, kernel_shape: Sequence, - kernel_strides: Sequence, + kernel_strides: Sequence, conv_params: Sequence, in_c_per_group: u32, - stride_oc: u32, + stride_oc: usize, } #[cube] @@ -126,11 +123,11 @@ fn kernel_loop( input: &Tensor>, weight: &Tensor>, sum: &mut Line, - in_offs: u32, + in_offs: usize, in_bounds: bool, - weight_offs: u32, + weight_offs: usize, params: &LoopParams, - #[comptime] kernel_dim: u32, + #[comptime] kernel_dim: usize, #[comptime] has_padding: bool, ) { if comptime![kernel_dim < params.kernel_shape.len()] { @@ -142,8 +139,8 @@ fn kernel_loop( for pos in 0..*params.kernel_shape.index(kernel_dim) { let in_pos = (out_idx * conv.stride + pos * conv.dilation) as i32 - conv.padding; - let in_offs = in_offs + in_pos as u32 * stride; - let weight_offs = weight_offs + pos * k_stride; + let in_offs = in_offs + in_pos as usize * stride; + let weight_offs = weight_offs + pos as usize * k_stride; let mut in_bounds = in_bounds; if has_padding { @@ -181,19 +178,19 @@ fn kernel_loop_inner( input: &Tensor>, weight: &Tensor>, sum: &mut Line, - in_offs: u32, + in_offs: usize, in_bounds: bool, - weight_offs: u32, + weight_offs: usize, in_c_per_group: u32, - stride_oc: u32, + stride_oc: usize, ) { let line_size_in = input.line_size(); let line_size_out = sum.size(); if in_bounds { - for in_c in range_stepped(0, in_c_per_group, line_size_in) { - let in_pos = in_offs + in_c; - let mut weight_pos = weight_offs + in_c; + for in_c in range_stepped(0, in_c_per_group, line_size_in as u32) { + let in_pos = in_offs + in_c as usize; + let mut weight_pos = weight_offs + in_c as usize; let val = input[in_pos / line_size_in]; @@ -225,6 +222,7 @@ pub fn conv_direct( bias: Option>, options: ConvOptions, ) -> Result, ConvSetupError> { + let client = input.client.clone(); let out_dtype = input.dtype; let rank = input.shape.num_dims(); let dim_c = rank - 1; @@ -276,9 +274,11 @@ pub fn conv_direct( // Use channels_per_group instead of in_channels to avoid issues here let line_size_in = max_line_size(&weight); - let mut shape_out = shape_divmod(&output); - shape_out.values.remove(0); - let shape_out_c = shape_out.values.pop().unwrap(); + let shape_out = output.shape[1..dim_c] + .iter() + .map(|s| FastDivmodArgs::::new(&client, *s as u32)) + .collect(); + let shape_out_c = FastDivmodArgs::::new(&client, out_channels as u32); let mut conv_params = SequenceArg::new(); @@ -292,7 +292,7 @@ pub fn conv_direct( let bias = bias.as_ref().map(|b| b.as_tensor_arg(line_size_out)); - let working_units = output.shape.num_elements() / line_size_out as usize; + let working_units = output.shape.num_elements() / line_size_out; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); @@ -317,7 +317,7 @@ pub fn conv_direct( } #[cube] -pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence) -> (u32, Sequence) { +pub(crate) fn div_mod_seq(pos: u32, shape: &Sequence>) -> (u32, Sequence) { let rank = comptime![shape.len()]; let mut offs = pos; let mut out = Sequence::new(); diff --git a/crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/launch.rs b/crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/launch.rs index 98bd72bfba..a4d7b98b72 100644 --- a/crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/launch.rs +++ b/crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/launch.rs @@ -6,7 +6,7 @@ use cubek::{ components::ConvSetupError, forward, }, matmul::{ - definition::{MatmulElemType, MatmulElems, MatmulGlobalElems}, + definition::{MatmulElems, MatmulGlobalElems}, launch::MatmulInputHandleRef, }, }; @@ -141,18 +141,9 @@ pub fn launch_convolution_forward( let client = input.client.clone(); let dtypes = MatmulElems::from_globals(&MatmulGlobalElems { - lhs: MatmulElemType { - dtype: input.dtype.into(), - quantized: false, - }, - rhs: MatmulElemType { - dtype: weight.dtype.into(), - quantized: false, - }, - out: MatmulElemType { - dtype: out_dtype.into(), - quantized: false, - }, + lhs: input.dtype.into(), + rhs: weight.dtype.into(), + out: out_dtype.into(), }); let input = MatmulInputHandleRef::new(input.as_handle_ref(), input.dtype.into()); let weight = MatmulInputHandleRef::new(weight.as_handle_ref(), weight.dtype.into()); diff --git a/crates/burn-cubecl/src/kernel/conv/im2col.rs b/crates/burn-cubecl/src/kernel/conv/im2col.rs index 815bd295b0..08be204461 100644 --- a/crates/burn-cubecl/src/kernel/conv/im2col.rs +++ b/crates/burn-cubecl/src/kernel/conv/im2col.rs @@ -30,7 +30,7 @@ pub(crate) fn batches_per_run( let cube_count_per_batch = out_shape.div_ceil(plane_size); let max_cube_count = u16::MAX as usize; - let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size); + let max_simultaneous = Ord::min(max_cube_count / cube_count_per_batch, batch_size); if max_simultaneous == 0 { return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static( cube_count_per_batch as u32, diff --git a/crates/burn-cubecl/src/kernel/grid_sample/base.rs b/crates/burn-cubecl/src/kernel/grid_sample/base.rs index 4f7d0e15ba..7bfdabdbb9 100644 --- a/crates/burn-cubecl/src/kernel/grid_sample/base.rs +++ b/crates/burn-cubecl/src/kernel/grid_sample/base.rs @@ -45,9 +45,9 @@ impl From for PaddingMode { #[cube] pub(crate) fn fetch_value( input: &Tensor, - base: u32, - stride_h: u32, - stride_w: u32, + base: usize, + stride_h: usize, + stride_w: usize, y: i32, x: i32, h: i32, @@ -67,17 +67,17 @@ pub(crate) fn fetch_value( #[cube] pub(crate) fn fetch_with_zeros( input: &Tensor, - base: u32, - stride_h: u32, - stride_w: u32, + base: usize, + stride_h: usize, + stride_w: usize, y: i32, x: i32, h: i32, w: i32, ) -> F { let in_bounds = x >= 0 && x < w && y >= 0 && y < h; - let x_clamped = Min::min(Max::max(x, 0), w - 1) as u32; - let y_clamped = Min::min(Max::max(y, 0), h - 1) as u32; + let x_clamped = Min::min(Max::max(x, 0), w - 1) as usize; + let y_clamped = Min::min(Max::max(y, 0), h - 1) as usize; let idx = base + y_clamped * stride_h + x_clamped * stride_w; select(in_bounds, input[idx], F::new(0.0)) } @@ -86,16 +86,16 @@ pub(crate) fn fetch_with_zeros( #[cube] pub(crate) fn fetch_with_border( input: &Tensor, - base: u32, - stride_h: u32, - stride_w: u32, + base: usize, + stride_h: usize, + stride_w: usize, y: i32, x: i32, h: i32, w: i32, ) -> F { - let x_clamped = Min::min(Max::max(x, 0), w - 1) as u32; - let y_clamped = Min::min(Max::max(y, 0), h - 1) as u32; + let x_clamped = Min::min(Max::max(x, 0), w - 1) as usize; + let y_clamped = Min::min(Max::max(y, 0), h - 1) as usize; let idx = base + y_clamped * stride_h + x_clamped * stride_w; input[idx] } @@ -105,9 +105,9 @@ pub(crate) fn fetch_with_border( #[cube] pub(crate) fn fetch_with_reflection( input: &Tensor, - base: u32, - stride_h: u32, - stride_w: u32, + base: usize, + stride_h: usize, + stride_w: usize, y: i32, x: i32, h: i32, @@ -122,7 +122,7 @@ pub(crate) fn fetch_with_reflection( /// Reflect an integer index that may be out of bounds. /// After float reflection, indices can be up to 2 steps out for bicubic (1 step for bilinear). #[cube] -fn reflect_coord_bounded(idx: i32, size: i32) -> u32 { +fn reflect_coord_bounded(idx: i32, size: i32) -> usize { let max_idx = size - 1; let neg_reflected = -idx - 1; let pos_reflected = 2 * max_idx + 1 - idx; @@ -131,7 +131,7 @@ fn reflect_coord_bounded(idx: i32, size: i32) -> u32 { neg_reflected, select(idx > max_idx, pos_reflected, idx), ); - Min::min(Max::max(result, 0), max_idx) as u32 + Min::min(Max::max(result, 0), max_idx) as usize } /// Reflect a float coordinate into the valid sampling range. diff --git a/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs b/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs index 7e92301fc5..5ec1e2a0e4 100644 --- a/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs +++ b/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs @@ -14,10 +14,10 @@ use super::base::{PaddingMode, fetch_value, reflect_coord}; /// 3. For each channel: fetch 4 corner values, interpolate, and write output #[cube(launch)] fn grid_sample_bilinear_kernel( - input: &Tensor, // [N, C, H_in, W_in] - grid: &Tensor, // [N, H_out, W_out, 2] - output: &mut Tensor, // [N, C, H_out, W_out] - shape_spatial: Sequence, // [N, H_out, W_out] for thread decomposition + input: &Tensor, // [N, C, H_in, W_in] + grid: &Tensor, // [N, H_out, W_out, 2] + output: &mut Tensor, // [N, C, H_out, W_out] + shape_spatial: Sequence>, // [N, H_out, W_out] for thread decomposition #[comptime] align_corners: bool, #[comptime] pad_mode: PaddingMode, #[define(F)] _dtype: StorageType, @@ -30,15 +30,17 @@ fn grid_sample_bilinear_kernel( } // Decompose spatial index into (n, h_out, w_out) - let (rem, w_out) = shape_spatial.index(2).div_mod(spatial_idx); - let (n, h_out) = shape_spatial.index(1).div_mod(rem); + let (rem, w_out) = shape_spatial[2].div_mod(spatial_idx as u32); + let (n, h_out) = shape_spatial[1].div_mod(rem); - let channels = input.shape(1); - let h_in = input.shape(2); - let w_in = input.shape(3); + let channels = input.shape(1) as u32; + let h_in = input.shape(2) as u32; + let w_in = input.shape(3) as u32; // Read grid coordinates once per spatial position - let grid_offset = n * grid.stride(0) + h_out * grid.stride(1) + w_out * grid.stride(2); + let grid_offset = n as usize * grid.stride(0) + + h_out as usize * grid.stride(1) + + w_out as usize * grid.stride(2); let gx = grid[grid_offset]; // x coordinate in [-1, 1] let gy = grid[grid_offset + 1]; // y coordinate in [-1, 1] @@ -95,12 +97,13 @@ fn grid_sample_bilinear_kernel( let out_stride_w = output.stride(3); // Base offsets for this spatial position - let in_base_n = n * stride_n; - let out_base_spatial = n * out_stride_n + h_out * out_stride_h + w_out * out_stride_w; + let in_base_n = n as usize * stride_n; + let out_base_spatial = + n as usize * out_stride_n + h_out as usize * out_stride_h + w_out as usize * out_stride_w; // Loop over all channels - grid coords and weights are reused for c in 0..channels { - let in_base = in_base_n + c * stride_c; + let in_base = in_base_n + c as usize * stride_c; let v00 = fetch_value( input, in_base, stride_h, stride_w, y0, x0, h_in, w_in, pad_mode, @@ -118,7 +121,7 @@ fn grid_sample_bilinear_kernel( // Bilinear interpolation let result = wx_ * wy_ * v00 + wx_ * wy * v01 + wx * wy_ * v10 + wx * wy * v11; - let out_idx = out_base_spatial + c * out_stride_c; + let out_idx = out_base_spatial + c as usize * out_stride_c; output[out_idx] = result; } } diff --git a/crates/burn-cubecl/src/kernel/index/flip.rs b/crates/burn-cubecl/src/kernel/index/flip.rs index 3a9daa8a87..85648a7a9a 100644 --- a/crates/burn-cubecl/src/kernel/index/flip.rs +++ b/crates/burn-cubecl/src/kernel/index/flip.rs @@ -2,7 +2,6 @@ use crate::{ CubeRuntime, kernel::into_contiguous, ops::numeric::empty_device_dtype, tensor::CubeTensor, }; use burn_backend::DType; -use cubecl::std::scalar::InputScalar; use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked)] @@ -10,7 +9,7 @@ fn flip_kernel( input: &Tensor, output: &mut Tensor, indices: Sequence, - #[comptime] rank: u32, + #[comptime] rank: usize, #[define(E, Bool)] _dtypes: [StorageType; 2], ) { if ABSOLUTE_POS >= output.len() { @@ -80,7 +79,7 @@ pub(crate) fn flip_on_output( tensor.as_tensor_arg(1), output.as_tensor_arg(1), indices_sequence, - ndims as u32, + ndims, [dtype_input.into(), dtype_bool.into()], ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/gather.rs b/crates/burn-cubecl/src/kernel/index/gather.rs index e83307223d..e78f38918e 100644 --- a/crates/burn-cubecl/src/kernel/index/gather.rs +++ b/crates/burn-cubecl/src/kernel/index/gather.rs @@ -21,7 +21,7 @@ fn gather_kernel( indices: &LinearView>, output: &mut Tensor>, out_layout: LinearLayout, - dim: &u32, + dim: usize, #[define(T, I)] _dtypes: [StorageType; 2], ) { if !indices.is_in_bounds(ABSOLUTE_POS) { @@ -31,17 +31,17 @@ fn gather_kernel( let index = indices[ABSOLUTE_POS]; let out_pos = out_layout.to_source_pos(ABSOLUTE_POS); - let stride = input.stride(*dim); - let mut offset = u32::cast_from(index); + let stride = input.stride(dim); + let mut offset = usize::cast_from(index); offset *= stride; - if *dim > 0 { - let offset_before = index_offset_with_layout(input, output, out_pos, 0, *dim, false); + if dim > 0 { + let offset_before = index_offset_with_layout(input, output, out_pos, 0, dim, false); offset += offset_before; } let offset_after = - index_offset_with_layout(input, output, out_pos, *dim + 1, input.rank(), false); + index_offset_with_layout(input, output, out_pos, dim + 1, input.rank(), false); offset += offset_after; output[out_pos] = input[offset]; } @@ -72,7 +72,7 @@ pub(crate) fn gather( linear_view(&indices, 1), output.as_tensor_arg(1), linear_layout(&output, 1), - ScalarArg::new(dim as u32), + ScalarArg::new(dim), [tensor.dtype.into(), indices.dtype.into()], ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/repeat_dim.rs b/crates/burn-cubecl/src/kernel/index/repeat_dim.rs index 768973c4fc..a75cbd7f54 100644 --- a/crates/burn-cubecl/src/kernel/index/repeat_dim.rs +++ b/crates/burn-cubecl/src/kernel/index/repeat_dim.rs @@ -5,7 +5,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; fn repeat_dim_kernel( input: &Tensor, output: &mut Tensor, - dim: u32, + dim: usize, #[define(E)] _dtype: StorageType, ) { if ABSOLUTE_POS >= output.len() { @@ -56,7 +56,7 @@ pub(crate) fn repeat_dim( cube_dim, input.as_tensor_arg(1), output.as_tensor_arg(1), - ScalarArg::new(dim as u32), + ScalarArg::new(dim), output.dtype.into(), ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/scatter.rs b/crates/burn-cubecl/src/kernel/index/scatter.rs index 8f6cc62d9c..c8dea8cbe6 100644 --- a/crates/burn-cubecl/src/kernel/index/scatter.rs +++ b/crates/burn-cubecl/src/kernel/index/scatter.rs @@ -11,18 +11,18 @@ fn scatter_kernel( input: &mut Tensor, indices: &Tensor, value: &Tensor, - dim: &u32, + dim: usize, #[define(T, I)] _dtypes: [StorageType; 2], ) { - let stride_input = input.stride(*dim); - let shape_value = value.shape(*dim); + let stride_input = input.stride(dim); + let shape_value = value.shape(dim); let mut offset_input = 0; let mut offset_value = 0; let mut num_elems = 1; for i in 0..value.rank() { - let shouldnt_skip = i != *dim; + let shouldnt_skip = i != dim; if shouldnt_skip { let shape_input_loop = input.shape(i); let shape_value_loop = value.shape(i); @@ -54,7 +54,7 @@ fn scatter_kernel( idx += offset_value; let result_value = value[idx]; - let result_indices = u32::cast_from(indices[idx]); + let result_indices = usize::cast_from(indices[idx]); let mut index_input = stride_input * result_indices; index_input += offset_input; @@ -121,7 +121,7 @@ pub(crate) fn scatter( tensor.as_tensor_arg(1), indices.as_tensor_arg(1), value.as_tensor_arg(1), - ScalarArg::new(dim as u32), + ScalarArg::new(dim), [tensor.dtype.into(), indices.dtype.into()], ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/select.rs b/crates/burn-cubecl/src/kernel/index/select.rs index 64347c7787..42483ac4d1 100644 --- a/crates/burn-cubecl/src/kernel/index/select.rs +++ b/crates/burn-cubecl/src/kernel/index/select.rs @@ -9,7 +9,7 @@ fn select_kernel( input: &Tensor, indices: &Tensor, output: &mut Tensor, - dim: u32, + dim: usize, #[define(T, I)] _dtypes: [StorageType; 2], ) { if ABSOLUTE_POS >= output.len() { @@ -22,7 +22,7 @@ fn select_kernel( let mut offset_local = ABSOLUTE_POS / output.stride(i) % output.shape(i); if i == dim { - offset_local = u32::cast_from(indices[offset_local]); + offset_local = usize::cast_from(indices[offset_local]); } offset_input += offset_local * input.stride(i); @@ -69,7 +69,7 @@ pub(crate) fn select( indices.dtype.size(), ), output.as_tensor_arg(1), - ScalarArg::new(dim as u32), + ScalarArg::new(dim), [tensor.dtype.into(), indices.dtype.into()], ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/select_assign.rs b/crates/burn-cubecl/src/kernel/index/select_assign.rs index 3ad68fb054..31e756927e 100644 --- a/crates/burn-cubecl/src/kernel/index/select_assign.rs +++ b/crates/burn-cubecl/src/kernel/index/select_assign.rs @@ -8,13 +8,12 @@ fn select_assign_kernel( tensor: &mut Tensor, indices: &Tensor, value: &Tensor, - dim: &u32, + dim: usize, #[define(F, I)] _dtypes: [StorageType; 2], ) { - let dim = *dim; - let mut offset_tensor = 0u32; - let mut offset_value = 0u32; - let mut num_elems = 1u32; + let mut offset_tensor = 0; + let mut offset_value = 0; + let mut num_elems = 1; // Calculate offsets and num_elems for i in 0..tensor.rank() { @@ -39,7 +38,7 @@ fn select_assign_kernel( // Main operation for i in 0..value.shape(dim) { - let index_tensor = u32::cast_from(indices[i]) * strides_tensor_dim + offset_tensor; + let index_tensor = usize::cast_from(indices[i]) * strides_tensor_dim + offset_tensor; let index_value = i * strides_value_dim + offset_value; let value = Op::BinaryOp::::execute( @@ -104,7 +103,7 @@ pub(crate) fn select_assign( indices.dtype.size(), ), value.as_tensor_arg(1), - ScalarArg::new(dim as u32), + ScalarArg::new(dim), [tensor.dtype.into(), indices.dtype.into()], ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/kernel/index/slice.rs b/crates/burn-cubecl/src/kernel/index/slice.rs index 9041ad1642..a872478713 100644 --- a/crates/burn-cubecl/src/kernel/index/slice.rs +++ b/crates/burn-cubecl/src/kernel/index/slice.rs @@ -58,8 +58,8 @@ pub fn slice(tensor: CubeTensor, indices: &[Range]) -> fn slice_kernel( input: &Tensor, output: &mut LinearView, - out_shape: Sequence, - indices: Sequence, + out_shape: Sequence>, + indices: Sequence, #[define(E)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { @@ -73,11 +73,10 @@ fn slice_kernel( #[unroll] for i in 0..rank { // Iterate in reverse to use divmod - let i = unwrap(i); - let dim = comptime![rank - i - 1]; + let dim = rank - i - 1; - let range_start = *indices.index(dim); - let (rem, offset_local) = out_shape.index(dim).div_mod(offset_output); + let range_start = indices[dim]; + let (rem, offset_local) = out_shape[dim].div_mod(offset_output); offset_output = rem; let offset_local = offset_local + range_start; @@ -94,11 +93,11 @@ pub(crate) fn slice_on_output( indices: &[Range], ) -> CubeTensor { let ndims = tensor.shape.num_dims(); - let mut indices_sequence = SequenceArg::::new(); + let mut indices_sequence = SequenceArg::::new(); for i in 0..ndims { let start = indices.get(i).map(|index| index.start).unwrap_or(0); - indices_sequence.push(ScalarArg::new(start as u32)); + indices_sequence.push(ScalarArg::new(start)); } let working_units = output.shape.num_elements(); @@ -127,9 +126,9 @@ pub(crate) fn slice_on_output( fn slice_with_steps_kernel( input: &Tensor, output: &mut LinearView, - out_shape: Sequence, - starts: Sequence, - ends: Sequence, + out_shape: Sequence>, + starts: Sequence, + ends: Sequence, steps: Sequence, #[define(E)] _dtype: StorageType, ) { @@ -145,21 +144,20 @@ fn slice_with_steps_kernel( #[unroll] for i in 0..rank { // Iterate in reverse to use divmod - let i = unwrap(i); - let dim = comptime![rank - i - 1]; - let start = *starts.index(dim); - let end = *ends.index(dim); - let step = *steps.index(dim); + let dim = rank - i - 1; + let start = starts[dim]; + let end = ends[dim]; + let step = steps[dim]; - let (rem, output_idx) = out_shape.index(dim).div_mod(output_offset); + let (rem, output_idx) = out_shape[dim].div_mod(output_offset); output_offset = rem; let input_idx = if step > 0 { // Forward stepping - start + output_idx * (step as u32) + start + output_idx * (step as usize) } else { // Backward stepping - start from end-1 - let abs_step = (-step) as u32; + let abs_step = (-step) as usize; let end_minus_1 = end - 1; end_minus_1 - output_idx * abs_step }; @@ -197,21 +195,21 @@ pub fn slice_with_steps(tensor: CubeTensor, slices: &[Slice]) ); // Prepare three separate sequences for kernel - let mut starts = SequenceArg::::new(); - let mut ends = SequenceArg::::new(); + let mut starts = SequenceArg::::new(); + let mut ends = SequenceArg::::new(); let mut steps = SequenceArg::::new(); for (dim, slice) in slices.iter().enumerate() { let range = slice.to_range(tensor.shape[dim]); - starts.push(ScalarArg::new(range.start as u32)); - ends.push(ScalarArg::new(range.end as u32)); + starts.push(ScalarArg::new(range.start)); + ends.push(ScalarArg::new(range.end)); steps.push(ScalarArg::new(slice.step as i32)); } // Pad with default values if needed to match tensor dimensions for dim in slices.len()..tensor.shape.num_dims() { starts.push(ScalarArg::new(0)); - ends.push(ScalarArg::new(tensor.shape[dim] as u32)); + ends.push(ScalarArg::new(tensor.shape[dim])); steps.push(ScalarArg::new(1)); } diff --git a/crates/burn-cubecl/src/kernel/index/slice_assign.rs b/crates/burn-cubecl/src/kernel/index/slice_assign.rs index 26b11172c2..527011c41c 100644 --- a/crates/burn-cubecl/src/kernel/index/slice_assign.rs +++ b/crates/burn-cubecl/src/kernel/index/slice_assign.rs @@ -13,8 +13,8 @@ use cubecl::{ fn slice_assign_kernel( input: &mut Tensor>, value: &LinearView>, - slice_shape: Sequence, - slice_offsets: Sequence, + slice_shape: Sequence>, + slice_offsets: Sequence, #[define(E)] _dtype: StorageType, ) { if !value.is_in_bounds(ABSOLUTE_POS) { @@ -27,21 +27,17 @@ fn slice_assign_kernel( let mut offset_remainder = ABSOLUTE_POS * line_size; let mut offset_input = 0; - let mut i = comptime![0]; - #[allow(clippy::explicit_counter_loop)] #[unroll] - for _ in 0..rank { - let dim = comptime![rank - i - 1]; - let (rem, offset_local) = slice_shape.index(dim).div_mod(offset_remainder); + for i in 0..rank { + let dim = rank - i - 1; + let (rem, offset_local) = slice_shape[dim].div_mod(offset_remainder); - let range_start = *slice_offsets.index(dim); + let range_start = slice_offsets[dim]; let offset_local_input = offset_local + range_start; offset_input += offset_local_input * input.stride(dim); offset_remainder = rem; - - comptime![i += 1;] } // Value tensor is accessed linearly since it's a LinearView @@ -53,9 +49,9 @@ fn slice_assign_kernel( fn slice_assign_with_steps_kernel( input: &mut Tensor, value: &LinearView, - value_shape: Sequence, - starts: Sequence, - ends: Sequence, + value_shape: Sequence>, + starts: Sequence, + ends: Sequence, steps: Sequence, #[define(E)] _dtype: StorageType, ) { @@ -71,22 +67,21 @@ fn slice_assign_with_steps_kernel( #[unroll] for i in 0..rank { // Iterate in reverse to use divmod - let i = unwrap(i); - let dim = comptime![rank - i - 1]; - let start = *starts.index(dim); - let end = *ends.index(dim); - let step = *steps.index(dim); + let dim = rank - i - 1; + let start = starts[dim]; + let end = ends[dim]; + let step = steps[dim]; - let (rem, value_idx) = value_shape.index(dim).div_mod(value_offset); + let (rem, value_idx) = value_shape[dim].div_mod(value_offset); value_offset = rem; let input_idx = if step > 0 { // Forward stepping - start + value_idx * (step as u32) + start + value_idx * (step as usize) } else if step < 0 { // Backward stepping - start from end-1 // For negative steps, we iterate backwards through the selected indices - let abs_step = (-step) as u32; + let abs_step = (-step) as usize; let end_minus_1 = end - 1; end_minus_1 - value_idx * abs_step } else { @@ -135,7 +130,7 @@ pub(crate) fn slice_assign( *R::supported_line_sizes() .iter() .filter(|it| { - let it = **it as usize; + let it = **it; shape.is_multiple_of(it) && strides_compatible(&tensor.strides, it) && strides_compatible(&value.strides, it) @@ -147,8 +142,8 @@ pub(crate) fn slice_assign( 1 }; - let mut shape = SequenceArg::::new(); - let mut offsets = SequenceArg::::new(); + let mut shape = SequenceArg::>::new(); + let mut offsets = SequenceArg::::new(); for i in 0..ndims { let slice = indices.get(i).cloned().unwrap_or(burn_backend::Slice { @@ -160,11 +155,11 @@ pub(crate) fn slice_assign( let end = slice.end.unwrap_or(tensor.shape[i] as isize); let length = (end - slice.start) as usize; - shape.push(FastDivmodArgs::new(&client, length as u32)); - offsets.push(ScalarArg::new(start as u32)); + shape.push(FastDivmodArgs::::new(&client, length)); + offsets.push(ScalarArg::new(start)); } - let working_units = value.shape.num_elements() / line_size as usize; + let working_units = value.shape.num_elements() / line_size; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); @@ -205,21 +200,21 @@ pub(crate) fn slice_assign_with_steps( }; // Prepare sequences for kernel - let mut starts = SequenceArg::::new(); - let mut ends = SequenceArg::::new(); + let mut starts = SequenceArg::::new(); + let mut ends = SequenceArg::::new(); let mut steps = SequenceArg::::new(); for (dim, slice) in slices.iter().enumerate() { let range = slice.to_range(tensor.shape[dim]); - starts.push(ScalarArg::new(range.start as u32)); - ends.push(ScalarArg::new(range.end as u32)); + starts.push(ScalarArg::new(range.start)); + ends.push(ScalarArg::new(range.end)); steps.push(ScalarArg::new(slice.step as i32)); } // Pad with default values if needed to match tensor dimensions for dim in slices.len()..tensor.shape.num_dims() { starts.push(ScalarArg::new(0)); - ends.push(ScalarArg::new(tensor.shape[dim] as u32)); + ends.push(ScalarArg::new(tensor.shape[dim])); steps.push(ScalarArg::new(1)); } diff --git a/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs b/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs index 98e5624698..84e71d8f9c 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs @@ -15,7 +15,7 @@ use crate::{ fn interpolate_bicubic_kernel( input: &Tensor>, output: &mut Tensor>, - shape_out: Sequence, + shape_out: Sequence>, out_layout: LinearLayout, #[define(F)] _dtype: StorageType, ) { @@ -26,9 +26,9 @@ fn interpolate_bicubic_kernel( let line_size = input.line_size(); let out_idx = out_layout.to_source_pos(ABSOLUTE_POS); - let (rem, c) = shape_out.index(3).div_mod(ABSOLUTE_POS * line_size); - let (rem, x) = shape_out.index(2).div_mod(rem); - let (b, y) = shape_out.index(1).div_mod(rem); + let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size); + let (rem, x) = shape_out[2].div_mod(rem); + let (b, y) = shape_out[1].div_mod(rem); let input_height = input.shape(1) - 1; let output_height = f32::cast_from(Max::max(output.shape(1) - 1, 1)); @@ -36,7 +36,7 @@ fn interpolate_bicubic_kernel( let frac = f32::cast_from(numerator / output_height); let y_in_f = Floor::floor(frac); - let y_in = u32::cast_from(y_in_f); + let y_in = usize::cast_from(y_in_f); let yw = Line::empty(line_size).fill(F::cast_from(frac - y_in_f)); let y0 = select(y_in != 0, y_in - 1, 0); @@ -49,7 +49,7 @@ fn interpolate_bicubic_kernel( let numerator = f32::cast_from(x * input_width); let frac = numerator / output_width; let x_in_f = Floor::floor(frac); - let x_in = u32::cast_from(x_in_f); + let x_in = usize::cast_from(x_in_f); let xw = Line::empty(line_size).fill(F::cast_from(frac - x_in_f)); let x0 = select(x_in != 0, x_in - 1, 0); diff --git a/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs b/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs index 50cf80ac29..216817a5dc 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs @@ -15,7 +15,7 @@ use crate::{ fn interpolate_bilinear_kernel( input: &Tensor>, output: &mut Tensor>, - shape_out: Sequence, + shape_out: Sequence>, out_layout: LinearLayout, #[define(F)] _dtype: StorageType, ) { @@ -26,9 +26,9 @@ fn interpolate_bilinear_kernel( let line_size = input.line_size(); let out_idx = out_layout.to_source_pos(ABSOLUTE_POS); - let (rem, c) = shape_out.index(3).div_mod(ABSOLUTE_POS * line_size); - let (rem, x) = shape_out.index(2).div_mod(rem); - let (b, y) = shape_out.index(1).div_mod(rem); + let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size); + let (rem, x) = shape_out[2].div_mod(rem); + let (b, y) = shape_out[1].div_mod(rem); let numerator = (input.shape(1) - 1) as f32; let denominator = Max::max(output.shape(1) - 1, 1) as f32; @@ -42,8 +42,8 @@ fn interpolate_bilinear_kernel( let yw_ = Line::empty(line_size).fill(F::new(1.0) - yw); let yw = Line::empty(line_size).fill(yw); let y0_ok = v0 >= 0.0; - let y0 = v0 as u32; - let y1 = v1 as u32; + let y0 = v0 as usize; + let y1 = v1 as usize; let numerator = f32::cast_from(input.shape(2) - 1); let denominator = f32::cast_from(Max::max(output.shape(2) - 1, 1)); @@ -55,8 +55,8 @@ fn interpolate_bilinear_kernel( let xw_ = Line::empty(line_size).fill(F::new(1.0) - xw); let xw = Line::empty(line_size).fill(xw); let x0_ok = v0 >= 0.0; - let x0 = v0 as u32; - let x1 = v1 as u32; + let x0 = v0 as usize; + let x1 = v1 as usize; let index_base = b * input.stride(0) + c * input.stride(3); diff --git a/crates/burn-cubecl/src/kernel/interpolate/nearest.rs b/crates/burn-cubecl/src/kernel/interpolate/nearest.rs index 8ebc2a3f6a..a60f61e743 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/nearest.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/nearest.rs @@ -15,7 +15,7 @@ use crate::{ fn interpolate_nearest_kernel( input: &Tensor>, output: &mut Tensor>, - shape_out: Sequence, + shape_out: Sequence>, out_layout: LinearLayout, #[define(F)] _dtype: StorageType, ) { @@ -31,16 +31,16 @@ fn interpolate_nearest_kernel( let (h_in, w_in) = (input.shape(1) as f32, input.shape(2) as f32); let (h_out, w_out) = (output.shape(1) as f32, output.shape(2) as f32); - let (rem, c) = shape_out.index(3).div_mod(out_pos); - let (rem, x) = shape_out.index(2).div_mod(rem); - let (b, y) = shape_out.index(1).div_mod(rem); + let (rem, c) = shape_out[3].div_mod(out_pos); + let (rem, x) = shape_out[2].div_mod(rem); + let (b, y) = shape_out[1].div_mod(rem); let y = y as f32 * (h_in / h_out); let x = x as f32 * (w_in / w_out); let in_idx = b * input.stride(0) - + y as u32 * input.stride(1) - + x as u32 * input.stride(2) + + y as usize * input.stride(1) + + x as usize * input.stride(2) + c * input.stride(3); output[out_idx] = input[in_idx / line_size]; diff --git a/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs b/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs index 82ab2c6072..b0dab1d6aa 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs @@ -15,7 +15,7 @@ use crate::{ fn interpolate_nearest_backward_kernel( grad: &Tensor>, output: &mut Tensor>, - shape_out: Sequence, + shape_out: Sequence>, out_layout: LinearLayout, #[define(F)] _dtype: StorageType, ) { @@ -31,9 +31,9 @@ fn interpolate_nearest_backward_kernel( let grad_h = grad.shape(1); let grad_w = grad.shape(2); - let (rem, c) = shape_out.index(3).div_mod(ABSOLUTE_POS * line_size); - let (rem, out_x) = shape_out.index(2).div_mod(rem); - let (b, out_y) = shape_out.index(1).div_mod(rem); + let (rem, c) = shape_out[3].div_mod(ABSOLUTE_POS * line_size); + let (rem, out_x) = shape_out[2].div_mod(rem); + let (b, out_y) = shape_out[1].div_mod(rem); let grad_y_start = start_index::(out_y, grad_h, out_h); let grad_y_end = end_index::(out_y, grad_h, out_h); @@ -56,18 +56,18 @@ fn interpolate_nearest_backward_kernel( } #[cube] -fn start_index(input_index: u32, output_size: u32, input_size: u32) -> u32 { +fn start_index(input_index: usize, output_size: usize, input_size: usize) -> usize { let numerator = F::cast_from(input_index * output_size); let div: F = Ceil::ceil(numerator / F::cast_from(input_size)); - u32::cast_from(div) + usize::cast_from(div) } #[cube] -fn end_index(input_index: u32, output_size: u32, input_size: u32) -> u32 { +fn end_index(input_index: usize, output_size: usize, input_size: usize) -> usize { let numerator = F::cast_from((input_index + 1) * output_size); let div: F = Ceil::ceil(numerator / F::cast_from(input_size)); - let index = u32::cast_from(div); + let index = usize::cast_from(div); Min::min(output_size, index) } diff --git a/crates/burn-cubecl/src/kernel/mask/base.rs b/crates/burn-cubecl/src/kernel/mask/base.rs index df08bed25f..fbedacd78b 100644 --- a/crates/burn-cubecl/src/kernel/mask/base.rs +++ b/crates/burn-cubecl/src/kernel/mask/base.rs @@ -1,5 +1,5 @@ use burn_backend::DType; -use cubecl::std::scalar::InputScalar; +use cubecl::prelude::InputScalar; use super::{MaskFillStrategy, mask_where::MaskWhereStrategy}; use crate::{CubeRuntime, tensor::CubeTensor}; diff --git a/crates/burn-cubecl/src/kernel/mask/mask_fill.rs b/crates/burn-cubecl/src/kernel/mask/mask_fill.rs index 5c1c61e372..ba5e80c00f 100644 --- a/crates/burn-cubecl/src/kernel/mask/mask_fill.rs +++ b/crates/burn-cubecl/src/kernel/mask/mask_fill.rs @@ -1,9 +1,5 @@ use burn_backend::DType; -use cubecl::{ - calculate_cube_count_elemwise, - prelude::*, - std::{scalar::InputScalar, tensor::layout::linear::LinearView}, -}; +use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; use crate::{ CubeRuntime, diff --git a/crates/burn-cubecl/src/kernel/matmul/base.rs b/crates/burn-cubecl/src/kernel/matmul/base.rs index fc79a84e60..1642a0ad1b 100644 --- a/crates/burn-cubecl/src/kernel/matmul/base.rs +++ b/crates/burn-cubecl/src/kernel/matmul/base.rs @@ -2,7 +2,7 @@ use super::init_matmul_output; use crate::{CubeRuntime, tensor::CubeTensor}; use burn_backend::{DType, QTensorPrimitive}; use cubek::matmul::{ - definition::{MatmulElemType, MatmulElems, MatmulGlobalElems, MatmulSetupError}, + definition::{MatmulElems, MatmulGlobalElems, MatmulSetupError}, launch::{MatmulInputHandleRef, Strategy}, }; @@ -55,8 +55,6 @@ pub(crate) fn launch_matmul( out: CubeTensor, ) -> Result<(), MatmulSetupError> { let client = &lhs.client; - let mut lhs_quant = false; - let mut rhs_quant = false; let lhs_quant_handles = lhs.quantized_handles(); let out_dtype: DType = out.dtype; @@ -66,20 +64,17 @@ pub(crate) fn launch_matmul( lhs.dtype, MatmulInputHandleRef::new(lhs.as_handle_ref(), lhs.dtype.into()), ), - Some((data, scale)) => { - lhs_quant = true; - ( - out_dtype, - MatmulInputHandleRef::quantized( - data.as_handle_ref(), - scale.as_handle_ref(), - &lhs.shape.dims, - lhs.scheme(), - data.dtype.into(), - scale.dtype.into(), - ), - ) - } + Some((data, scale)) => ( + out_dtype, + MatmulInputHandleRef::quantized( + data.as_handle_ref(), + scale.as_handle_ref(), + &lhs.shape.dims, + lhs.scheme(), + data.dtype.into(), + scale.dtype.into(), + ), + ), }; let rhs_quant_handles = rhs.quantized_handles(); @@ -89,35 +84,23 @@ pub(crate) fn launch_matmul( lhs.dtype, MatmulInputHandleRef::new(rhs.as_handle_ref(), lhs.dtype.into()), ), - Some((data, scale)) => { - rhs_quant = true; - ( - out_dtype, - MatmulInputHandleRef::quantized( - data.as_handle_ref(), - scale.as_handle_ref(), - &rhs.shape.dims, - rhs.scheme(), - data.dtype.into(), - scale.dtype.into(), - ), - ) - } + Some((data, scale)) => ( + out_dtype, + MatmulInputHandleRef::quantized( + data.as_handle_ref(), + scale.as_handle_ref(), + &rhs.shape.dims, + rhs.scheme(), + data.dtype.into(), + scale.dtype.into(), + ), + ), }; let mut dtypes = MatmulElems::from_globals(&MatmulGlobalElems { - lhs: MatmulElemType { - dtype: lhs_dtype.into(), - quantized: lhs_quant, - }, - rhs: MatmulElemType { - dtype: rhs_dtype.into(), - quantized: rhs_quant, - }, - out: MatmulElemType { - dtype: out_dtype.into(), - quantized: false, - }, + lhs: lhs_dtype.into(), + rhs: rhs_dtype.into(), + out: out_dtype.into(), }); cubek::matmul::launch::launch_ref( strategy, diff --git a/crates/burn-cubecl/src/kernel/matmul/tune/base.rs b/crates/burn-cubecl/src/kernel/matmul/tune/base.rs index e45570ab47..3e6635d88e 100644 --- a/crates/burn-cubecl/src/kernel/matmul/tune/base.rs +++ b/crates/burn-cubecl/src/kernel/matmul/tune/base.rs @@ -6,7 +6,7 @@ use crate::{ use burn_backend::DType; use cubecl::tune::{LocalTuner, Tunable, TunableSet, TuneGroup, local_tuner}; use cubek::matmul::{ - definition::{MatmulElemType, MatmulKind}, + definition::MatmulKind, launch::{MatmulAutotuneKey, MatmulGlobalScale, Strategy, should_tune_double_buffering}, routines::{ BlueprintStrategy, TileSizeSelection, double_buffering::DoubleBufferingArgs, @@ -400,17 +400,8 @@ fn create_key( &rhs.shape.dims, &lhs.strides, &rhs.strides, - MatmulElemType { - dtype: lhs.dtype.into(), - quantized: matches!(lhs.dtype, DType::QFloat(_)), - }, - MatmulElemType { - dtype: rhs.dtype.into(), - quantized: matches!(rhs.dtype, DType::QFloat(_)), - }, - MatmulElemType { - dtype: out.dtype.into(), - quantized: matches!(out.dtype, DType::QFloat(_)), - }, + lhs.dtype.into(), + rhs.dtype.into(), + out.dtype.into(), ) } diff --git a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs index a3fa067cc3..1a7f9a6f3b 100644 --- a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs @@ -63,12 +63,12 @@ fn adaptive_avg_pool2d_direct( } #[cube] -fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { +fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { (output_size_index * input_size) / output_size } #[cube] -fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { +fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { let index = (output_size_index + 1) * input_size; let index = index.div_ceil(output_size); diff --git a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs index bbf9fa4e5a..4c991f142d 100644 --- a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs @@ -68,12 +68,12 @@ fn adaptive_avg_pool2d_backward_direct( } #[cube] -fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { +fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { (output_size_index * input_size) / output_size } #[cube] -fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { +fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { let index = (output_size_index + 1) * input_size; let index = index.div_ceil(output_size); diff --git a/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs index 546cf10328..4f9a0c19b6 100644 --- a/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs @@ -36,7 +36,7 @@ impl Pool2dDirectStrategy for AvgPoolStrategy { fn initialize( #[comptime] _config: &Self::Config, - #[comptime] line_size: u32, + #[comptime] line_size: LineSize, ) -> Self::Accumulator { let sum = Line::empty(line_size).fill(N::from_int(0)); // Count will be set dynamically: either by accumulate (count_include_pad=false) @@ -49,7 +49,7 @@ impl Pool2dDirectStrategy for AvgPoolStrategy { fn accumulate( #[comptime] config: &Self::Config, accumulator: &mut Self::Accumulator, - _index: u32, + _index: usize, result: Line, ) { let (sum, count) = accumulator; @@ -78,7 +78,7 @@ impl Pool2dDirectStrategy for AvgPoolStrategy { fn store( #[comptime] _config: &Self::Config, - position: u32, + position: usize, output: &mut Tensor>, _output_indices: &mut (), accumulator: Self::Accumulator, diff --git a/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs index 5f7a429ba5..7494ba0ebf 100644 --- a/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs @@ -35,9 +35,9 @@ fn avg_pool2d_backward_kernel( let channel_lines = output.shape(3) / line_size; let channel = (ABSOLUTE_POS % channel_lines) * output.line_size(); let pos = ABSOLUTE_POS / channel_lines; - let iw = pos % output.shape(2); + let iw = pos as u32 % output.shape(2) as u32; let pos = pos / output.shape(2); - let ih = pos % output.shape(1); + let ih = pos as u32 % output.shape(1) as u32; let batch = pos / output.shape(1); let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0)); @@ -45,8 +45,8 @@ fn avg_pool2d_backward_kernel( let (oh_start, oh_end, ow_start, ow_end) = loop_ranges( ih as i32, iw as i32, - grad.shape(1), - grad.shape(2), + grad.shape(1) as u32, + grad.shape(2) as u32, args, kernel_size_0, kernel_size_1, @@ -60,8 +60,8 @@ fn avg_pool2d_backward_kernel( let kernel_size_1 = comptime![kernel_size_1 as u32]; let index_base = batch * grad.stride(0) + channel * grad.stride(3); - let border_bottom = output.shape(1) + padding_0; - let border_right = output.shape(2) + padding_1; + let border_bottom = output.shape(1) as u32 + padding_0; + let border_right = output.shape(2) as u32 + padding_1; let begin_h = ih + padding_0; let begin_w = iw + padding_1; @@ -72,7 +72,8 @@ fn avg_pool2d_backward_kernel( if begin_h >= ih_start && ih < ih_end { for ow in ow_start..ow_end { - let index = index_base + oh * grad.stride(1) + ow * grad.stride(2); + let index = + index_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2); let iw_start = ow * stride_1; let iw_end = Min::min(iw_start + kernel_size_1, border_right); diff --git a/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs index 3ceaf66869..6f54b906d3 100644 --- a/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs @@ -33,7 +33,7 @@ impl Pool2dDirectStrategy for MaxPoolStrategy { fn initialize( #[comptime] _config: &Self::Config, - #[comptime] line_size: u32, + #[comptime] line_size: LineSize, ) -> Self::Accumulator { Line::empty(line_size).fill(N::min_value()) } @@ -41,7 +41,7 @@ impl Pool2dDirectStrategy for MaxPoolStrategy { fn accumulate( #[comptime] _config: &Self::Config, accumulator: &mut Self::Accumulator, - _index: u32, + _index: LineSize, result: Line, ) { *accumulator = Max::max(*accumulator, result); @@ -57,7 +57,7 @@ impl Pool2dDirectStrategy for MaxPoolStrategy { fn store( #[comptime] _config: &Self::Config, - position: u32, + position: usize, output: &mut Tensor>, _output_indices: &mut (), accumulator: Self::Accumulator, @@ -74,7 +74,7 @@ impl Pool2dDirectStrategy for MaxPoolWithIndicesStrategy { fn initialize( #[comptime] _config: &Self::Config, - #[comptime] line_size: u32, + #[comptime] line_size: LineSize, ) -> Self::Accumulator { let val = Line::empty(line_size).fill(N::min_value()); let idx = Line::empty(line_size).fill(0i32); @@ -84,7 +84,7 @@ impl Pool2dDirectStrategy for MaxPoolWithIndicesStrategy { fn accumulate( #[comptime] _config: &Self::Config, accumulator: &mut Self::Accumulator, - index: u32, + index: usize, result: Line, ) { let indices = Line::cast_from(index); @@ -102,7 +102,7 @@ impl Pool2dDirectStrategy for MaxPoolWithIndicesStrategy { fn store( #[comptime] _config: &Self::Config, - position: u32, + position: usize, output: &mut Tensor>, output_indices: &mut Tensor>, accumulator: Self::Accumulator, diff --git a/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs index a639894bb8..a6f272d59d 100644 --- a/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs @@ -38,8 +38,8 @@ fn max_pool2d_with_indices_backward_kernel( let (oh_start, oh_end, ow_start, ow_end) = loop_ranges( ih as i32, iw as i32, - grad.shape(1), - grad.shape(2), + grad.shape(1) as u32, + grad.shape(2) as u32, args, kernel_size_0, kernel_size_1, @@ -51,7 +51,7 @@ fn max_pool2d_with_indices_backward_kernel( for oh in oh_start..oh_end { for ow in ow_start..ow_end { - let index = index_base + oh * grad.stride(1) + ow * grad.stride(2); + let index = index_base + oh as usize * grad.stride(1) + ow as usize * grad.stride(2); let index_max = Line::::cast_from(indices[index / line_size]); grad_acc += select_many( diff --git a/crates/burn-cubecl/src/kernel/pool/pool2d.rs b/crates/burn-cubecl/src/kernel/pool/pool2d.rs index f342b15ba1..a25b3f7cd1 100644 --- a/crates/burn-cubecl/src/kernel/pool/pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/pool2d.rs @@ -16,13 +16,13 @@ pub(crate) trait Pool2dDirectStrategy: Send + Sync + 'static { fn initialize( #[comptime] config: &Self::Config, - #[comptime] line_size: u32, + #[comptime] line_size: LineSize, ) -> Self::Accumulator; fn accumulate( #[comptime] config: &Self::Config, accumulator: &mut Self::Accumulator, - index: u32, + index: usize, result: Line, ); @@ -38,7 +38,7 @@ pub(crate) trait Pool2dDirectStrategy: Send + Sync + 'static { fn store( #[comptime] config: &Self::Config, - position: u32, + position: usize, output: &mut Tensor>, output_indices: &mut Self::Indices, accumulator: Self::Accumulator, @@ -77,13 +77,13 @@ pub fn pool2d_direct( input.stride(2), input.stride(3), ); - let (in_h, in_w) = (input.shape(1), input.shape(2)); + let (in_h, in_w) = (input.shape(1) as u32, input.shape(2) as u32); let c = (ABSOLUTE_POS % channel_lines) * input.line_size(); let pos = ABSOLUTE_POS / channel_lines; - let ow = pos % out_w; + let ow = pos as u32 % out_w as u32; let pos = pos / out_w; - let oh = pos % out_h; + let oh = pos as u32 % out_h as u32; let b = pos / out_h; let mut accumulator = S::Pool2d::::initialize(config, input.line_size()); @@ -110,15 +110,15 @@ pub fn pool2d_direct( let ih_pad = ih - args.padding_0; let iw_pad = iw - args.padding_1; - let in_h_off = ih_pad * in_stride_h; - let in_w_off = iw_pad * in_stride_w; + let in_h_off = ih_pad as usize * in_stride_h; + let in_w_off = iw_pad as usize * in_stride_w; let index_input = in_b_off + in_c_off + in_h_off + in_w_off; S::Pool2d::::accumulate( config, &mut accumulator, - ih_pad * in_w + iw_pad, + ih_pad as usize * in_w as usize + iw_pad as usize, input[index_input / input.line_size()], ); } diff --git a/crates/burn-cubecl/src/kernel/utils.rs b/crates/burn-cubecl/src/kernel/utils.rs index 20a5770bf7..67fa3c85f1 100644 --- a/crates/burn-cubecl/src/kernel/utils.rs +++ b/crates/burn-cubecl/src/kernel/utils.rs @@ -1,5 +1,6 @@ use burn_backend::Shape; use cubecl::{ + ir::LineSize, prelude::ArrayArg, std::{ FastDivmod, FastDivmodArgs, @@ -10,17 +11,19 @@ use cubecl::{prelude::SequenceArg, std::tensor::layout::linear::LinearLayout}; use crate::{CubeRuntime, tensor::CubeTensor}; -pub fn shape_divmod<'a, R: CubeRuntime>(tensor: &CubeTensor) -> SequenceArg<'a, R, FastDivmod> { +pub fn shape_divmod<'a, R: CubeRuntime>( + tensor: &CubeTensor, +) -> SequenceArg<'a, R, FastDivmod> { let mut arg = SequenceArg::new(); for dim in tensor.shape.iter() { - arg.push(FastDivmodArgs::new(&tensor.client, *dim as u32)); + arg.push(FastDivmodArgs::::new(&tensor.client, *dim)); } arg } pub fn linear_layout<'a, R: CubeRuntime>( tensor: &'a CubeTensor, - line_size: u8, + line_size: LineSize, ) -> LinearLayoutArgs<'a, R> { LinearLayoutArgs::from_shape_strides(&tensor.client, &tensor.shape, &tensor.strides, line_size) } @@ -28,7 +31,7 @@ pub fn linear_layout<'a, R: CubeRuntime>( pub fn linear_layout_ref<'a, R: CubeRuntime>( tensor: &'a CubeTensor, reference: &'a CubeTensor, - line_size: u8, + line_size: LineSize, ) -> LinearLayoutArgs<'a, R> { LinearLayoutArgs::from_shape_strides_with_reference( &tensor.client, @@ -41,7 +44,7 @@ pub fn linear_layout_ref<'a, R: CubeRuntime>( pub fn linear_view<'a, R: CubeRuntime>( tensor: &'a CubeTensor, - line_size: u8, + line_size: LineSize, ) -> LinearViewLaunch<'a, R> { let len = tensor.shape.iter().product::(); let layout = linear_layout(tensor, line_size); @@ -54,7 +57,7 @@ pub fn linear_view<'a, R: CubeRuntime>( pub fn linear_view_ref<'a, R: CubeRuntime>( tensor: &'a CubeTensor, reference: &'a CubeTensor, - line_size: u8, + line_size: LineSize, ) -> LinearViewLaunch<'a, R> { let len = tensor.shape.iter().product::(); let layout = linear_layout_ref(tensor, reference, line_size); @@ -66,7 +69,7 @@ pub fn linear_view_ref<'a, R: CubeRuntime>( pub fn linear_view_alias<'a, R: CubeRuntime>( tensor: &'a CubeTensor, - line_size: u8, + line_size: LineSize, pos: usize, ) -> LinearViewLaunch<'a, R> { let layout = linear_layout(tensor, line_size); diff --git a/crates/burn-cubecl/src/ops/base.rs b/crates/burn-cubecl/src/ops/base.rs index e5da9bdabb..ebddccba86 100644 --- a/crates/burn-cubecl/src/ops/base.rs +++ b/crates/burn-cubecl/src/ops/base.rs @@ -5,7 +5,7 @@ use burn_backend::{ }; use burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape}; use burn_std::tensor::{ReshapeAction, contiguous_strides, reshape_action}; -use cubecl::server::CopyDescriptor; +use cubecl::{ir::LineSize, server::CopyDescriptor}; use cubecl::{quant::scheme::BlockSize, tensor_line_size_parallel}; pub(crate) fn from_data(data: TensorData, device: &R::Device) -> CubeTensor { @@ -333,7 +333,7 @@ pub fn q_reshape(mut tensor: CubeTensor, shape: Shape) -> Cub tensor } -pub(crate) fn max_line_size(tensor: &CubeTensor) -> u8 { +pub(crate) fn max_line_size(tensor: &CubeTensor) -> LineSize { tensor_line_size_parallel( tensor .client @@ -344,7 +344,10 @@ pub(crate) fn max_line_size(tensor: &CubeTensor) -> u8 { ) } -pub(crate) fn max_line_size_many(tensors: &[&CubeTensor], axis: usize) -> u8 { +pub(crate) fn max_line_size_many( + tensors: &[&CubeTensor], + axis: usize, +) -> LineSize { let vec = tensors .iter() .map(|tensor| { diff --git a/crates/burn-cubecl/src/ops/bool_tensor.rs b/crates/burn-cubecl/src/ops/bool_tensor.rs index 8132d9f429..22d39a2b0d 100644 --- a/crates/burn-cubecl/src/ops/bool_tensor.rs +++ b/crates/burn-cubecl/src/ops/bool_tensor.rs @@ -9,7 +9,7 @@ use burn_backend::{ tensor::{BoolTensor, Device, FloatTensor, IntTensor}, }; use burn_backend::{Shape, TensorData, tensor::BoolElem}; -use cubecl::std::scalar::InputScalar; +use cubecl::prelude::InputScalar; use std::ops::Range; use super::{expand, numeric, permute, unfold}; diff --git a/crates/burn-cubecl/src/ops/int_tensor.rs b/crates/burn-cubecl/src/ops/int_tensor.rs index 5ad13d86f7..d73790d754 100644 --- a/crates/burn-cubecl/src/ops/int_tensor.rs +++ b/crates/burn-cubecl/src/ops/int_tensor.rs @@ -20,8 +20,8 @@ use burn_backend::ExecutionError; use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use burn_backend::{DType, IntDType, Slice, ops::IntTensorOps}; use burn_backend::{Distribution, ElementConversion, Shape, TensorData}; +use cubecl::frontend::Numeric; use cubecl::prelude::*; -use cubecl::{frontend::Numeric, std::scalar::InputScalar}; use cubek::reduce::components::instructions::ReduceOperationConfig; use std::ops::Range; diff --git a/crates/burn-cubecl/src/ops/numeric.rs b/crates/burn-cubecl/src/ops/numeric.rs index 69889ae8cc..da4acc7641 100644 --- a/crates/burn-cubecl/src/ops/numeric.rs +++ b/crates/burn-cubecl/src/ops/numeric.rs @@ -11,7 +11,7 @@ use crate::{ ops::max_line_size, }; use burn_backend::{DType, Shape}; -use cubecl::std::{FastDivmod, scalar::InputScalar, tensor::layout::linear::LinearView}; +use cubecl::std::{FastDivmod, tensor::layout::linear::LinearView}; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use cubecl::{client::ComputeClient, server::Allocation}; @@ -343,8 +343,8 @@ impl CumulativeOp for MinOp { fn cumulative_kernel( input: &Tensor, output: &mut LinearView, - shape: Sequence, - #[comptime] dim: u32, + shape: Sequence>, + #[comptime] dim: usize, #[define(C)] _dtype: StorageType, ) { if !output.is_in_bounds(ABSOLUTE_POS) { @@ -428,7 +428,7 @@ fn cumulative_op( input.as_tensor_arg(1), linear_view(&output, 1), shape_divmod(&input), - dim as u32, + dim, output.dtype.into(), ) .expect("Kernel to never fail"); diff --git a/crates/burn-cubecl/src/ops/tensor.rs b/crates/burn-cubecl/src/ops/tensor.rs index e41c571093..8d8921b448 100644 --- a/crates/burn-cubecl/src/ops/tensor.rs +++ b/crates/burn-cubecl/src/ops/tensor.rs @@ -16,7 +16,6 @@ use burn_backend::{Backend, ExecutionError}; use burn_backend::{DType, ElementConversion, FloatDType, Slice}; use burn_backend::{Distribution, Shape, TensorData, ops::FloatTensorOps}; use cubecl::prelude::*; -use cubecl::std::scalar::InputScalar; use cubek::reduce::components::instructions::ReduceOperationConfig; use std::ops::Range; diff --git a/crates/burn-cubecl/src/template/base.rs b/crates/burn-cubecl/src/template/base.rs index 68e61c9f8e..bedb44c643 100644 --- a/crates/burn-cubecl/src/template/base.rs +++ b/crates/burn-cubecl/src/template/base.rs @@ -23,6 +23,7 @@ impl CubeTask for SourceKernel { _compiler: &mut C, _options: &C::CompilationOptions, _mode: ExecutionMode, + _address_type: StorageType, ) -> Result, CompilationError> { let source_template = self.kernel_source.source(); let source = source_template.complete(); @@ -42,6 +43,10 @@ impl KernelMetadata for SourceKernel { fn id(&self) -> KernelId { self.kernel_source.id() } + + fn address_type(&self) -> StorageType { + u32::as_type_native_unchecked() + } } /// Generates kernel source code by replacing some information using templating. diff --git a/crates/burn-cubecl/src/tensor/base.rs b/crates/burn-cubecl/src/tensor/base.rs index 8562e2eff3..dec4d239d8 100644 --- a/crates/burn-cubecl/src/tensor/base.rs +++ b/crates/burn-cubecl/src/tensor/base.rs @@ -206,7 +206,7 @@ where } /// Return the reference to a tensor argument. - pub fn as_tensor_arg<'a>(&'a self, line_size: u8) -> TensorArg<'a, R> { + pub fn as_tensor_arg<'a>(&'a self, line_size: LineSize) -> TensorArg<'a, R> { let size = self.dtype.size(); let handle: TensorHandleRef<'a, R> = self.as_handle_ref(); @@ -222,12 +222,12 @@ where } /// Return the reference to an array argument. - pub fn as_array_arg(&self, vectorisation: u8) -> ArrayArg<'_, R> { + pub fn as_array_arg(&self, line_size: LineSize) -> ArrayArg<'_, R> { unsafe { ArrayArg::from_raw_parts::( &self.handle, self.handle.size() as usize / core::mem::size_of::(), - vectorisation, + line_size, ) } } diff --git a/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs index fe9876c23d..9aec65c85b 100644 --- a/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs @@ -17,18 +17,22 @@ use cubecl::{features::Plane, prelude::*}; use super::prefix_sum::prefix_sum; -const BLOCK_H: u32 = 4; +const BLOCK_H: usize = 4; #[cube] fn merge(labels: &Tensor>, label_1: u32, label_2: u32) { let mut label_1 = label_1; let mut label_2 = label_2; - while label_1 != label_2 && (label_1 != u32::cast_from(Atomic::load(&labels[label_1])) - 1) { - label_1 = u32::cast_from(Atomic::load(&labels[label_1])) - 1; + while label_1 != label_2 + && (label_1 != u32::cast_from(Atomic::load(&labels[label_1 as usize])) - 1) + { + label_1 = u32::cast_from(Atomic::load(&labels[label_1 as usize])) - 1; } - while label_1 != label_2 && (label_2 != u32::cast_from(Atomic::load(&labels[label_2])) - 1) { - label_2 = u32::cast_from(Atomic::load(&labels[label_2])) - 1; + while label_1 != label_2 + && (label_2 != u32::cast_from(Atomic::load(&labels[label_2 as usize])) - 1) + { + label_2 = u32::cast_from(Atomic::load(&labels[label_2 as usize])) - 1; } while label_1 != label_2 { #[allow(clippy::manual_swap)] @@ -37,7 +41,10 @@ fn merge(labels: &Tensor>, label_1: u32, label_2: u32) { label_1 = label_2; label_2 = tmp; } - let label_3 = u32::cast_from(Atomic::min(&labels[label_1], I::cast_from(label_2 + 1))) - 1; + let label_3 = u32::cast_from(Atomic::min( + &labels[label_1 as usize], + I::cast_from(label_2 + 1), + )) - 1; if label_1 == label_3 { label_1 = label_2; } else { @@ -60,7 +67,7 @@ fn end_distance(pixels: u32, tx: u32) -> u32 { #[allow(unconditional_panic, reason = "clippy thinks PLANE_DIM is always 2")] fn ballot_dyn(y: u32, pred: bool) -> u32 { let index = y % (PLANE_DIM / 32); - plane_ballot(pred)[index] + plane_ballot(pred)[index as usize] } #[cube(launch_unchecked)] @@ -72,23 +79,23 @@ fn strip_labeling( let mut shared_pixels = SharedMemory::::new(BLOCK_H); let y = ABSOLUTE_POS_Y; - let rows = labels.shape(0); - let cols = labels.shape(1); + let rows = labels.shape(0) as u32; + let cols = labels.shape(1) as u32; if y >= rows { terminate!(); } - let img_stride = img.stride(0); - let labels_stride = labels.stride(0); + let img_stride = img.stride(0) as u32; + let labels_stride = labels.stride(0) as u32; let img_line_base = y * img_stride + UNIT_POS_X; - let labels_line_base = y * labels.stride(0) + UNIT_POS_X; + let labels_line_base = y * labels_stride + UNIT_POS_X; - let mut distance_y = 0; + let mut distance_y = 0u32; let mut distance_y_1 = 0; - for i in range_stepped(0, img.shape(1), PLANE_DIM) { + for i in range_stepped(0, img.shape(1) as u32, PLANE_DIM) { let x = UNIT_POS_X + i; if x < cols { @@ -101,14 +108,14 @@ fn strip_labeling( let img_index = img_line_base + i; let labels_index = labels_line_base + i; - let p_y = bool::cast_from(img[img_index]); + let p_y = bool::cast_from(img[img_index as usize]); let pixels_y = ballot_dyn(UNIT_POS_Y, p_y) & mask; let mut s_dist_y = start_distance(pixels_y, UNIT_POS_X); if p_y && s_dist_y == 0 { Atomic::store( - &labels[labels_index], + &labels[labels_index as usize], I::cast_from(labels_index - select(UNIT_POS_X == 0, distance_y, 0) + 1), ); } @@ -117,7 +124,7 @@ fn strip_labeling( sync_cube(); if UNIT_POS_X == 0 { - shared_pixels[UNIT_POS_Y] = pixels_y; + shared_pixels[UNIT_POS_Y as usize] = pixels_y; } sync_cube(); @@ -125,7 +132,7 @@ fn strip_labeling( // Requires if and not select, because `select` may execute the then branch even if the // condition is false (on non-CUDA backends), which can lead to OOB reads. let pixels_y_1 = if UNIT_POS_Y > 0 { - shared_pixels[UNIT_POS_Y - 1] + shared_pixels[(UNIT_POS_Y - 1) as usize] } else { 0u32.runtime() }; @@ -198,14 +205,14 @@ fn strip_merge( #[comptime] connectivity: Connectivity, ) { let plane_start_x = CUBE_POS_X * (CUBE_DIM_X * CUBE_DIM_Z - PLANE_DIM) + UNIT_POS_Z * PLANE_DIM; - let y = (CUBE_POS_Y + 1) * BLOCK_H; + let y = (CUBE_POS_Y + 1) * BLOCK_H as u32; let x = plane_start_x + UNIT_POS_X; - let img_step = img.stride(0); - let labels_step = labels.stride(0); - let cols = img.shape(1); + let img_step = img.stride(0) as u32; + let labels_step = labels.stride(0) as u32; + let cols = img.shape(1) as u32; - if y < labels.shape(0) && x < labels.shape(1) { + if y < labels.shape(0) as u32 && x < labels.shape(1) as u32 { let mut mask = 0xffffffffu32; if cols - plane_start_x < 32 { mask >>= 32 - (cols - plane_start_x); @@ -217,8 +224,8 @@ fn strip_merge( let img_index_up = img_index - img_step; let labels_index_up = labels_index - labels_step; - let p = bool::cast_from(img[img_index]); - let p_up = bool::cast_from(img[img_index_up]); + let p = bool::cast_from(img[img_index as usize]); + let p_up = bool::cast_from(img[img_index_up as usize]); let pixels = ballot_dyn(UNIT_POS_Z, p) & mask; let pixels_up = ballot_dyn(UNIT_POS_Z, p_up) & mask; @@ -234,27 +241,27 @@ fn strip_merge( } } Connectivity::Eight => { - let mut last_dist_vec = SharedMemory::::new(32); - let mut last_dist_up_vec = SharedMemory::::new(32); + let mut last_dist_vec = SharedMemory::::new(32usize); + let mut last_dist_up_vec = SharedMemory::::new(32usize); let s_dist = start_distance(pixels, UNIT_POS_X); let s_dist_up = start_distance(pixels_up, UNIT_POS_X); if UNIT_POS_PLANE == PLANE_DIM - 1 { - last_dist_vec[UNIT_POS_Z] = start_distance(pixels, 32); - last_dist_up_vec[UNIT_POS_Z] = start_distance(pixels_up, 32); + last_dist_vec[UNIT_POS_Z as usize] = start_distance(pixels, 32); + last_dist_up_vec[UNIT_POS_Z as usize] = start_distance(pixels_up, 32); } sync_cube(); if CUBE_POS_X == 0 || UNIT_POS_Z > 0 { let last_dist = if UNIT_POS_Z > 0 { - last_dist_vec[UNIT_POS_Z - 1] + last_dist_vec[(UNIT_POS_Z - 1) as usize] } else { 0u32.runtime() }; let last_dist_up = if UNIT_POS_Z > 0 { - last_dist_up_vec[UNIT_POS_Z - 1] + last_dist_up_vec[(UNIT_POS_Z - 1) as usize] } else { 0u32.runtime() }; @@ -300,10 +307,10 @@ fn relabeling(img: &Tensor, labels: &mut Tensor(img: &Tensor, labels: &mut Tensor( let y = ABSOLUTE_POS_Y; let x = ABSOLUTE_POS_X; - let cols = labels.shape(1); - let rows = labels.shape(0); - let img_step = img.stride(0); - let labels_step = labels.stride(0); + let cols = labels.shape(1) as u32; + let rows = labels.shape(0) as u32; + let img_step = img.stride(0) as u32; + let labels_step = labels.stride(0) as u32; if x < cols && y < rows { let mut mask = 0xffffffffu32; @@ -363,7 +370,7 @@ fn analysis( let img_index = y * img_step + x; let labels_index = y * labels_step + x; - let p = bool::cast_from(img[img_index]); + let p = bool::cast_from(img[img_index as usize]); let pixels = ballot_dyn(UNIT_POS_Y, p) & mask; let s_dist = start_distance(pixels, UNIT_POS_X); let count = end_distance(pixels, UNIT_POS_X); @@ -372,29 +379,29 @@ fn analysis( let mut label = 0u32; if p && s_dist == 0 { - label = u32::cast_from(labels[labels_index]) - 1; - while label != u32::cast_from(labels[label]) - 1 { - label = u32::cast_from(labels[label]) - 1; + label = u32::cast_from(labels[labels_index as usize]) - 1; + while label != u32::cast_from(labels[label as usize]) - 1 { + label = u32::cast_from(labels[label as usize]) - 1; } label += 1; - Atomic::add(&area[label], I::cast_from(count)); + Atomic::add(&area[label as usize], I::cast_from(count)); if opts.bounds_enabled { - Atomic::min(&left[label], I::cast_from(x)); - Atomic::min(&top[label], I::cast_from(y)); - Atomic::max(&right[label], I::cast_from(max_x)); - Atomic::max(&bottom[label], I::cast_from(y)); + Atomic::min(&left[label as usize], I::cast_from(x)); + Atomic::min(&top[label as usize], I::cast_from(y)); + Atomic::max(&right[label as usize], I::cast_from(max_x)); + Atomic::max(&bottom[label as usize], I::cast_from(y)); } if comptime!(opts.max_label_enabled || opts.compact_labels) { Atomic::max(&max_label[0], I::cast_from(label)); } } - label = plane_broadcast(label, UNIT_POS_X - s_dist); + label = plane_shuffle(label, UNIT_POS_X - s_dist); if p { - labels[labels_index] = I::cast_from(label); + labels[labels_index as usize] = I::cast_from(label); } } } @@ -408,16 +415,16 @@ fn compact_labels( let x = ABSOLUTE_POS_X; let y = ABSOLUTE_POS_Y; - let labels_pos = y * labels.stride(0) + x; + let labels_pos = y * labels.stride(0) as u32 + x; - if labels_pos >= labels.len() { + if labels_pos as usize >= labels.len() { terminate!(); } - let label = u32::cast_from(labels[labels_pos]); + let label = u32::cast_from(labels[labels_pos as usize]); if label != 0 { - let new_label = remap[label]; - labels[labels_pos] = new_label; + let new_label = remap[label as usize]; + labels[labels_pos as usize] = new_label; Atomic::max(&max_label[0], new_label); } } @@ -437,23 +444,23 @@ fn compact_stats( remap: &Tensor, ) { let label = ABSOLUTE_POS_X; - if label >= remap.len() { + if label as usize >= remap.len() { terminate!(); } - let area = area[label]; + let area = area[label as usize]; if area == I::new(0) { terminate!(); } - let new_label = u32::cast_from(remap[label]); + let new_label = u32::cast_from(remap[label as usize]); - area_new[new_label] = area; + area_new[new_label as usize] = area; // This should be gated but there's a problem with the Eq bound only being implemented for tuples // up to 12 elems, so I can't pass the opts. It's not unsafe, but potentially unnecessary work. - top_new[new_label] = top[label]; - left_new[new_label] = left[label]; - right_new[new_label] = right[label]; - bottom_new[new_label] = bottom[label]; + top_new[new_label as usize] = top[label as usize]; + left_new[new_label as usize] = left[label as usize]; + right_new[new_label as usize] = right[label as usize]; + bottom_new[new_label as usize] = bottom[label as usize]; } #[allow(clippy::type_complexity)] @@ -502,7 +509,7 @@ pub fn hardware_accelerated( @@ -22,11 +22,11 @@ fn prefix_sum_kernel( scan_out: &mut Tensor>, scan_bump: &Tensor>, reduction: &Tensor>, - cube_count_x: u32, + cube_count_x: usize, ) { - let mut broadcast = SharedMemory::::new(1); + let mut broadcast = SharedMemory::::new(1usize); let mut reduce = SharedMemory::::new(MAX_REDUCE_SIZE); - let batch = CUBE_POS_Z; + let batch = CUBE_POS_Z as usize; let line_spt = comptime!(PART_SIZE / CUBE_SIZE / scan_in.line_size()); let nums_per_cube = CUBE_SIZE * line_spt; let v_last = comptime!(scan_in.line_size() - 1); @@ -36,11 +36,11 @@ fn prefix_sum_kernel( broadcast[0] = Atomic::add(&scan_bump[batch], I::new(1)); } sync_cube(); - let part_id = u32::cast_from(broadcast[0]); + let part_id = usize::cast_from(broadcast[0]); let plane_id = UNIT_POS_X / PLANE_DIM; let dev_offs = part_id * nums_per_cube; - let plane_offs = plane_id * PLANE_DIM * line_spt; + let plane_offs = UNIT_POS_X as usize * line_spt; // Exit if full plane is out of bounds if dev_offs + plane_offs >= scan_in.shape(1) { @@ -56,9 +56,9 @@ fn prefix_sum_kernel( let red_offs = batch * reduction.stride(0); let scan_offs = batch * scan_in.stride(0); - let mut t_scan = Array::>::vectorized(line_spt, scan_in.line_size()); + let mut t_scan = Array::>::lined(line_spt, scan_in.line_size()); { - let mut i = dev_offs + plane_offs + UNIT_POS_PLANE; + let mut i = dev_offs + plane_offs + UNIT_POS_PLANE as usize; if part_id < cube_count_x - 1 { for k in 0..line_spt { @@ -70,7 +70,7 @@ fn prefix_sum_kernel( scan[v] += prev; } t_scan[k] = scan; - i += PLANE_DIM; + i += PLANE_DIM as usize; } } @@ -87,7 +87,7 @@ fn prefix_sum_kernel( } t_scan[k] = scan; } - i += PLANE_DIM; + i += PLANE_DIM as usize; } } @@ -95,13 +95,13 @@ fn prefix_sum_kernel( let plane_mask = PLANE_DIM - 1; let circular_shift = (UNIT_POS_PLANE + plane_mask) & plane_mask; for k in 0..line_spt { - let t = plane_broadcast(plane_inclusive_sum(t_scan[k][v_last]), circular_shift); + let t = plane_shuffle(plane_inclusive_sum(t_scan[k][v_last]), circular_shift); t_scan[k] += Line::cast_from(select(UNIT_POS_PLANE != 0, t, zero) + prev); - prev += plane_broadcast(t, 0); + prev += plane_broadcast(t, 0u32); } if UNIT_POS_PLANE == 0 { - reduce[plane_id] = prev; + reduce[plane_id as usize] = prev; } } sync_cube(); @@ -118,9 +118,9 @@ fn prefix_sum_kernel( while j <= aligned_size { let i_0 = ((UNIT_POS_X + offset_0) << offset_1) - offset_0; let pred_0 = i_0 < spine_size; - let t_0 = plane_inclusive_sum(select(pred_0, reduce[i_0], zero)); + let t_0 = plane_inclusive_sum(select(pred_0, reduce[i_0 as usize], zero)); if pred_0 { - reduce[i_0] = t_0; + reduce[i_0 as usize] = t_0; } sync_cube(); @@ -129,9 +129,13 @@ fn prefix_sum_kernel( let i_1 = UNIT_POS_X + rshift; if (i_1 & (j - 1)) >= rshift { let pred_1 = i_1 < spine_size; - let t_1 = select(pred_1, reduce[((i_1 >> offset_1) << offset_1) - 1], zero); + let t_1 = select( + pred_1, + reduce[(((i_1 >> offset_1) << offset_1) - 1) as usize], + zero, + ); if pred_1 && ((i_1 + 1) & (rshift - 1)) != 0 { - reduce[i_1] += t_1; + reduce[i_1 as usize] += t_1; } } } else { @@ -148,7 +152,7 @@ fn prefix_sum_kernel( if UNIT_POS_X == 0 { Atomic::store( &reduction[part_id + red_offs], - (reduce[spine_size - 1] << I::new(2)) + (reduce[(spine_size - 1) as usize] << I::new(2)) | select(part_id != 0, flag_reduction, flag_inclusive), ) } @@ -164,7 +168,8 @@ fn prefix_sum_kernel( prev_reduction += flag_payload >> I::new(2); Atomic::store( &reduction[part_id + red_offs], - ((prev_reduction + reduce[spine_size - 1]) << I::new(2)) | flag_inclusive, + ((prev_reduction + reduce[(spine_size - 1) as usize]) << I::new(2)) + | flag_inclusive, ); broadcast[0] = prev_reduction; break; @@ -181,19 +186,19 @@ fn prefix_sum_kernel( { let prev = if plane_id != 0 { - reduce[plane_id - 1] + reduce[(plane_id - 1) as usize] } else { zero }; let prev = Line::cast_from(broadcast[0] + prev); - let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt; + let s_offset = UNIT_POS_PLANE + plane_id * PLANE_DIM * line_spt as u32; let dev_offset = part_id * nums_per_cube; - let mut i = s_offset + dev_offset; + let mut i = s_offset as usize + dev_offset; if part_id < cube_count_x - 1 { for k in 0..line_spt { scan_out[i + scan_offs] = t_scan[k] + prev; - i += PLANE_DIM; + i += PLANE_DIM as usize; } } @@ -202,7 +207,7 @@ fn prefix_sum_kernel( if i < scan_out.shape(1) { scan_out[i + scan_offs] = t_scan[k] + prev; } - i += PLANE_DIM; + i += PLANE_DIM as usize; } } } @@ -217,27 +222,27 @@ fn count_trailing_zeros(num: u32) -> u32 { pub fn prefix_sum(input: CubeTensor) -> CubeTensor { let client = input.client.clone(); let device = input.device.clone(); - let num_elems = input.shape.num_elements() as u32; - let numbers = *input.shape.last().unwrap() as u32; + let num_elems = input.shape.num_elements(); + let numbers = *input.shape.last().unwrap(); let batches = num_elems / numbers; - let input = reshape(input, Shape::new([batches as usize, numbers as usize])); + let input = reshape(input, Shape::new([batches, numbers])); let out = empty_device::(client.clone(), device.clone(), input.shape.clone()); let cubes = numbers.div_ceil(PART_SIZE); - let cube_dim = CubeDim::new_1d(CUBE_SIZE); - let cube_count = CubeCount::new_3d(cubes, 1, batches); + let cube_dim = CubeDim::new_1d(CUBE_SIZE as u32); + let cube_count = CubeCount::new_3d(cubes as u32, 1, batches as u32); let bump = zeros_client::( client.clone(), device.clone(), - Shape::new([batches as usize]), + Shape::new([batches]), I::dtype(), ); let reduction = zeros_client::( client.clone(), device.clone(), - Shape::new([batches as usize, cubes as usize]), + Shape::new([batches, cubes]), I::dtype(), ); diff --git a/examples/custom-cubecl-kernel/src/kernel.rs b/examples/custom-cubecl-kernel/src/kernel.rs index 08d4ded4d7..872feb5df7 100644 --- a/examples/custom-cubecl-kernel/src/kernel.rs +++ b/examples/custom-cubecl-kernel/src/kernel.rs @@ -8,9 +8,9 @@ pub fn fused_matmul_add_relu_kernel( bias: &Tensor, output: &mut Tensor, ) { - let row = ABSOLUTE_POS_X; - let col = ABSOLUTE_POS_Y; - let batch = ABSOLUTE_POS_Z; + let row = ABSOLUTE_POS_X as usize; + let col = ABSOLUTE_POS_Y as usize; + let batch = ABSOLUTE_POS_Z as usize; let n_rows = output.shape(output.rank() - 2); let n_cols = output.shape(output.rank() - 1);