Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions benchmarks/python/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def nvfp4_grouped_mm_translator(
nv_offsets = getnv(offsets, fd, lc_to_nv_map)
nv_blocksf_offsets = getnv(blockscale_offsets, fd, lc_to_nv_map)
nv_problem_sizes = getnv(problem_sizes, fd, lc_to_nv_map)
fp4_mat1, fp8_scale1 = fd.ops.nv_block_quantize(nv_act)
layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf(fp8_scale1, nv_offsets, nv_blocksf_offsets)
fp4_mat1, layout_fp8_scale1 = fd.ops.nv_grouped_block_quantize(nv_act, nv_offsets, nv_blocksf_offsets)
out = fd.ops.cutlass_nvfp4_grouped_mm(
fp4_mat1,
nv_fp4_w,
Expand Down
49 changes: 49 additions & 0 deletions python/python_direct/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3556,6 +3556,55 @@ tuple[TensorView, TensorView]
- block_scales: Per-block scaling factors
)",
py::return_value_policy::reference);
ops.def(
"nv_grouped_block_quantize",
[](TensorView* input,
TensorView* input_offsets,
TensorView* output_offsets,
TensorView* global_scale,
int64_t block_size,
PrimDataType dtype) -> py::tuple {
auto output = groupedBlockQuantize(
input,
input_offsets,
output_offsets,
BlockScalingFactorLayout::Block128x4,
global_scale,
block_size,
dtype);
return py::make_tuple(output.quantized_tensor, output.block_scales);
},
py::arg("input"),
py::arg("input_offsets"),
py::arg("output_offsets"),
py::arg("global_scale").none(true) = py::none(),
py::arg("block_size") = 16,
py::arg("dtype") = DataType::Float4_e2m1fn,
R"(
Grouped block quantize tensor to NVFP4 format.
Parameters
----------
input : TensorView
Input tensor to quantize. Must be a floating point tensor.
input_offsets: TensorView
A 1D tensor with length as `number of groups`.
Its value notes the offsets of the starting token in each group for the input tensor view
output_offsets: TensorView
A 1D tensor with length as `number of groups`.
Its value notes the offsets of the starting token in each group for the output tensor view.
global_scale : TensorView, optional
block_size : int, optional
Block size for quantization. Default is 16.
dtype : PrimDataType, optional
Data type of quantized output. Default is DataType::Float4_e2m1fn
Returns
-------
tuple[TensorView, TensorView]
A tuple containing (quantized_tensor, block_scales) where:
- quantized_tensor: Quantized tensor in NVFP4 format
- block_scales: Per-block scaling factors (swizzled in storage)
)",
py::return_value_policy::reference);
}

void bindRandomOps(py::module_& ops) {
Expand Down
25 changes: 25 additions & 0 deletions python/python_direct/python_translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1652,6 +1652,31 @@ class PythonTranslator : public OptInConstDispatch {
std::vector<const nvfuser::Val*>{bqop->output(0), bqop->output(1)});
}

void handle(const GroupedBlockQuantizationOp* grouped_bqop) final {
NVF_ERROR(grouped_bqop != nullptr);
visited_vals_.insert(grouped_bqop->output(0));
visited_vals_.insert(grouped_bqop->output(1));

static const auto default_args = std::make_tuple(
KeywordArgument<decltype(grouped_bqop->globalScale())>{
"global_scale", nullptr},
KeywordArgument<int64_t>{"block_size", 16},
KeywordArgument<DataType>{"dtype", DataType::Float4_e2m1fn});

auto dtype = grouped_bqop->quantizedOutput()->as<TensorView>()->dtype();
printer_.generateKwargsOperation(
"fd.ops.nv_grouped_block_quantize",
std::make_tuple(
grouped_bqop->in(),
grouped_bqop->inputOffsets(),
grouped_bqop->outputOffsets()),
default_args,
std::make_tuple(
grouped_bqop->globalScale(), grouped_bqop->blockSize(), dtype),
std::vector<const nvfuser::Val*>{
grouped_bqop->output(0), grouped_bqop->output(1)});
}

private:
//! Convert CPP values to python syntax.
PythonPrinter printer_;
Expand Down
167 changes: 167 additions & 0 deletions tests/python/direct/test_narrow_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,3 +828,170 @@ def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
assert (
large_diff_ratio < 0.1
), f"Large diff ratio {large_diff_ratio:.2%} exceeds 10% threshold"


# This is adopted from the decomposed version test_block_quantize_op_and_layout_op
@pytest.mark.skipif(
is_pre_blackwell(), reason="Only supported on blackwell and newer devices."
)
@pytest.mark.skipif(
not microarchitecture_is_pre(12), reason="Does not support blackwell compute 12.0"
)
@pytest.mark.parametrize("config", [[1024, 128, 256]])
@pytest.mark.parametrize("tokens_per_expert_neg_one", [[115, 144, 8]])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16])
def test_grouped_block_quantize_op(
nvfuser_direct_test,
config,
tokens_per_expert_neg_one,
out_dtype,
):
BLOCK_SIZE = 16

