Skip to content

Commit 311c26d

Browse files
authored
GroupedBlockQuantizeOp PR1: Adding codegen support (#5776)
## Context The series of PRs is trying to enable a single kernel for quantization and layout handling of block scaling factor on grouped tensors. Existing solution for nvfp4 quantization of activation Tensor for grouped_mm relies on two operation: i. BlockQuantizationOp produces scaled_tv and block_scaling_factor. ii. block_scaling_factor needs to be processed by PreprocessGroupedMatmulInputSf in order to satisfy the swizzle layout required by grouped_mm kernels The series of PRs tries to merge the two operation into a single one. ### Stacked PRs #5775 GroupedBlockQuantizationOp PR0: Adding runtime function #5776 GroupedBlockQuantizationOp PR1: Adding codegen support #5777 GroupedBlockQuantizationOp PR2: Adding python API and updating llama4 benchmark ## What's in this PR 1. Adding Fusion IR node GroupedBlockQuantizationOp. The operation is a combination of BlockQuantizationOp and PreprocessGroupedMatmulInputSf, where it inherits all the validation / checks from the two operations. The operation is similar to BlockQuantizationOp, with the exception that: i. The block scaling factor output doesn't have the swizzle logic represented as allocation domain transformations; ii. It takes an additional inputs (input_offsets and output_offsets) to facilitate group indexing, similar to PreprocessGroupedMatmulInputSf. 2. Adding cpp test case for GroupedBlockQuantizationOp.
1 parent 352dcbf commit 311c26d

27 files changed

+946
-205
lines changed

csrc/codegen.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,130 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
18981898
indent() << genCall(fn_call, template_args, func_args) << ";\n";
18991899
}
19001900

