diff --git a/CMakeLists.txt b/CMakeLists.txt index 61006e83cbf..396cba89e7a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -557,7 +557,7 @@ target_link_libraries(codegen_internal PUBLIC ) if (BUILD_CUTLASS AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) target_link_libraries(codegen_internal PUBLIC nvf_cutlass) - target_compile_definitions(codegen_internal PRIVATE "-DNVFUSER_CUTLASS_KERNEL_ENABLED") + target_compile_definitions(codegen_internal PUBLIC "-DNVFUSER_CUTLASS_KERNEL_ENABLED") endif() target_link_libraries(codegen_internal PUBLIC LLVM_JIT) diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index b3ec5e39c3a..27e56883e2e 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -1713,7 +1713,6 @@ std::string CutlassNvfp4GroupedMmaOp::toInlineString(int indent_size) const { std::vector CutlassNvfp4GroupedMmaOp::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { -#if NVFUSER_CUTLASS_KERNEL_ENABLED const auto& mat1 = inputs[0].as(); const auto& mat2 = inputs[1].as(); const auto& scale1 = inputs[2].as(); @@ -1722,6 +1721,31 @@ std::vector CutlassNvfp4GroupedMmaOp::evaluate( const auto& problem_sizes = inputs[5].as(); const auto& expert_offsets = inputs[6].as(); const auto& sf_offsets = inputs[7].as(); + + // Meta-device fast path outside of torch version guard + if (mat1.is_meta() || mat2.is_meta() || scale1.is_meta() || + scale2.is_meta() || alpha.is_meta() || problem_sizes.is_meta() || + expert_offsets.is_meta() || sf_offsets.is_meta()) { + // For nvfp4_scaled_grouped_mm, the output shape is [M, N] + // where M = mat1.size(0) and N = mat2.size(2). + // Note: CutlassNvfp4GroupedMmaOp expects mat2 to be [G, K/2, N] (packed) at + // runtime and transposes it before calling into CUTLASS. + std::vector result_sizes = {mat1.size(0), mat2.size(2)}; + + at::ScalarType out_dtype = data_type_to_aten(out()->dtype()); + auto options = + mat1.options().device(c10::Device(c10::kMeta)).dtype(out_dtype); + at::Tensor result = at::empty(result_sizes, options); + + if (const auto rfactor_did_idx = getRFactorDeviceDimensionIndex(out()); + rfactor_did_idx != -1) { + result = result.unsqueeze(rfactor_did_idx); + } + + return {result}; + } + +#if NVFUSER_CUTLASS_KERNEL_ENABLED NVF_CHECK( mat1.scalar_type() == at::ScalarType::Float4_e2m1fn_x2 && mat2.scalar_type() == at::ScalarType::Float4_e2m1fn_x2); diff --git a/tests/cpp/test_meta.cpp b/tests/cpp/test_meta.cpp index f0e2b4bcb3b..613ce97a988 100644 --- a/tests/cpp/test_meta.cpp +++ b/tests/cpp/test_meta.cpp @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -533,4 +534,204 @@ TEST_F(MetaTest, Matmul1D) { EXPECT_EQ(meta_out.strides(), real_out.strides()); } +// Test CutlassNvfp4GroupedMmaOp with meta device +TEST_F(MetaTest, CutlassNvfp4GroupedMma) { +#if NVFUSER_CUTLASS_KERNEL_ENABLED + // NVFP4 CUTLASS grouped MMA relies on SM100+ kernels (Blackwell) and TMA. + // On older GPUs the CUTLASS kernel may compile but fail at runtime when + // initializing TMA descriptors (e.g., status 801). + if (!deviceMajorMinorCheck(10)) { + GTEST_SKIP() << "CutlassNvfp4GroupedMma requires SM100+ (compute " + "capability >= 10.0)"; + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // Choose an example where all M, N, K, and K/2 are different, and CUTLASS + // alignment constraints are satisfied. Also "bigger" shapes tend to be more + // robust for TMA/tile constraints: + // num_experts = 4 + // m_per_expert = 64 => M = 256 + // N = 96 + // K = 320, K/2 = 160 + constexpr int64_t G = 4; + constexpr int64_t M_PER_EXPERT = 64; + constexpr int64_t M = G * M_PER_EXPERT; + constexpr int64_t N = 96; + constexpr int64_t K = 320; + constexpr int64_t K_DIV_2 = K / 2; // 160 + constexpr int64_t K_DIV_16 = K / 16; // 20 + + // Shapes: + // mat1: [M, K] = [256, 320] (logical/unpacked shape) + // mat2: [G, K, N] = [4, 320, 96] (logical/unpacked shape) + // output: [M, N] = [256, 96] + // Note: Packed dtype Float4_e2m1fn_x2 is not allowed in IR. We use the + // unpacked dtype (Float4_e2m1fn) and the logical K dimension. When binding a + // packed ATen tensor (K/2), the last dim is adjusted (K/2 -> K). + auto mat1 = makeContigConcreteTensor({M, K}, DataType::Float4_e2m1fn); + // mat2 is expected as [G, K, N] logically. When binding packed FP4 inputs + // (Float4_e2m1fn_x2) with shape [G, K/2, N], the evaluator adjusts the K dim + // (K/2 -> K). We also set a stride order so the adjustment applies to the K + // dimension even though it isn't the last logical dimension. + auto mat2 = TensorViewBuilder() + .shape({G, K, N}) + .dtype(DataType::Float4_e2m1fn) + .contiguity({true, true, true}) + .strideOrder({2, 0, 1}) + .build(); + // Block-scaling factors have last dim K / 16 + auto scale1 = + makeContigConcreteTensor({M, K_DIV_16}, DataType::Float8_e4m3fn); + auto scale2 = + makeContigConcreteTensor({G, N, K_DIV_16}, DataType::Float8_e4m3fn); + auto alpha = makeContigConcreteTensor({G}, DataType::Float); + auto problem_sizes = makeContigConcreteTensor({G, 3}, DataType::Index); + auto expert_offsets = makeContigConcreteTensor({G}, DataType::Index); + auto sf_offsets = makeContigConcreteTensor({G}, DataType::Index); + + fusion->addInput(mat1); + fusion->addInput(mat2); + fusion->addInput(scale1); + fusion->addInput(scale2); + fusion->addInput(alpha); + fusion->addInput(problem_sizes); + fusion->addInput(expert_offsets); + fusion->addInput(sf_offsets); + + auto result = cutlass_nvfp4_grouped_mm( + mat1, + mat2, + scale1, + scale2, + alpha, + problem_sizes, + expert_offsets, + sf_offsets, + DataType::BFloat16); + fusion->addOutput(result); + + // Create real inputs with appropriate data types + auto options_uint8 = + at::TensorOptions().dtype(at::ScalarType::Byte).device(at::kCUDA, 0); + auto options_fp4 = + at::TensorOptions().dtype(at::kFloat4_e2m1fn_x2).device(at::kCUDA, 0); + auto options_fp8 = + at::TensorOptions().dtype(at::kFloat8_e4m3fn).device(at::kCUDA, 0); + auto options_fp32 = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, 0); + + // FP4 tensors must be created as UInt8 and viewed as Float4 + at::Tensor mat1_uint8 = at::randint(0, 256, {M, K_DIV_2}, options_uint8); + at::Tensor mat1_input = mat1_uint8.contiguous().view(at::kFloat4_e2m1fn_x2); + + // IMPORTANT: CutlassNvfp4GroupedMmaOp::evaluate transposes mat2 before + // calling into CUTLASS, which requires the transposed result to be + // contiguous. Construct mat2 as a transpose-view of a contiguous [G, N, K/2] + // tensor so that CutlassNvfp4GroupedMmaOp::evaluate's internal transpose + // produces a contiguous tensor. + at::Tensor mat2_base_uint8 = + at::randint(0, 256, {G, N, K_DIV_2}, options_uint8); + at::Tensor mat2_base = + mat2_base_uint8.contiguous().view(at::kFloat4_e2m1fn_x2); + at::Tensor mat2_input = mat2_base.transpose(-1, -2); // [G, K/2, N] + // FP8 tensors can be created from FP32 tensors + at::Tensor scale1_input = + at::randn({M, K_DIV_16}, options_fp32).to(at::kFloat8_e4m3fn); + at::Tensor scale2_input = + at::randn({G, N, K_DIV_16}, options_fp32).to(at::kFloat8_e4m3fn); + at::Tensor alpha_input = at::ones({G}, options_fp32); + // problem_sizes uses unpacked dimensions per expert: (m_i, n, k) + const std::vector problem_sizes_vec{ + // expert 0 + M_PER_EXPERT, + N, + K, + // expert 1 + M_PER_EXPERT, + N, + K, + // expert 2 + M_PER_EXPERT, + N, + K, + // expert 3 + M_PER_EXPERT, + N, + K, + }; + at::Tensor problem_sizes_input = + at::tensor(problem_sizes_vec, options_int).reshape({G, 3}); + const std::vector expert_offsets_vec{ + 0, M_PER_EXPERT, 2 * M_PER_EXPERT, 3 * M_PER_EXPERT}; + at::Tensor expert_offsets_input = at::tensor(expert_offsets_vec, options_int); + const std::vector sf_offsets_vec{ + 0, M_PER_EXPERT, 2 * M_PER_EXPERT, 3 * M_PER_EXPERT}; + at::Tensor sf_offsets_input = at::tensor(sf_offsets_vec, options_int); + + // CUDA path + ExpressionEvaluator ee_cuda; + ee_cuda.bind(fusion->inputs().at(0), mat1_input); + ee_cuda.bind(fusion->inputs().at(1), mat2_input); + ee_cuda.bind(fusion->inputs().at(2), scale1_input); + ee_cuda.bind(fusion->inputs().at(3), scale2_input); + ee_cuda.bind(fusion->inputs().at(4), alpha_input); + ee_cuda.bind(fusion->inputs().at(5), problem_sizes_input); + ee_cuda.bind(fusion->inputs().at(6), expert_offsets_input); + ee_cuda.bind(fusion->inputs().at(7), sf_offsets_input); + auto real_out = ee_cuda.evaluate(fusion->outputs().at(0)).as(); + + // Meta evaluation + ExpressionEvaluator ee_meta; + auto meta_mat1 = at::empty_strided( + mat1_input.sizes(), mat1_input.strides(), options_fp4.device(at::kMeta)); + auto meta_mat2 = at::empty_strided( + mat2_input.sizes(), mat2_input.strides(), options_fp4.device(at::kMeta)); + auto meta_scale1 = at::empty_strided( + scale1_input.sizes(), + scale1_input.strides(), + options_fp8.device(at::kMeta)); + auto meta_scale2 = at::empty_strided( + scale2_input.sizes(), + scale2_input.strides(), + options_fp8.device(at::kMeta)); + auto meta_alpha = at::empty_strided( + alpha_input.sizes(), + alpha_input.strides(), + options_fp32.device(at::kMeta)); + auto meta_problem_sizes = at::empty_strided( + problem_sizes_input.sizes(), + problem_sizes_input.strides(), + options_int.device(at::kMeta)); + auto meta_expert_offsets = at::empty_strided( + expert_offsets_input.sizes(), + expert_offsets_input.strides(), + options_int.device(at::kMeta)); + auto meta_sf_offsets = at::empty_strided( + sf_offsets_input.sizes(), + sf_offsets_input.strides(), + options_int.device(at::kMeta)); + + ee_meta.bind(fusion->inputs().at(0), meta_mat1); + ee_meta.bind(fusion->inputs().at(1), meta_mat2); + ee_meta.bind(fusion->inputs().at(2), meta_scale1); + ee_meta.bind(fusion->inputs().at(3), meta_scale2); + ee_meta.bind(fusion->inputs().at(4), meta_alpha); + ee_meta.bind(fusion->inputs().at(5), meta_problem_sizes); + ee_meta.bind(fusion->inputs().at(6), meta_expert_offsets); + ee_meta.bind(fusion->inputs().at(7), meta_sf_offsets); + auto meta_out = ee_meta.evaluate(fusion->outputs().at(0)).as(); + + // Checks + EXPECT_TRUE(meta_out.is_meta()); + EXPECT_EQ(meta_out.scalar_type(), at::kBFloat16); + EXPECT_EQ(meta_out.sizes(), real_out.sizes()); + EXPECT_EQ(meta_out.strides(), real_out.strides()); +#else + GTEST_SKIP() << "Test requires CUTLASS support"; +#endif +} + } // namespace nvfuser