-
Notifications
You must be signed in to change notification settings - Fork 75
Add device type meta support for CutlassNvfp4GroupedMmaOp::evaluate
#5695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
1a99a15
6b57c2d
12680c9
5da5e22
e7b6b9c
69e9b62
fda4d73
877d21e
4aa14ca
d3d371a
ba40d2f
4897d9e
2f407c3
8c4cabb
f383c75
2f6a74f
3449af7
65dec62
7b38316
588ec95
6c326f5
8489553
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1713,7 +1713,6 @@ std::string CutlassNvfp4GroupedMmaOp::toInlineString(int indent_size) const { | |
| std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate( | ||
| const ExpressionEvaluator& ee, | ||
| const std::vector<PolymorphicValue>& inputs) const { | ||
| #if NVFUSER_CUTLASS_KERNEL_ENABLED | ||
| const auto& mat1 = inputs[0].as<at::Tensor>(); | ||
| const auto& mat2 = inputs[1].as<at::Tensor>(); | ||
| const auto& scale1 = inputs[2].as<at::Tensor>(); | ||
|
|
@@ -1722,6 +1721,29 @@ std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate( | |
| const auto& problem_sizes = inputs[5].as<at::Tensor>(); | ||
| const auto& expert_offsets = inputs[6].as<at::Tensor>(); | ||
| const auto& sf_offsets = inputs[7].as<at::Tensor>(); | ||
|
|
||
| // Meta-device fast path outside of torch version guard | ||
| if (mat1.is_meta() || mat2.is_meta() || scale1.is_meta() || | ||
| scale2.is_meta() || alpha.is_meta() || problem_sizes.is_meta() || | ||
| expert_offsets.is_meta() || sf_offsets.is_meta()) { | ||
| // For nvfp4_scaled_grouped_mm, the output shape is [M, N] | ||
| // where M = mat1.size(0) and N = mat2.size(1) | ||
| std::vector<int64_t> result_sizes = {mat1.size(0), mat2.size(1)}; | ||
|
||
|
|
||
| at::ScalarType out_dtype = data_type_to_aten(out()->dtype()); | ||
| auto options = | ||
| mat1.options().device(c10::Device(c10::kMeta)).dtype(out_dtype); | ||
| at::Tensor result = at::empty(result_sizes, options); | ||
|
|
||
| if (const auto rfactor_did_idx = getRFactorDeviceDimensionIndex(out()); | ||
| rfactor_did_idx != -1) { | ||
| result = result.unsqueeze(rfactor_did_idx); | ||
| } | ||
|
|
||
| return {result}; | ||
| } | ||
|
|
||
| #if NVFUSER_CUTLASS_KERNEL_ENABLED | ||
| NVF_CHECK( | ||
| mat1.scalar_type() == at::ScalarType::Float4_e2m1fn_x2 && | ||
| mat2.scalar_type() == at::ScalarType::Float4_e2m1fn_x2); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -533,4 +533,127 @@ TEST_F(MetaTest, Matmul1D) { | |
| EXPECT_EQ(meta_out.strides(), real_out.strides()); | ||
| } | ||
|
|
||
| // Test CutlassNvfp4GroupedMmaOp with meta device | ||
| TEST_F(MetaTest, CutlassNvfp4GroupedMma) { | ||
| #if NVFUSER_CUTLASS_KERNEL_ENABLED | ||
| auto fusion = std::make_unique<Fusion>(); | ||
| FusionGuard fg(fusion.get()); | ||
|
|
||
| // mat1: [M, K/2] = [128, 64] (packed FP4) | ||
| // mat2: [G, N, K/2] = [4, 128, 64] (packed FP4) | ||
| // output: [M, N] = [128, 128] | ||
| auto mat1 = makeContigConcreteTensor({128, 64}, DataType::Float4_e2m1fn_x2); | ||
| auto mat2 = | ||
| makeContigConcreteTensor({4, 128, 64}, DataType::Float4_e2m1fn_x2); | ||
| auto scale1 = makeContigConcreteTensor({128, 8}, DataType::Float8_e4m3fn); | ||
| auto scale2 = makeContigConcreteTensor({4, 128, 8}, DataType::Float8_e4m3fn); | ||
| auto alpha = makeContigConcreteTensor({4}, DataType::Float); | ||
| auto problem_sizes = makeContigConcreteTensor({4, 3}, DataType::Index); | ||
| auto expert_offsets = makeContigConcreteTensor({4}, DataType::Index); | ||
| auto sf_offsets = makeContigConcreteTensor({4}, DataType::Index); | ||
|
|
||
| fusion->addInput(mat1); | ||
| fusion->addInput(mat2); | ||
| fusion->addInput(scale1); | ||
| fusion->addInput(scale2); | ||
| fusion->addInput(alpha); | ||
| fusion->addInput(problem_sizes); | ||
| fusion->addInput(expert_offsets); | ||
| fusion->addInput(sf_offsets); | ||
|
|
||
| auto result = cutlass_nvfp4_grouped_mm( | ||
| mat1, | ||
| mat2, | ||
| scale1, | ||
| scale2, | ||
| alpha, | ||
| problem_sizes, | ||
| expert_offsets, | ||
| sf_offsets, | ||
| DataType::BFloat16); | ||
| fusion->addOutput(result); | ||
|
|
||
| // Create real inputs with appropriate data types | ||
| auto options_fp4 = | ||
| at::TensorOptions().dtype(at::kFloat4_e2m1fn_x2).device(at::kCUDA, 0); | ||
| auto options_fp8 = | ||
| at::TensorOptions().dtype(at::kFloat8_e4m3fn).device(at::kCUDA, 0); | ||
| auto options_fp32 = | ||
| at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); | ||
| auto options_int = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, 0); | ||
|
|
||
| at::Tensor mat1_input = at::randn({128, 64}, options_fp4); | ||
| at::Tensor mat2_input = at::randn({4, 128, 64}, options_fp4); | ||
| at::Tensor scale1_input = at::randn({128, 8}, options_fp8); | ||
| at::Tensor scale2_input = at::randn({4, 128, 8}, options_fp8); | ||
| at::Tensor alpha_input = at::ones({4}, options_fp32); | ||
| at::Tensor problem_sizes_input = at::tensor( | ||
| {{32, 128, 128}, {32, 128, 128}, {32, 128, 128}, {32, 128, 128}}, | ||
|
||
| options_int); | ||
| at::Tensor expert_offsets_input = at::tensor({0, 32, 64, 96}, options_int); | ||
| at::Tensor sf_offsets_input = at::tensor({0, 32, 64, 96}, options_int); | ||
|
|
||
| // CUDA path | ||
| ExpressionEvaluator ee_cuda; | ||
| ee_cuda.bind(fusion->inputs().at(0), mat1_input); | ||
| ee_cuda.bind(fusion->inputs().at(1), mat2_input); | ||
| ee_cuda.bind(fusion->inputs().at(2), scale1_input); | ||
| ee_cuda.bind(fusion->inputs().at(3), scale2_input); | ||
| ee_cuda.bind(fusion->inputs().at(4), alpha_input); | ||
| ee_cuda.bind(fusion->inputs().at(5), problem_sizes_input); | ||
| ee_cuda.bind(fusion->inputs().at(6), expert_offsets_input); | ||
| ee_cuda.bind(fusion->inputs().at(7), sf_offsets_input); | ||
| auto real_out = ee_cuda.evaluate(fusion->outputs().at(0)).as<at::Tensor>(); | ||
|
|
||
zasdfgbnm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // Meta evaluation | ||
| ExpressionEvaluator ee_meta; | ||
zasdfgbnm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto meta_mat1 = at::empty_strided( | ||
| mat1_input.sizes(), mat1_input.strides(), options_fp4.device(at::kMeta)); | ||
| auto meta_mat2 = at::empty_strided( | ||
| mat2_input.sizes(), mat2_input.strides(), options_fp4.device(at::kMeta)); | ||
zasdfgbnm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto meta_scale1 = at::empty_strided( | ||
| scale1_input.sizes(), | ||
| scale1_input.strides(), | ||
| options_fp8.device(at::kMeta)); | ||
zasdfgbnm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto meta_scale2 = at::empty_strided( | ||
| scale2_input.sizes(), | ||
| scale2_input.strides(), | ||
| options_fp8.device(at::kMeta)); | ||
| auto meta_alpha = at::empty_strided( | ||
| alpha_input.sizes(), | ||
| alpha_input.strides(), | ||
| options_fp32.device(at::kMeta)); | ||
| auto meta_problem_sizes = at::empty_strided( | ||
| problem_sizes_input.sizes(), | ||
| problem_sizes_input.strides(), | ||
| options_int.device(at::kMeta)); | ||
| auto meta_expert_offsets = at::empty_strided( | ||
| expert_offsets_input.sizes(), | ||
| expert_offsets_input.strides(), | ||
| options_int.device(at::kMeta)); | ||
| auto meta_sf_offsets = at::empty_strided( | ||
| sf_offsets_input.sizes(), | ||
| sf_offsets_input.strides(), | ||
| options_int.device(at::kMeta)); | ||
|
|
||
| ee_meta.bind(fusion->inputs().at(0), meta_mat1); | ||
| ee_meta.bind(fusion->inputs().at(1), meta_mat2); | ||
| ee_meta.bind(fusion->inputs().at(2), meta_scale1); | ||
| ee_meta.bind(fusion->inputs().at(3), meta_scale2); | ||
| ee_meta.bind(fusion->inputs().at(4), meta_alpha); | ||
| ee_meta.bind(fusion->inputs().at(5), meta_problem_sizes); | ||
| ee_meta.bind(fusion->inputs().at(6), meta_expert_offsets); | ||
| ee_meta.bind(fusion->inputs().at(7), meta_sf_offsets); | ||
| auto meta_out = ee_meta.evaluate(fusion->outputs().at(0)).as<at::Tensor>(); | ||
|
|
||
| // Checks | ||
| EXPECT_TRUE(meta_out.is_meta()); | ||
| EXPECT_EQ(meta_out.scalar_type(), at::kBFloat16); | ||
| EXPECT_EQ(meta_out.sizes(), real_out.sizes()); | ||
| EXPECT_EQ(meta_out.strides(), real_out.strides()); | ||
| #else | ||
| GTEST_SKIP() << "Test requires CUTLASS support"; | ||
| #endif | ||
| } | ||
|
|
||
| } // namespace nvfuser | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the output size
nis mat2.size(2).e.g. if you look at line 1767 below.