# k dimension is multiple of 4 * 16 to avoid padding on block scaling factor
m, n, k = config
assert k % 64 == 0
tokens_per_expert = list(tokens_per_expert_neg_one)
tokens_per_expert.append(m - sum(tokens_per_expert))
g = len(tokens_per_expert)

mat1 = torch.randn((m, k), dtype=torch.float32, device="cuda:0")
# format is g, n, k instead of g, k, n
mat2 = torch.randn((g, n, k), dtype=torch.float32, device="cuda:0")

offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0")
blockscale_offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0")
problem_sizes = torch.empty((g, 3), dtype=torch.int32, device="cuda:0")

# prepare quantization for mat2
mat2_gs = torch.empty((g,), dtype=torch.float32, device="cuda:0")
scale2 = torch.empty(
(g, n, k // BLOCK_SIZE), dtype=torch.float8_e4m3fn, device="cuda:0"
)

acc_tokens = 0
rounded_acc_tokens = 0
mat2_scaled = torch.empty(
(g, n, k // 2), dtype=torch.float4_e2m1fn_x2, device="cuda:0"
)

for i in range(g):
global_sf = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2[i].max()
offsets[i] = acc_tokens
blockscale_offsets[i] = rounded_acc_tokens
acc_tokens += tokens_per_expert[i]
# Note: we technically don't need to round up, since k is perfectly sized.
rounded_acc_tokens += round_up(tokens_per_expert[i], 128)

problem_sizes[i][0] = tokens_per_expert[i]
problem_sizes[i][1] = n
problem_sizes[i][2] = k

scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2[i], global_sf)
mat2_gs[i] = 1.0 / global_sf
mat2_scaled[i] = scaled_mat2_i
scale2[i] = linear_to_swizzled_128_4(bs_mat2_i)

def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
mat1 = fd.define_tensor(
shape=[-1, -1],
contiguity=True,
dtype=DataType.Float,
is_cpu=False,
)
mat2 = fd.define_tensor(
shape=[-1, -1, -1],
contiguity=True,
dtype=DataType.Float4_e2m1fn,
is_cpu=False,
stride_order=[2, 0, 1],
)
scale2 = fd.define_tensor(
shape=[-1, -1, -1],
contiguity=True,
dtype=DataType.Float8_e4m3fn,
is_cpu=False,
)
alpha = fd.define_tensor(
shape=[-1], contiguity=True, dtype=DataType.Float, is_cpu=False
)
problem_sizes = fd.define_tensor(
shape=[-1, -1], contiguity=True, dtype=DataType.Int32, is_cpu=False
)
offsets = fd.define_tensor(
shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
)
blockscale_offsets = fd.define_tensor(
shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
)

fp4_mat1, fp8_scale1 = fd.ops.nv_grouped_block_quantize(
mat1, offsets, blockscale_offsets
)

out = fd.ops.cutlass_nvfp4_grouped_mm(
fp4_mat1,
mat2,
fp8_scale1,
scale2,
alpha,
problem_sizes,
offsets,
blockscale_offsets,
DataType.BFloat16,
)
fd.add_output(out)

inputs = [
mat1,
mat2_scaled.view(torch.float4_e2m1fn_x2).transpose(-1, -2),
scale2,
mat2_gs,
problem_sizes,
offsets,
blockscale_offsets,
]

o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)
# quantization for activation is needed for reference.
# note: following sglang implementation, not computing global scaling factor for mat1
# similarly, we don't need to apply mat1_gs to alpha
mat1_gs = torch.ones((g,), dtype=torch.float32, device="cuda:0")
mat1_fp4, scale1 = activation_scale_to_nvfp4(
mat1, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE
)
o_decomposed_ref = torch.empty(m, n, dtype=torch.bfloat16, device="cuda:0")
for i in range(g):
l = offsets[i]
l_sf = blockscale_offsets[i]
if i == g - 1:
r = m
else:
r = offsets[i + 1]
r_sf = round_up(tokens_per_expert[i], 128) + l_sf
# For some reason I cannot feed mat2_gs[i] as alpha in the torch kernel.
# This triggers a cublas invalid value error.
o_decomposed_ref[l:r] = (
torch._scaled_mm(
mat1_fp4[l:r],
mat2_scaled[i].transpose(-1, -2),
scale1[l_sf:r_sf],
scale2[i],
None,
None,
torch.bfloat16,
)
* mat2_gs[i]
)

# Validate: nvfuser quantization should match baseline
abs_diff = torch.abs(o[0] - o_decomposed_ref)
max_diff = torch.max(abs_diff)
assert max_diff <= 10.0, f"Max difference {max_diff:.4f} exceeds threshold of 10.0"

# Check that large differences (> 5.0) are rare (< 10% of elements)
large_diff_count = torch.count_nonzero(torch.gt(abs_diff, 5.0))
large_diff_ratio = large_diff_count / abs_diff.numel()
assert (
large_diff_ratio < 0.1
), f"Large diff ratio {large_diff_ratio:.2%} exceeds 10% threshold"
Loading