diff --git a/benchmarks/python/benchmark_inference.py b/benchmarks/python/benchmark_inference.py index 06fa5488e1b..992b4766889 100644 --- a/benchmarks/python/benchmark_inference.py +++ b/benchmarks/python/benchmark_inference.py @@ -104,8 +104,7 @@ def nvfp4_grouped_mm_translator( nv_offsets = getnv(offsets, fd, lc_to_nv_map) nv_blocksf_offsets = getnv(blockscale_offsets, fd, lc_to_nv_map) nv_problem_sizes = getnv(problem_sizes, fd, lc_to_nv_map) - fp4_mat1, fp8_scale1 = fd.ops.nv_block_quantize(nv_act) - layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf(fp8_scale1, nv_offsets, nv_blocksf_offsets) + fp4_mat1, layout_fp8_scale1 = fd.ops.nv_grouped_block_quantize(nv_act, nv_offsets, nv_blocksf_offsets) out = fd.ops.cutlass_nvfp4_grouped_mm( fp4_mat1, nv_fp4_w, diff --git a/python/python_direct/ops.cpp b/python/python_direct/ops.cpp index 11ffe6a264b..5e20be398ab 100644 --- a/python/python_direct/ops.cpp +++ b/python/python_direct/ops.cpp @@ -3556,6 +3556,55 @@ tuple[TensorView, TensorView] - block_scales: Per-block scaling factors )", py::return_value_policy::reference); + ops.def( + "nv_grouped_block_quantize", + [](TensorView* input, + TensorView* input_offsets, + TensorView* output_offsets, + TensorView* global_scale, + int64_t block_size, + PrimDataType dtype) -> py::tuple { + auto output = groupedBlockQuantize( + input, + input_offsets, + output_offsets, + BlockScalingFactorLayout::Block128x4, + global_scale, + block_size, + dtype); + return py::make_tuple(output.quantized_tensor, output.block_scales); + }, + py::arg("input"), + py::arg("input_offsets"), + py::arg("output_offsets"), + py::arg("global_scale").none(true) = py::none(), + py::arg("block_size") = 16, + py::arg("dtype") = DataType::Float4_e2m1fn, + R"( +Grouped block quantize tensor to NVFP4 format. +Parameters +---------- +input : TensorView + Input tensor to quantize. Must be a floating point tensor. +input_offsets: TensorView + A 1D tensor with length as `number of groups`. + Its value notes the offsets of the starting token in each group for the input tensor view +output_offsets: TensorView + A 1D tensor with length as `number of groups`. + Its value notes the offsets of the starting token in each group for the output tensor view. +global_scale : TensorView, optional +block_size : int, optional + Block size for quantization. Default is 16. +dtype : PrimDataType, optional + Data type of quantized output. Default is DataType::Float4_e2m1fn +Returns +------- +tuple[TensorView, TensorView] + A tuple containing (quantized_tensor, block_scales) where: + - quantized_tensor: Quantized tensor in NVFP4 format + - block_scales: Per-block scaling factors (swizzled in storage) + )", + py::return_value_policy::reference); } void bindRandomOps(py::module_& ops) { diff --git a/python/python_direct/python_translate.cpp b/python/python_direct/python_translate.cpp index 067a7a2a392..bfa1f77f742 100644 --- a/python/python_direct/python_translate.cpp +++ b/python/python_direct/python_translate.cpp @@ -1652,6 +1652,31 @@ class PythonTranslator : public OptInConstDispatch { std::vector{bqop->output(0), bqop->output(1)}); } + void handle(const GroupedBlockQuantizationOp* grouped_bqop) final { + NVF_ERROR(grouped_bqop != nullptr); + visited_vals_.insert(grouped_bqop->output(0)); + visited_vals_.insert(grouped_bqop->output(1)); + + static const auto default_args = std::make_tuple( + KeywordArgumentglobalScale())>{ + "global_scale", nullptr}, + KeywordArgument{"block_size", 16}, + KeywordArgument{"dtype", DataType::Float4_e2m1fn}); + + auto dtype = grouped_bqop->quantizedOutput()->as()->dtype(); + printer_.generateKwargsOperation( + "fd.ops.nv_grouped_block_quantize", + std::make_tuple( + grouped_bqop->in(), + grouped_bqop->inputOffsets(), + grouped_bqop->outputOffsets()), + default_args, + std::make_tuple( + grouped_bqop->globalScale(), grouped_bqop->blockSize(), dtype), + std::vector{ + grouped_bqop->output(0), grouped_bqop->output(1)}); + } + private: //! Convert CPP values to python syntax. PythonPrinter printer_; diff --git a/tests/python/direct/test_narrow_precision.py b/tests/python/direct/test_narrow_precision.py index 8e16b7abc09..efd7f5c7c46 100644 --- a/tests/python/direct/test_narrow_precision.py +++ b/tests/python/direct/test_narrow_precision.py @@ -828,3 +828,170 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None: assert ( large_diff_ratio < 0.1 ), f"Large diff ratio {large_diff_ratio:.2%} exceeds 10% threshold" + + +# This is adopted from the decomposed version test_block_quantize_op_and_layout_op +@pytest.mark.skipif( + is_pre_blackwell(), reason="Only supported on blackwell and newer devices." +) +@pytest.mark.skipif( + not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0" +) +@pytest.mark.parametrize("config", [[1024, 128, 256]]) +@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16]) +def test_grouped_block_quantize_op( + nvfuser_direct_test, + config, + tokens_per_expert_neg_one, + out_dtype, +): + BLOCK_SIZE = 16 + + # k dimension is multiple of 4 * 16 to avoid padding on block scaling factor + m, n, k = config + assert k % 64 == 0 + tokens_per_expert = list(tokens_per_expert_neg_one) + tokens_per_expert.append(m - sum(tokens_per_expert)) + g = len(tokens_per_expert) + + mat1 = torch.randn((m, k), dtype=torch.float32, device="cuda:0") + # format is g, n, k instead of g, k, n + mat2 = torch.randn((g, n, k), dtype=torch.float32, device="cuda:0") + + offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0") + blockscale_offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0") + problem_sizes = torch.empty((g, 3), dtype=torch.int32, device="cuda:0") + + # prepare quantization for mat2 + mat2_gs = torch.empty((g,), dtype=torch.float32, device="cuda:0") + scale2 = torch.empty( + (g, n, k // BLOCK_SIZE), dtype=torch.float8_e4m3fn, device="cuda:0" + ) + + acc_tokens = 0 + rounded_acc_tokens = 0 + mat2_scaled = torch.empty( + (g, n, k // 2), dtype=torch.float4_e2m1fn_x2, device="cuda:0" + ) + + for i in range(g): + global_sf = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2[i].max() + offsets[i] = acc_tokens + blockscale_offsets[i] = rounded_acc_tokens + acc_tokens += tokens_per_expert[i] + # Note: we technically don't need to round up, since k is perfectly sized. + rounded_acc_tokens += round_up(tokens_per_expert[i], 128) + + problem_sizes[i][0] = tokens_per_expert[i] + problem_sizes[i][1] = n + problem_sizes[i][2] = k + + scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2[i], global_sf) + mat2_gs[i] = 1.0 / global_sf + mat2_scaled[i] = scaled_mat2_i + scale2[i] = linear_to_swizzled_128_4(bs_mat2_i) + + def nvfuser_fusion_id0(fd: FusionDefinition) -> None: + mat1 = fd.define_tensor( + shape=[-1, -1], + contiguity=True, + dtype=DataType.Float, + is_cpu=False, + ) + mat2 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=True, + dtype=DataType.Float4_e2m1fn, + is_cpu=False, + stride_order=[2, 0, 1], + ) + scale2 = fd.define_tensor( + shape=[-1, -1, -1], + contiguity=True, + dtype=DataType.Float8_e4m3fn, + is_cpu=False, + ) + alpha = fd.define_tensor( + shape=[-1], contiguity=True, dtype=DataType.Float, is_cpu=False + ) + problem_sizes = fd.define_tensor( + shape=[-1, -1], contiguity=True, dtype=DataType.Int32, is_cpu=False + ) + offsets = fd.define_tensor( + shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False + ) + blockscale_offsets = fd.define_tensor( + shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False + ) + + fp4_mat1, fp8_scale1 = fd.ops.nv_grouped_block_quantize( + mat1, offsets, blockscale_offsets + ) + + out = fd.ops.cutlass_nvfp4_grouped_mm( + fp4_mat1, + mat2, + fp8_scale1, + scale2, + alpha, + problem_sizes, + offsets, + blockscale_offsets, + DataType.BFloat16, + ) + fd.add_output(out) + + inputs = [ + mat1, + mat2_scaled.view(torch.float4_e2m1fn_x2).transpose(-1, -2), + scale2, + mat2_gs, + problem_sizes, + offsets, + blockscale_offsets, + ] + + o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs) + # quantization for activation is needed for reference. + # note: following sglang implementation, not computing global scaling factor for mat1 + # similarly, we don't need to apply mat1_gs to alpha + mat1_gs = torch.ones((g,), dtype=torch.float32, device="cuda:0") + mat1_fp4, scale1 = activation_scale_to_nvfp4( + mat1, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE + ) + o_decomposed_ref = torch.empty(m, n, dtype=torch.bfloat16, device="cuda:0") + for i in range(g): + l = offsets[i] + l_sf = blockscale_offsets[i] + if i == g - 1: + r = m + else: + r = offsets[i + 1] + r_sf = round_up(tokens_per_expert[i], 128) + l_sf + # For some reason I cannot feed mat2_gs[i] as alpha in the torch kernel. + # This triggers a cublas invalid value error. + o_decomposed_ref[l:r] = ( + torch._scaled_mm( + mat1_fp4[l:r], + mat2_scaled[i].transpose(-1, -2), + scale1[l_sf:r_sf], + scale2[i], + None, + None, + torch.bfloat16, + ) + * mat2_gs[i] + ) + + # Validate: nvfuser quantization should match baseline + abs_diff = torch.abs(o[0] - o_decomposed_ref) + max_diff = torch.max(abs_diff) + assert max_diff <= 10.0, f"Max difference {max_diff:.4f} exceeds threshold of 10.0" + + # Check that large differences (> 5.0) are rare (< 10% of elements) + large_diff_count = torch.count_nonzero(torch.gt(abs_diff, 5.0)) + large_diff_ratio = large_diff_count / abs_diff.numel() + assert ( + large_diff_ratio < 0.1 + ), f"Large diff ratio {large_diff_ratio:.2%} exceeds 10% threshold"