Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion csrc/ir/composite_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,6 @@ std::string CutlassNvfp4GroupedMmaOp::toInlineString(int indent_size) const {
std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
#if NVFUSER_CUTLASS_KERNEL_ENABLED
const auto& mat1 = inputs[0].as<at::Tensor>();
const auto& mat2 = inputs[1].as<at::Tensor>();
const auto& scale1 = inputs[2].as<at::Tensor>();
Expand All @@ -1722,6 +1721,29 @@ std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate(
const auto& problem_sizes = inputs[5].as<at::Tensor>();
const auto& expert_offsets = inputs[6].as<at::Tensor>();
const auto& sf_offsets = inputs[7].as<at::Tensor>();

// Meta-device fast path outside of torch version guard
if (mat1.is_meta() || mat2.is_meta() || scale1.is_meta() ||
scale2.is_meta() || alpha.is_meta() || problem_sizes.is_meta() ||
expert_offsets.is_meta() || sf_offsets.is_meta()) {
// For nvfp4_scaled_grouped_mm, the output shape is [M, N]
// where M = mat1.size(0) and N = mat2.size(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the output size n is mat2.size(2).
e.g. if you look at line 1767 below.

std::vector<int64_t> result_sizes = {mat1.size(0), mat2.size(1)};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing dimension validation in meta path - mat1 expected to be 2D and mat2 expected to be 3D, but shape calculation proceeds without checks (unlike GroupedMmaOp::evaluate which validates dimensions)


at::ScalarType out_dtype = data_type_to_aten(out()->dtype());
auto options =
mat1.options().device(c10::Device(c10::kMeta)).dtype(out_dtype);
at::Tensor result = at::empty(result_sizes, options);

if (const auto rfactor_did_idx = getRFactorDeviceDimensionIndex(out());
rfactor_did_idx != -1) {
result = result.unsqueeze(rfactor_did_idx);
}

return {result};
}

#if NVFUSER_CUTLASS_KERNEL_ENABLED
NVF_CHECK(
mat1.scalar_type() == at::ScalarType::Float4_e2m1fn_x2 &&
mat2.scalar_type() == at::ScalarType::Float4_e2m1fn_x2);
Expand Down
123 changes: 123 additions & 0 deletions tests/cpp/test_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,4 +533,127 @@ TEST_F(MetaTest, Matmul1D) {
EXPECT_EQ(meta_out.strides(), real_out.strides());
}

// Test CutlassNvfp4GroupedMmaOp with meta device
TEST_F(MetaTest, CutlassNvfp4GroupedMma) {
#if NVFUSER_CUTLASS_KERNEL_ENABLED
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

// mat1: [M, K/2] = [128, 64] (packed FP4)
// mat2: [G, N, K/2] = [4, 128, 64] (packed FP4)
// output: [M, N] = [128, 128]
auto mat1 = makeContigConcreteTensor({128, 64}, DataType::Float4_e2m1fn_x2);
auto mat2 =
makeContigConcreteTensor({4, 128, 64}, DataType::Float4_e2m1fn_x2);
auto scale1 = makeContigConcreteTensor({128, 8}, DataType::Float8_e4m3fn);
auto scale2 = makeContigConcreteTensor({4, 128, 8}, DataType::Float8_e4m3fn);
auto alpha = makeContigConcreteTensor({4}, DataType::Float);
auto problem_sizes = makeContigConcreteTensor({4, 3}, DataType::Index);
auto expert_offsets = makeContigConcreteTensor({4}, DataType::Index);
auto sf_offsets = makeContigConcreteTensor({4}, DataType::Index);

fusion->addInput(mat1);
fusion->addInput(mat2);
fusion->addInput(scale1);
fusion->addInput(scale2);
fusion->addInput(alpha);
fusion->addInput(problem_sizes);
fusion->addInput(expert_offsets);
fusion->addInput(sf_offsets);

auto result = cutlass_nvfp4_grouped_mm(
mat1,
mat2,
scale1,
scale2,
alpha,
problem_sizes,
expert_offsets,
sf_offsets,
DataType::BFloat16);
fusion->addOutput(result);

// Create real inputs with appropriate data types
auto options_fp4 =
at::TensorOptions().dtype(at::kFloat4_e2m1fn_x2).device(at::kCUDA, 0);
auto options_fp8 =
at::TensorOptions().dtype(at::kFloat8_e4m3fn).device(at::kCUDA, 0);
auto options_fp32 =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, 0);

at::Tensor mat1_input = at::randn({128, 64}, options_fp4);
at::Tensor mat2_input = at::randn({4, 128, 64}, options_fp4);
at::Tensor scale1_input = at::randn({128, 8}, options_fp8);
at::Tensor scale2_input = at::randn({4, 128, 8}, options_fp8);
at::Tensor alpha_input = at::ones({4}, options_fp32);
at::Tensor problem_sizes_input = at::tensor(
{{32, 128, 128}, {32, 128, 128}, {32, 128, 128}, {32, 128, 128}},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, here we get away with k == n. Maybe we want to change that just for slightly better test coverage. 😉

options_int);
at::Tensor expert_offsets_input = at::tensor({0, 32, 64, 96}, options_int);
at::Tensor sf_offsets_input = at::tensor({0, 32, 64, 96}, options_int);

// CUDA path
ExpressionEvaluator ee_cuda;
ee_cuda.bind(fusion->inputs().at(0), mat1_input);
ee_cuda.bind(fusion->inputs().at(1), mat2_input);
ee_cuda.bind(fusion->inputs().at(2), scale1_input);
ee_cuda.bind(fusion->inputs().at(3), scale2_input);
ee_cuda.bind(fusion->inputs().at(4), alpha_input);
ee_cuda.bind(fusion->inputs().at(5), problem_sizes_input);
ee_cuda.bind(fusion->inputs().at(6), expert_offsets_input);
ee_cuda.bind(fusion->inputs().at(7), sf_offsets_input);
auto real_out = ee_cuda.evaluate(fusion->outputs().at(0)).as<at::Tensor>();

// Meta evaluation
ExpressionEvaluator ee_meta;
auto meta_mat1 = at::empty_strided(
mat1_input.sizes(), mat1_input.strides(), options_fp4.device(at::kMeta));
auto meta_mat2 = at::empty_strided(
mat2_input.sizes(), mat2_input.strides(), options_fp4.device(at::kMeta));
auto meta_scale1 = at::empty_strided(
scale1_input.sizes(),
scale1_input.strides(),
options_fp8.device(at::kMeta));
auto meta_scale2 = at::empty_strided(
scale2_input.sizes(),
scale2_input.strides(),
options_fp8.device(at::kMeta));
auto meta_alpha = at::empty_strided(
alpha_input.sizes(),
alpha_input.strides(),
options_fp32.device(at::kMeta));
auto meta_problem_sizes = at::empty_strided(
problem_sizes_input.sizes(),
problem_sizes_input.strides(),
options_int.device(at::kMeta));
auto meta_expert_offsets = at::empty_strided(
expert_offsets_input.sizes(),
expert_offsets_input.strides(),
options_int.device(at::kMeta));
auto meta_sf_offsets = at::empty_strided(
sf_offsets_input.sizes(),
sf_offsets_input.strides(),
options_int.device(at::kMeta));

ee_meta.bind(fusion->inputs().at(0), meta_mat1);
ee_meta.bind(fusion->inputs().at(1), meta_mat2);
ee_meta.bind(fusion->inputs().at(2), meta_scale1);
ee_meta.bind(fusion->inputs().at(3), meta_scale2);
ee_meta.bind(fusion->inputs().at(4), meta_alpha);
ee_meta.bind(fusion->inputs().at(5), meta_problem_sizes);
ee_meta.bind(fusion->inputs().at(6), meta_expert_offsets);
ee_meta.bind(fusion->inputs().at(7), meta_sf_offsets);
auto meta_out = ee_meta.evaluate(fusion->outputs().at(0)).as<at::Tensor>();

// Checks
EXPECT_TRUE(meta_out.is_meta());
EXPECT_EQ(meta_out.scalar_type(), at::kBFloat16);
EXPECT_EQ(meta_out.sizes(), real_out.sizes());
EXPECT_EQ(meta_out.strides(), real_out.strides());
#else
GTEST_SKIP() << "Test requires CUTLASS support";
#endif
}

} // namespace nvfuser