Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ target_link_libraries(codegen_internal PUBLIC
)
if (BUILD_CUTLASS AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
target_link_libraries(codegen_internal PUBLIC nvf_cutlass)
target_compile_definitions(codegen_internal PRIVATE "-DNVFUSER_CUTLASS_KERNEL_ENABLED")
target_compile_definitions(codegen_internal PUBLIC "-DNVFUSER_CUTLASS_KERNEL_ENABLED")
endif()

target_link_libraries(codegen_internal PUBLIC LLVM_JIT)
Expand Down
24 changes: 23 additions & 1 deletion csrc/ir/composite_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,6 @@ std::string CutlassNvfp4GroupedMmaOp::toInlineString(int indent_size) const {
std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
#if NVFUSER_CUTLASS_KERNEL_ENABLED
const auto& mat1 = inputs[0].as<at::Tensor>();
const auto& mat2 = inputs[1].as<at::Tensor>();
const auto& scale1 = inputs[2].as<at::Tensor>();
Expand All @@ -1722,6 +1721,29 @@ std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate(
const auto& problem_sizes = inputs[5].as<at::Tensor>();
const auto& expert_offsets = inputs[6].as<at::Tensor>();
const auto& sf_offsets = inputs[7].as<at::Tensor>();

// Meta-device fast path outside of torch version guard
if (mat1.is_meta() || mat2.is_meta() || scale1.is_meta() ||
scale2.is_meta() || alpha.is_meta() || problem_sizes.is_meta() ||
expert_offsets.is_meta() || sf_offsets.is_meta()) {
// For nvfp4_scaled_grouped_mm, the output shape is [M, N]
// where M = mat1.size(0) and N = mat2.size(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the output size n is mat2.size(2).
e.g. if you look at line 1767 below.

std::vector<int64_t> result_sizes = {mat1.size(0), mat2.size(1)};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing dimension validation in meta path - mat1 expected to be 2D and mat2 expected to be 3D, but shape calculation proceeds without checks (unlike GroupedMmaOp::evaluate which validates dimensions)


at::ScalarType out_dtype = data_type_to_aten(out()->dtype());
auto options =
mat1.options().device(c10::Device(c10::kMeta)).dtype(out_dtype);
at::Tensor result = at::empty(result_sizes, options);

if (const auto rfactor_did_idx = getRFactorDeviceDimensionIndex(out());
rfactor_did_idx != -1) {
result = result.unsqueeze(rfactor_did_idx);
}

return {result};
}

#if NVFUSER_CUTLASS_KERNEL_ENABLED
NVF_CHECK(
mat1.scalar_type() == at::ScalarType::Float4_e2m1fn_x2 &&
mat2.scalar_type() == at::ScalarType::Float4_e2m1fn_x2);
Expand Down
142 changes: 142 additions & 0 deletions tests/cpp/test_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,4 +533,146 @@ TEST_F(MetaTest, Matmul1D) {
EXPECT_EQ(meta_out.strides(), real_out.strides());
}

// Test CutlassNvfp4GroupedMmaOp with meta device
TEST_F(MetaTest, CutlassNvfp4GroupedMma) {
#if NVFUSER_CUTLASS_KERNEL_ENABLED
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

// Choose an example where all M, N, K, and K/2 are different:
// M = 128, N = 80, K = 192, K/2 = 96
// Shapes:
// mat1: [M, K] = [128, 192] (unpacked dimensions for fusion)
// mat2: [G, N, K] = [4, 80, 192] (unpacked dimensions for fusion)
// output: [M, N] = [128, 80]
// Note: Use unpacked type Float4_e2m1fn with UNPACKED dimensions
auto mat1 = makeContigConcreteTensor({128, 192}, DataType::Float4_e2m1fn);
auto mat2 =
makeContigConcreteTensor({4, 80, 192}, DataType::Float4_e2m1fn);
// Block-scaling factors have last dim K / 16 = 192 / 16 = 12
auto scale1 = makeContigConcreteTensor({128, 12}, DataType::Float8_e4m3fn);
auto scale2 = makeContigConcreteTensor({4, 80, 12}, DataType::Float8_e4m3fn);
auto alpha = makeContigConcreteTensor({4}, DataType::Float);
auto problem_sizes = makeContigConcreteTensor({4, 3}, DataType::Index);
auto expert_offsets = makeContigConcreteTensor({4}, DataType::Index);
auto sf_offsets = makeContigConcreteTensor({4}, DataType::Index);

fusion->addInput(mat1);
fusion->addInput(mat2);
fusion->addInput(scale1);
fusion->addInput(scale2);
fusion->addInput(alpha);
fusion->addInput(problem_sizes);
fusion->addInput(expert_offsets);
fusion->addInput(sf_offsets);

auto result = cutlass_nvfp4_grouped_mm(
mat1,
mat2,
scale1,
scale2,
alpha,
problem_sizes,
expert_offsets,
sf_offsets,
DataType::BFloat16);
fusion->addOutput(result);

// Create real inputs with appropriate data types
auto options_uint8 =
at::TensorOptions().dtype(at::ScalarType::Byte).device(at::kCUDA, 0);
auto options_fp4 =
at::TensorOptions().dtype(at::kFloat4_e2m1fn_x2).device(at::kCUDA, 0);
auto options_fp8 =
at::TensorOptions().dtype(at::kFloat8_e4m3fn).device(at::kCUDA, 0);
auto options_fp32 =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, 0);

// FP4 tensors must be created as UInt8 and viewed as Float4
at::Tensor mat1_uint8 = at::randint(0, 256, {128, 96}, options_uint8);
at::Tensor mat1_input =
mat1_uint8.contiguous().view(at::kFloat4_e2m1fn_x2);
NVF_CHECK(mat1_input.is_contiguous(), "mat1_input is not contiguous!");

at::Tensor mat2_uint8 = at::randint(0, 256, {4, 80, 96}, options_uint8);
at::Tensor mat2_input =
mat2_uint8.contiguous().view(at::kFloat4_e2m1fn_x2);
NVF_CHECK(mat2_input.is_contiguous(), "mat2_input is not contiguous!");
// FP8 tensors can be created from FP32 tensors
at::Tensor scale1_input =
at::randn({128, 12}, options_fp32).to(at::kFloat8_e4m3fn);
at::Tensor scale2_input =
at::randn({4, 80, 12}, options_fp32).to(at::kFloat8_e4m3fn);
at::Tensor alpha_input = at::ones({4}, options_fp32);
// problem_sizes uses unpacked dimensions: M=32, N=80, K=192
at::Tensor problem_sizes_input = at::tensor(
{32, 80, 192, 32, 80, 192, 32, 80, 192, 32, 80, 192},
options_int).reshape({4, 3});
at::Tensor expert_offsets_input = at::tensor({0, 32, 64, 96}, options_int);
at::Tensor sf_offsets_input = at::tensor({0, 32, 64, 96}, options_int);

// CUDA path
ExpressionEvaluator ee_cuda;
ee_cuda.bind(fusion->inputs().at(0), mat1_input);
ee_cuda.bind(fusion->inputs().at(1), mat2_input);
ee_cuda.bind(fusion->inputs().at(2), scale1_input);
ee_cuda.bind(fusion->inputs().at(3), scale2_input);
ee_cuda.bind(fusion->inputs().at(4), alpha_input);
ee_cuda.bind(fusion->inputs().at(5), problem_sizes_input);
ee_cuda.bind(fusion->inputs().at(6), expert_offsets_input);
ee_cuda.bind(fusion->inputs().at(7), sf_offsets_input);
auto real_out = ee_cuda.evaluate(fusion->outputs().at(0)).as<at::Tensor>();

// Meta evaluation
ExpressionEvaluator ee_meta;
auto meta_mat1 = at::empty_strided(
mat1_input.sizes(), mat1_input.strides(), options_fp4.device(at::kMeta));
auto meta_mat2 = at::empty_strided(
mat2_input.sizes(), mat2_input.strides(), options_fp4.device(at::kMeta));
auto meta_scale1 = at::empty_strided(
scale1_input.sizes(),
scale1_input.strides(),
options_fp8.device(at::kMeta));
auto meta_scale2 = at::empty_strided(
scale2_input.sizes(),
scale2_input.strides(),
options_fp8.device(at::kMeta));
auto meta_alpha = at::empty_strided(
alpha_input.sizes(),
alpha_input.strides(),
options_fp32.device(at::kMeta));
auto meta_problem_sizes = at::empty_strided(
problem_sizes_input.sizes(),
problem_sizes_input.strides(),
options_int.device(at::kMeta));
auto meta_expert_offsets = at::empty_strided(
expert_offsets_input.sizes(),
expert_offsets_input.strides(),
options_int.device(at::kMeta));
auto meta_sf_offsets = at::empty_strided(
sf_offsets_input.sizes(),
sf_offsets_input.strides(),
options_int.device(at::kMeta));

ee_meta.bind(fusion->inputs().at(0), meta_mat1);
ee_meta.bind(fusion->inputs().at(1), meta_mat2);
ee_meta.bind(fusion->inputs().at(2), meta_scale1);
ee_meta.bind(fusion->inputs().at(3), meta_scale2);
ee_meta.bind(fusion->inputs().at(4), meta_alpha);
ee_meta.bind(fusion->inputs().at(5), meta_problem_sizes);
ee_meta.bind(fusion->inputs().at(6), meta_expert_offsets);
ee_meta.bind(fusion->inputs().at(7), meta_sf_offsets);
auto meta_out = ee_meta.evaluate(fusion->outputs().at(0)).as<at::Tensor>();

// Checks
EXPECT_TRUE(meta_out.is_meta());
EXPECT_EQ(meta_out.scalar_type(), at::kBFloat16);
EXPECT_EQ(meta_out.sizes(), real_out.sizes());
EXPECT_EQ(meta_out.strides(), real_out.strides());
#else
GTEST_SKIP() << "Test requires CUTLASS support";
#endif
}

} // namespace nvfuser
Loading