Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ target_link_libraries(codegen_internal PUBLIC
)
if (BUILD_CUTLASS AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
target_link_libraries(codegen_internal PUBLIC nvf_cutlass)
target_compile_definitions(codegen_internal PRIVATE "-DNVFUSER_CUTLASS_KERNEL_ENABLED")
target_compile_definitions(codegen_internal PUBLIC "-DNVFUSER_CUTLASS_KERNEL_ENABLED")
endif()

target_link_libraries(codegen_internal PUBLIC LLVM_JIT)
Expand Down
24 changes: 23 additions & 1 deletion csrc/ir/composite_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,6 @@ std::string CutlassNvfp4GroupedMmaOp::toInlineString(int indent_size) const {
std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
#if NVFUSER_CUTLASS_KERNEL_ENABLED
const auto& mat1 = inputs[0].as<at::Tensor>();
const auto& mat2 = inputs[1].as<at::Tensor>();
const auto& scale1 = inputs[2].as<at::Tensor>();
Expand All @@ -1722,6 +1721,29 @@ std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate(
const auto& problem_sizes = inputs[5].as<at::Tensor>();
const auto& expert_offsets = inputs[6].as<at::Tensor>();
const auto& sf_offsets = inputs[7].as<at::Tensor>();

// Meta-device fast path outside of torch version guard
if (mat1.is_meta() || mat2.is_meta() || scale1.is_meta() ||
scale2.is_meta() || alpha.is_meta() || problem_sizes.is_meta() ||
expert_offsets.is_meta() || sf_offsets.is_meta()) {
// For nvfp4_scaled_grouped_mm, the output shape is [M, N]
// where M = mat1.size(0) and N = mat2.size(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the output size n is mat2.size(2).
e.g. if you look at line 1767 below.

std::vector<int64_t> result_sizes = {mat1.size(0), mat2.size(1)};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing dimension validation in meta path - mat1 expected to be 2D and mat2 expected to be 3D, but shape calculation proceeds without checks (unlike GroupedMmaOp::evaluate which validates dimensions)


at::ScalarType out_dtype = data_type_to_aten(out()->dtype());
auto options =
mat1.options().device(c10::Device(c10::kMeta)).dtype(out_dtype);
at::Tensor result = at::empty(result_sizes, options);

if (const auto rfactor_did_idx = getRFactorDeviceDimensionIndex(out());
rfactor_did_idx != -1) {
result = result.unsqueeze(rfactor_did_idx);
}

return {result};
}

#if NVFUSER_CUTLASS_KERNEL_ENABLED
NVF_CHECK(
mat1.scalar_type() == at::ScalarType::Float4_e2m1fn_x2 &&
mat2.scalar_type() == at::ScalarType::Float4_e2m1fn_x2);
Expand Down
127 changes: 127 additions & 0 deletions tests/cpp/test_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,4 +533,131 @@ TEST_F(MetaTest, Matmul1D) {
EXPECT_EQ(meta_out.strides(), real_out.strides());
}

// Test CutlassNvfp4GroupedMmaOp with meta device
TEST_F(MetaTest, CutlassNvfp4GroupedMma) {
#if NVFUSER_CUTLASS_KERNEL_ENABLED
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

// Choose an example where all M, N, K, and K/2 are different:
// M = 128, N = 80, K = 192, K/2 = 96
// Shapes:
// mat1: [M, K/2] = [128, 96] (packed FP4)
// mat2: [G, N, K/2] = [4, 80, 96] (packed FP4)
// output: [M, N] = [128, 80]
auto mat1 = makeContigConcreteTensor({128, 96}, DataType::Float4_e2m1fn_x2);
auto mat2 =
makeContigConcreteTensor({4, 80, 96}, DataType::Float4_e2m1fn_x2);
// Block-scaling factors have last dim K / 16 = 192 / 16 = 12
auto scale1 = makeContigConcreteTensor({128, 12}, DataType::Float8_e4m3fn);
auto scale2 = makeContigConcreteTensor({4, 80, 12}, DataType::Float8_e4m3fn);
auto alpha = makeContigConcreteTensor({4}, DataType::Float);
auto problem_sizes = makeContigConcreteTensor({4, 3}, DataType::Index);
auto expert_offsets = makeContigConcreteTensor({4}, DataType::Index);
auto sf_offsets = makeContigConcreteTensor({4}, DataType::Index);

fusion->addInput(mat1);
fusion->addInput(mat2);
fusion->addInput(scale1);
fusion->addInput(scale2);
fusion->addInput(alpha);
fusion->addInput(problem_sizes);
fusion->addInput(expert_offsets);
fusion->addInput(sf_offsets);

auto result = cutlass_nvfp4_grouped_mm(
mat1,
mat2,
scale1,
scale2,
alpha,
problem_sizes,
expert_offsets,
sf_offsets,
DataType::BFloat16);
fusion->addOutput(result);

// Create real inputs with appropriate data types
auto options_fp4 =
at::TensorOptions().dtype(at::kFloat4_e2m1fn_x2).device(at::kCUDA, 0);
auto options_fp8 =
at::TensorOptions().dtype(at::kFloat8_e4m3fn).device(at::kCUDA, 0);
auto options_fp32 =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, 0);

at::Tensor mat1_input = at::randn({128, 96}, options_fp4);
at::Tensor mat2_input = at::randn({4, 80, 96}, options_fp4);
at::Tensor scale1_input = at::randn({128, 12}, options_fp8);
at::Tensor scale2_input = at::randn({4, 80, 12}, options_fp8);
at::Tensor alpha_input = at::ones({4}, options_fp32);
at::Tensor problem_sizes_input = at::tensor(
{32, 80, 192, 32, 80, 192, 32, 80, 192, 32, 80, 192},
options_int).reshape({4, 3});
at::Tensor expert_offsets_input = at::tensor({0, 32, 64, 96}, options_int);
at::Tensor sf_offsets_input = at::tensor({0, 32, 64, 96}, options_int);

// CUDA path
ExpressionEvaluator ee_cuda;
ee_cuda.bind(fusion->inputs().at(0), mat1_input);
ee_cuda.bind(fusion->inputs().at(1), mat2_input);
ee_cuda.bind(fusion->inputs().at(2), scale1_input);
ee_cuda.bind(fusion->inputs().at(3), scale2_input);
ee_cuda.bind(fusion->inputs().at(4), alpha_input);
ee_cuda.bind(fusion->inputs().at(5), problem_sizes_input);
ee_cuda.bind(fusion->inputs().at(6), expert_offsets_input);
ee_cuda.bind(fusion->inputs().at(7), sf_offsets_input);
auto real_out = ee_cuda.evaluate(fusion->outputs().at(0)).as<at::Tensor>();

// Meta evaluation
ExpressionEvaluator ee_meta;
auto meta_mat1 = at::empty_strided(
mat1_input.sizes(), mat1_input.strides(), options_fp4.device(at::kMeta));
auto meta_mat2 = at::empty_strided(
mat2_input.sizes(), mat2_input.strides(), options_fp4.device(at::kMeta));
auto meta_scale1 = at::empty_strided(
scale1_input.sizes(),
scale1_input.strides(),
options_fp8.device(at::kMeta));
auto meta_scale2 = at::empty_strided(
scale2_input.sizes(),
scale2_input.strides(),
options_fp8.device(at::kMeta));
auto meta_alpha = at::empty_strided(
alpha_input.sizes(),
alpha_input.strides(),
options_fp32.device(at::kMeta));
auto meta_problem_sizes = at::empty_strided(
problem_sizes_input.sizes(),
problem_sizes_input.strides(),
options_int.device(at::kMeta));
auto meta_expert_offsets = at::empty_strided(
expert_offsets_input.sizes(),
expert_offsets_input.strides(),
options_int.device(at::kMeta));
auto meta_sf_offsets = at::empty_strided(
sf_offsets_input.sizes(),
sf_offsets_input.strides(),
options_int.device(at::kMeta));

ee_meta.bind(fusion->inputs().at(0), meta_mat1);
ee_meta.bind(fusion->inputs().at(1), meta_mat2);
ee_meta.bind(fusion->inputs().at(2), meta_scale1);
ee_meta.bind(fusion->inputs().at(3), meta_scale2);
ee_meta.bind(fusion->inputs().at(4), meta_alpha);
ee_meta.bind(fusion->inputs().at(5), meta_problem_sizes);
ee_meta.bind(fusion->inputs().at(6), meta_expert_offsets);
ee_meta.bind(fusion->inputs().at(7), meta_sf_offsets);
auto meta_out = ee_meta.evaluate(fusion->outputs().at(0)).as<at::Tensor>();

// Checks
EXPECT_TRUE(meta_out.is_meta());
EXPECT_EQ(meta_out.scalar_type(), at::kBFloat16);
EXPECT_EQ(meta_out.sizes(), real_out.sizes());
EXPECT_EQ(meta_out.strides(), real_out.strides());
#else
GTEST_SKIP() << "Test requires CUTLASS support";
#endif
}

} // namespace nvfuser
Loading