1901+
// Special handling of GroupedBlockQuantizationOp to call the runtime
1902+
// function.
1903+
void handle(const GroupedBlockQuantizationOp* grouped_bqop) final {
1904+
// This operator is plumbed down to a runtime function call.
1905+
// One of the assumptions is that the device runtime expects
1906+
// n consecutive inputs per thread. Where n can be 2 or 4 for Float, and 2,
1907+
// 4, or 8 for Half. We achieve this by having the quantized output tv
1908+
// scheduled to have the inner dimension grouped by 2/4/8.
1909+
auto output =
1910+
grouped_bqop->quantizedOutput()->as<kir::TensorIndex>()->view();
1911+
auto output_dtype = output->getDataType();
1912+
1913+
// Extract group size from the loop domain
1914+
int64_t group_size = 1;
1915+
const auto& loop_domain = output->getLoopDomain();
1916+
for (const auto* domain : loop_domain) {
1917+
if (domain->getParallelType() == ParallelType::Group &&
1918+
domain->extent()->isConstInt()) {
1919+
group_size = domain->extent()->evaluate().as<int64_t>();
1920+
break;
1921+
}
1922+
}
1923+
1924+
// Validate group size based on input data type
1925+
const auto input_dtype = grouped_bqop->in()
1926+
->as<kir::TensorIndex>()
1927+
->view()
1928+
->getDataType()
1929+
.value();
1930+
const bool is_half_precision =
1931+
(input_dtype == DataType::BFloat16 || input_dtype == DataType::Half);
1932+
const bool is_valid_group_size = is_half_precision
1933+
? (group_size == 2 || group_size == 4 || group_size == 8)
1934+
: (group_size == 2 || group_size == 4);
1935+
1936+
NVF_ERROR(
1937+
is_valid_group_size,
1938+
"Group size should be ",
1939+
is_half_precision ? "2, 4 or 8" : "2 or 4",
1940+
" for GroupedBlockQuantizationOp with input type ",
1941+
input_dtype,
1942+
". Found: ",
1943+
group_size,
1944+
". Expr: ",
1945+
grouped_bqop->toString());
1946+
1947+
// Build template arguments
1948+
ArgumentBuilder template_args;
1949+
// No global scale is required when quantizing to mxfp8
1950+
if (output_dtype == DataType::Float4_e2m1fn) {
1951+
template_args.arg(grouped_bqop->hasGlobalScale());
1952+
}
1953+
switch (grouped_bqop->layout()) {
1954+
case BlockScalingFactorLayout::Block128x4:
1955+
template_args.arg(32); // block_row_outer
1956+
template_args.arg(4); // block_row_inner
1957+
template_args.arg(4); // block_col
1958+
break;
1959+
default:
1960+
NVF_THROW("unrecognized layout");
1961+
break;
1962+
}
1963+
template_args.arg(group_size); // ITEMS_PER_THREAD
1964+
1965+
// Build function arguments
1966+
ArgumentBuilder func_args;
1967+
func_args.arg(genInline(
1968+
grouped_bqop->input(0)->as<kir::TensorIndex>()->view())); // input data
1969+
func_args.arg(genInline(output)); // quantized output
1970+
func_args.arg(genInline(grouped_bqop->blockScales()
1971+
->as<kir::TensorIndex>()
1972+
->view())); // block scales
1973+
1974+
// generate logical index for runtime function
1975+
func_args.arg(genInline(grouped_bqop->attributeVal(2)));
1976+
func_args.arg(genInline(grouped_bqop->attributeVal(3)));
1977+
func_args.arg("&").append(
1978+
genVariableName(grouped_bqop->inputOffsets()) + "[0]");
1979+
func_args.arg("&").append(
1980+
genVariableName(grouped_bqop->outputOffsets()) + "[0]");
1981+
func_args.arg(genInline(grouped_bqop->k()));
1982+
func_args.arg(genInline(grouped_bqop->g()));
1983+
1984+
if (output_dtype == DataType::Float4_e2m1fn) {
1985+
func_args.arg(
1986+
grouped_bqop->hasGlobalScale()
1987+
? genInline(grouped_bqop->globalScale())
1988+
: "{}");
1989+
}
1990+
1991+
// Add swizzled allocation domain parameters if needed
1992+
// This is always skipped when quantizing to mxfp8
1993+
auto block_scales_tv =
1994+
grouped_bqop->blockScales()->as<kir::TensorIndex>()->view();
1995+
if (block_scales_tv->hasAllocation()) {
1996+
auto logical_domain =
1997+
TensorDomain::noReductions(block_scales_tv->getLogicalDomain());
1998+
auto allocation_domain =
1999+
TensorDomain::noReductions(block_scales_tv->getAllocationDomain());
2000+
2001+
// Swizzled layout: 2D logical -> 5D allocation
2002+
if (logical_domain.size() == 2 && allocation_domain.size() == 5) {
2003+
// Add logical domain extent of the inner dimension
2004+
func_args.arg(genInline(logical_domain[1]->extent()));
2005+
2006+
// Add all allocation domain extents
2007+
for (const auto* alloc_id : allocation_domain) {
2008+
func_args.arg(genInline(alloc_id->extent()));
2009+
}
2010+
}
2011+
}
2012+
2013+
NVF_ERROR(
2014+
output_dtype == DataType::Float4_e2m1fn,
2015+
"only nvfp4 output is implemented");
2016+
2017+
// Generate the function call
2018+
indent() << genCall(
2019+
"bq::grouped_block_quantize_to_nvfp4",
2020+
template_args,
2021+
func_args)
2022+
<< ";\n";
2023+
}
2024+
19012025
std::string genReductionOp(BinaryOpType op_type, DataType data_type) {
19022026
std::stringstream lambda;
19032027
lambda << "[](" << data_type << " &a, " << data_type << " b) "

csrc/device_lower/analysis/non_divisible_split.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,12 @@ NonDivisiblePredicateInfo::NonDivisiblePredicateInfo(Fusion* fusion) {
218218
// mapped to any ID of the input or sibling output.
219219
if (def == nullptr ||
220220
(tv->definition()->isA<BlockQuantizationOp>() &&
221-
tv == tv->definition()->as<BlockQuantizationOp>()->blockScales())) {
221+
tv == tv->definition()->as<BlockQuantizationOp>()->blockScales()) ||
222+
(tv->definition()->isA<GroupedBlockQuantizationOp>() &&
223+
tv ==
224+
tv->definition()
225+
->as<GroupedBlockQuantizationOp>()
226+
->blockScales())) {
222227
continue;
223228
}
224229

csrc/device_lower/analysis/sync_information.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,16 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) {
299299
// sync/predication is handled there.
300300
if ((parallel_type == ParallelType::BIDx ||
301301
parallel_type == ParallelType::TIDx) &&
302-
(consumer->definition()->isA<BlockQuantizationOp>() &&
303-
consumer ==
304-
consumer->definition()
305-
->as<BlockQuantizationOp>()
306-
->blockScales())) {
302+
((consumer->definition()->isA<BlockQuantizationOp>() &&
303+
consumer ==
304+
consumer->definition()
305+
->as<BlockQuantizationOp>()
306+
->blockScales()) ||
307+
(consumer->definition()->isA<GroupedBlockQuantizationOp>() &&
308+
consumer ==
309+
consumer->definition()
310+
->as<GroupedBlockQuantizationOp>()
311+
->blockScales()))) {
307312
continue;
308313
}
309314

csrc/device_lower/analysis/trivial_broadcast.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,17 @@ void ConcretizedBroadcastDomains::handle(BlockQuantizationOp* bq) {
125125
}
126126
}
127127

