Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ target_link_libraries(codegen_internal PUBLIC
)
if (BUILD_CUTLASS AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
target_link_libraries(codegen_internal PUBLIC nvf_cutlass)
target_compile_definitions(codegen_internal PRIVATE "-DNVFUSER_CUTLASS_KERNEL_ENABLED")
target_compile_definitions(codegen_internal PUBLIC "-DNVFUSER_CUTLASS_KERNEL_ENABLED")
endif()

target_link_libraries(codegen_internal PUBLIC LLVM_JIT)
Expand Down
26 changes: 25 additions & 1 deletion csrc/ir/composite_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,6 @@ std::string CutlassNvfp4GroupedMmaOp::toInlineString(int indent_size) const {
std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
#if NVFUSER_CUTLASS_KERNEL_ENABLED
const auto& mat1 = inputs[0].as<at::Tensor>();
const auto& mat2 = inputs[1].as<at::Tensor>();
const auto& scale1 = inputs[2].as<at::Tensor>();
Expand All @@ -1722,6 +1721,31 @@ std::vector<PolymorphicValue> CutlassNvfp4GroupedMmaOp::evaluate(
const auto& problem_sizes = inputs[5].as<at::Tensor>();
const auto& expert_offsets = inputs[6].as<at::Tensor>();
const auto& sf_offsets = inputs[7].as<at::Tensor>();

// Meta-device fast path outside of torch version guard
if (mat1.is_meta() || mat2.is_meta() || scale1.is_meta() ||
scale2.is_meta() || alpha.is_meta() || problem_sizes.is_meta() ||
expert_offsets.is_meta() || sf_offsets.is_meta()) {
// For nvfp4_scaled_grouped_mm, the output shape is [M, N]
// where M = mat1.size(0) and N = mat2.size(2).
// Note: CutlassNvfp4GroupedMmaOp expects mat2 to be [G, K/2, N] (packed) at
// runtime and transposes it before calling into CUTLASS.
std::vector<int64_t> result_sizes = {mat1.size(0), mat2.size(2)};

at::ScalarType out_dtype = data_type_to_aten(out()->dtype());
auto options =
mat1.options().device(c10::Device(c10::kMeta)).dtype(out_dtype);
at::Tensor result = at::empty(result_sizes, options);

if (const auto rfactor_did_idx = getRFactorDeviceDimensionIndex(out());
rfactor_did_idx != -1) {
result = result.unsqueeze(rfactor_did_idx);
}

return {result};
}

#if NVFUSER_CUTLASS_KERNEL_ENABLED
NVF_CHECK(
mat1.scalar_type() == at::ScalarType::Float4_e2m1fn_x2 &&
mat2.scalar_type() == at::ScalarType::Float4_e2m1fn_x2);
Expand Down
201 changes: 201 additions & 0 deletions tests/cpp/test_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <tests/cpp/utils.h>

#include <array>
#include <cstdint>
#include <string>
#include <tuple>
#include <utility>
Expand Down Expand Up @@ -533,4 +534,204 @@ TEST_F(MetaTest, Matmul1D) {
EXPECT_EQ(meta_out.strides(), real_out.strides());
}

// Test CutlassNvfp4GroupedMmaOp with meta device
TEST_F(MetaTest, CutlassNvfp4GroupedMma) {
#if NVFUSER_CUTLASS_KERNEL_ENABLED
// NVFP4 CUTLASS grouped MMA relies on SM100+ kernels (Blackwell) and TMA.
// On older GPUs the CUTLASS kernel may compile but fail at runtime when
// initializing TMA descriptors (e.g., status 801).
if (!deviceMajorMinorCheck(10)) {
GTEST_SKIP() << "CutlassNvfp4GroupedMma requires SM100+ (compute "
"capability >= 10.0)";
}

auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

// Choose an example where all M, N, K, and K/2 are different, and CUTLASS
// alignment constraints are satisfied. Also "bigger" shapes tend to be more
// robust for TMA/tile constraints:
// num_experts = 4
// m_per_expert = 64 => M = 256
// N = 96
// K = 320, K/2 = 160
constexpr int64_t G = 4;
constexpr int64_t M_PER_EXPERT = 64;
constexpr int64_t M = G * M_PER_EXPERT;
constexpr int64_t N = 96;
constexpr int64_t K = 320;
constexpr int64_t K_DIV_2 = K / 2; // 160
constexpr int64_t K_DIV_16 = K / 16; // 20

// Shapes:
// mat1: [M, K] = [256, 320] (logical/unpacked shape)
// mat2: [G, K, N] = [4, 320, 96] (logical/unpacked shape)
// output: [M, N] = [256, 96]
// Note: Packed dtype Float4_e2m1fn_x2 is not allowed in IR. We use the
// unpacked dtype (Float4_e2m1fn) and the logical K dimension. When binding a
// packed ATen tensor (K/2), the last dim is adjusted (K/2 -> K).
auto mat1 = makeContigConcreteTensor({M, K}, DataType::Float4_e2m1fn);
// mat2 is expected as [G, K, N] logically. When binding packed FP4 inputs
// (Float4_e2m1fn_x2) with shape [G, K/2, N], the evaluator adjusts the K dim
// (K/2 -> K). We also set a stride order so the adjustment applies to the K
// dimension even though it isn't the last logical dimension.
auto mat2 = TensorViewBuilder()
.shape({G, K, N})
.dtype(DataType::Float4_e2m1fn)
.contiguity({true, true, true})
.strideOrder({2, 0, 1})
.build();
// Block-scaling factors have last dim K / 16
auto scale1 =
makeContigConcreteTensor({M, K_DIV_16}, DataType::Float8_e4m3fn);
auto scale2 =
makeContigConcreteTensor({G, N, K_DIV_16}, DataType::Float8_e4m3fn);
auto alpha = makeContigConcreteTensor({G}, DataType::Float);
auto problem_sizes = makeContigConcreteTensor({G, 3}, DataType::Index);
auto expert_offsets = makeContigConcreteTensor({G}, DataType::Index);
auto sf_offsets = makeContigConcreteTensor({G}, DataType::Index);

fusion->addInput(mat1);
fusion->addInput(mat2);
fusion->addInput(scale1);
fusion->addInput(scale2);
fusion->addInput(alpha);
fusion->addInput(problem_sizes);
fusion->addInput(expert_offsets);
fusion->addInput(sf_offsets);

auto result = cutlass_nvfp4_grouped_mm(
mat1,
mat2,
scale1,
scale2,
alpha,
problem_sizes,
expert_offsets,
sf_offsets,
DataType::BFloat16);
fusion->addOutput(result);

// Create real inputs with appropriate data types
auto options_uint8 =
at::TensorOptions().dtype(at::ScalarType::Byte).device(at::kCUDA, 0);
auto options_fp4 =
at::TensorOptions().dtype(at::kFloat4_e2m1fn_x2).device(at::kCUDA, 0);
auto options_fp8 =
at::TensorOptions().dtype(at::kFloat8_e4m3fn).device(at::kCUDA, 0);
auto options_fp32 =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, 0);

// FP4 tensors must be created as UInt8 and viewed as Float4
at::Tensor mat1_uint8 = at::randint(0, 256, {M, K_DIV_2}, options_uint8);
at::Tensor mat1_input = mat1_uint8.contiguous().view(at::kFloat4_e2m1fn_x2);

// IMPORTANT: CutlassNvfp4GroupedMmaOp::evaluate transposes mat2 before
// calling into CUTLASS, which requires the transposed result to be
// contiguous. Construct mat2 as a transpose-view of a contiguous [G, N, K/2]
// tensor so that CutlassNvfp4GroupedMmaOp::evaluate's internal transpose
// produces a contiguous tensor.
at::Tensor mat2_base_uint8 =
at::randint(0, 256, {G, N, K_DIV_2}, options_uint8);
at::Tensor mat2_base =
mat2_base_uint8.contiguous().view(at::kFloat4_e2m1fn_x2);
at::Tensor mat2_input = mat2_base.transpose(-1, -2); // [G, K/2, N]
// FP8 tensors can be created from FP32 tensors
at::Tensor scale1_input =
at::randn({M, K_DIV_16}, options_fp32).to(at::kFloat8_e4m3fn);
at::Tensor scale2_input =
at::randn({G, N, K_DIV_16}, options_fp32).to(at::kFloat8_e4m3fn);
at::Tensor alpha_input = at::ones({G}, options_fp32);
// problem_sizes uses unpacked dimensions per expert: (m_i, n, k)
const std::vector<int64_t> problem_sizes_vec{
// expert 0
M_PER_EXPERT,
N,
K,
// expert 1
M_PER_EXPERT,
N,
K,
// expert 2
M_PER_EXPERT,
N,
K,
// expert 3
M_PER_EXPERT,
N,
K,
};
at::Tensor problem_sizes_input =
at::tensor(problem_sizes_vec, options_int).reshape({G, 3});
const std::vector<int64_t> expert_offsets_vec{
0, M_PER_EXPERT, 2 * M_PER_EXPERT, 3 * M_PER_EXPERT};
at::Tensor expert_offsets_input = at::tensor(expert_offsets_vec, options_int);
const std::vector<int64_t> sf_offsets_vec{
0, M_PER_EXPERT, 2 * M_PER_EXPERT, 3 * M_PER_EXPERT};
at::Tensor sf_offsets_input = at::tensor(sf_offsets_vec, options_int);

// CUDA path
ExpressionEvaluator ee_cuda;
ee_cuda.bind(fusion->inputs().at(0), mat1_input);
ee_cuda.bind(fusion->inputs().at(1), mat2_input);
ee_cuda.bind(fusion->inputs().at(2), scale1_input);
ee_cuda.bind(fusion->inputs().at(3), scale2_input);
ee_cuda.bind(fusion->inputs().at(4), alpha_input);
ee_cuda.bind(fusion->inputs().at(5), problem_sizes_input);
ee_cuda.bind(fusion->inputs().at(6), expert_offsets_input);
ee_cuda.bind(fusion->inputs().at(7), sf_offsets_input);
auto real_out = ee_cuda.evaluate(fusion->outputs().at(0)).as<at::Tensor>();

// Meta evaluation
ExpressionEvaluator ee_meta;
auto meta_mat1 = at::empty_strided(
mat1_input.sizes(), mat1_input.strides(), options_fp4.device(at::kMeta));
auto meta_mat2 = at::empty_strided(
mat2_input.sizes(), mat2_input.strides(), options_fp4.device(at::kMeta));
auto meta_scale1 = at::empty_strided(
scale1_input.sizes(),
scale1_input.strides(),
options_fp8.device(at::kMeta));
auto meta_scale2 = at::empty_strided(
scale2_input.sizes(),
scale2_input.strides(),
options_fp8.device(at::kMeta));
auto meta_alpha = at::empty_strided(
alpha_input.sizes(),
alpha_input.strides(),
options_fp32.device(at::kMeta));
auto meta_problem_sizes = at::empty_strided(
problem_sizes_input.sizes(),
problem_sizes_input.strides(),
options_int.device(at::kMeta));
auto meta_expert_offsets = at::empty_strided(
expert_offsets_input.sizes(),
expert_offsets_input.strides(),
options_int.device(at::kMeta));
auto meta_sf_offsets = at::empty_strided(
sf_offsets_input.sizes(),
sf_offsets_input.strides(),
options_int.device(at::kMeta));

ee_meta.bind(fusion->inputs().at(0), meta_mat1);
ee_meta.bind(fusion->inputs().at(1), meta_mat2);
ee_meta.bind(fusion->inputs().at(2), meta_scale1);
ee_meta.bind(fusion->inputs().at(3), meta_scale2);
ee_meta.bind(fusion->inputs().at(4), meta_alpha);
ee_meta.bind(fusion->inputs().at(5), meta_problem_sizes);
ee_meta.bind(fusion->inputs().at(6), meta_expert_offsets);
ee_meta.bind(fusion->inputs().at(7), meta_sf_offsets);
auto meta_out = ee_meta.evaluate(fusion->outputs().at(0)).as<at::Tensor>();

// Checks
EXPECT_TRUE(meta_out.is_meta());
EXPECT_EQ(meta_out.scalar_type(), at::kBFloat16);
EXPECT_EQ(meta_out.sizes(), real_out.sizes());
EXPECT_EQ(meta_out.strides(), real_out.strides());
#else
GTEST_SKIP() << "Test requires CUTLASS support";
#endif
}

} // namespace nvfuser