128+
// GroupedBlockQuantizationOp introduces broadcast domains in the block scales
129+
// output
130+
void ConcretizedBroadcastDomains::handle(GroupedBlockQuantizationOp* bq) {
131+
auto out = bq->blockScales()->as<TensorView>();
132+
auto bcast_id = out->getLogicalDomain().back();
133+
if (bcast_id->isBroadcast()) {
134+
broadcast_origin_map_.emplace(
135+
bcast_id, std::unordered_set<IterDomain*>({bcast_id}));
136+
}
137+
}
138+
128139
void ConcretizedBroadcastDomains::dispatch(Expr* expr) {
129140
IterVisitor::dispatch(expr);
130141

csrc/device_lower/analysis/trivial_broadcast.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class NVF_API ConcretizedBroadcastDomains : private IterVisitor {
5353

5454
void handle(BlockQuantizationOp* bq) final;
5555

56+
void handle(GroupedBlockQuantizationOp* bq) final;
57+
5658
void dispatch(Expr* expr) final;
5759

5860
void markAsConcretized(

csrc/device_lower/pass/index.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,60 @@ void IndexLowering::handle(const BlockQuantizationOp* bqop) {
437437
GpuLower::current()->propagateExprInfo(bqop, back());
438438
}
439439

440+
void IndexLowering::handle(const GroupedBlockQuantizationOp* grouped_bqop) {
441+
const auto in = IrBuilder::create<kir::TensorIndex>(
442+
grouped_bqop->in()->as<TensorView>(), grouped_bqop->fusion()->zeroVal());
443+
444+
const auto out_scales = IrBuilder::create<kir::TensorIndex>(
445+
grouped_bqop->blockScales()->as<TensorView>(),
446+
grouped_bqop->fusion()->zeroVal());
447+
const auto out_quantized = IrBuilder::create<kir::TensorIndex>(
448+
grouped_bqop->quantizedOutput()->as<TensorView>(),
449+
grouped_bqop->fusion()->zeroVal());
450+
451+
std::vector<Val*> logical_index = Index::getConsumerPerDimLogicalIndex(
452+
grouped_bqop->quantizedOutput()->as<TensorView>(), for_loops_);
453+
NVF_ERROR(
454+
logical_index.size() == 2,
455+
"only matrices are supported in GroupedBlockQuantizationOp");
456+
457+
// As part of runtime validation
458+
// make sure that the inner dimension of the input is divisible by block size.
459+
auto* inner_id =
460+
grouped_bqop->in()->as<TensorView>()->getLogicalDomain().back();
461+
Val* is_divisible = SimplifyingIrBuilder::eqExpr(
462+
SimplifyingIrBuilder::modExpr(
463+
inner_id->extent(),
464+
IrBuilder::create<Val>(grouped_bqop->blockSize(), DataType::Index)),
465+
grouped_bqop->fusion()->zeroVal());
466+
467+
NVFUSER_LOWER_VALIDATE(
468+
is_divisible,
469+
"Inner dimension of GroupedBlockQuantizationOp input must be divisible "
470+
"by block "
471+
"size (",
472+
grouped_bqop->blockSize(),
473+
"), but got extent ",
474+
inner_id->extent()->toInlineString(),
475+
" in ",
476+
grouped_bqop->toString());
477+
478+
pushBack(IrBuilder::create<GroupedBlockQuantizationOp>(
479+
out_scales,
480+
out_quantized,
481+
in,
482+
grouped_bqop->inputOffsets(),
483+
grouped_bqop->outputOffsets(),
484+
grouped_bqop->layout(),
485+
grouped_bqop->k(),
486+
grouped_bqop->g(),
487+
grouped_bqop->globalScale(),
488+
grouped_bqop->blockSize(),
489+
logical_index[0],
490+
logical_index[1]));
491+
GpuLower::current()->propagateExprInfo(grouped_bqop, back());
492+
}
493+
440494
void IndexLowering::handle(const SelectOp* sop) {
441495
auto lowered_index = lowerSrcIndex(sop->input(1), sop->output(0));
442496
auto lowered_index_cast = lowered_index;

csrc/device_lower/pass/index.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class IndexLowering : private OptOutConstDispatch {
5858
void handle(const ArgsortOp*) final;
5959
void handle(const TopKOp*) final;
6060
void handle(const BlockQuantizationOp*) final;
61+
void handle(const GroupedBlockQuantizationOp*) final;
6162
void handle(const RNGOp*) final;
6263
void handle(const ReductionOp*) final;
6364
void handle(const GroupedReductionOp*) final;

csrc/device_lower/utils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ bool isTvOp(const Expr* expr) {
152152
ScanOp,
153153
PreprocessGroupedMatmulInputSf,
154154
BlockQuantizationOp,
155+
GroupedBlockQuantizationOp,
155156
LaunchDependentGridOp,
156157
WaitForPriorGridOp,
157158
kir::AllocTMem,

0 commit comments

Comments
 (0)