Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,120 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
indent() << genCall(fn_call, template_args, func_args) << ";\n";
}

// Special handling of GroupedBlockQuantizationOp to call the runtime
// function.
void handle(const GroupedBlockQuantizationOp* bqop) final {
// This operator is plumbed down to a runtime function call.
// One of the assumptions is that the device runtime expects
// n consecutive inputs per thread. Where n can be 2 or 4 for Float, and 2,
// 4, or 8 for Half. We achieve this by having the quantized output tv
// scheduled to have the inner dimension grouped by 2/4/8.
auto output = bqop->quantizedOutput()->as<kir::TensorIndex>()->view();
auto output_dtype = output->getDataType();

// Extract group size from the loop domain
int64_t group_size = 1;
const auto& loop_domain = output->getLoopDomain();
for (const auto* domain : loop_domain) {
if (domain->getParallelType() == ParallelType::Group &&
domain->extent()->isConstInt()) {
group_size = domain->extent()->evaluate().as<int64_t>();
break;
}
}

// Validate group size based on input data type
const auto input_dtype =
bqop->in()->as<kir::TensorIndex>()->view()->getDataType().value();
const bool is_half_precision =
(input_dtype == DataType::BFloat16 || input_dtype == DataType::Half);
const bool is_valid_group_size = is_half_precision
? (group_size == 2 || group_size == 4 || group_size == 8)
: (group_size == 2 || group_size == 4);

NVF_ERROR(
is_valid_group_size,
"Group size should be ",
is_half_precision ? "2, 4 or 8" : "2 or 4",
" for GroupedBlockQuantizationOp with input type ",
input_dtype,
". Found: ",
group_size,
". Expr: ",
bqop->toString());

// Build template arguments
ArgumentBuilder template_args;
// No global scale is required when quantizing to mxfp8
if (output_dtype == DataType::Float4_e2m1fn) {
template_args.arg(bqop->hasGlobalScale());
}
switch (bqop->layout()) {
case BlockScalingFactorLayout::Block128x4:
template_args.arg(32); // block_row_outer
template_args.arg(4); // block_row_inner
template_args.arg(4); // block_col
break;
default:
NVF_THROW("unrecognized layout");
break;
}
template_args.arg(group_size); // ITEMS_PER_THREAD

// Build function arguments
ArgumentBuilder func_args;
func_args.arg(genInline(
bqop->input(0)->as<kir::TensorIndex>()->view())); // input data
func_args.arg(genInline(output)); // quantized output
func_args.arg(genInline(
bqop->blockScales()->as<kir::TensorIndex>()->view())); // block scales

// generate logical index for runtime function
func_args.arg(genInline(bqop->attributeVal(2)));
func_args.arg(genInline(bqop->attributeVal(3)));
func_args.arg("&").append(genVariableName(bqop->inputOffsets()) + "[0]");
func_args.arg("&").append(genVariableName(bqop->outputOffsets()) + "[0]");
func_args.arg(genInline(bqop->k()));
func_args.arg(genInline(bqop->g()));

if (output_dtype == DataType::Float4_e2m1fn) {
func_args.arg(
bqop->hasGlobalScale() ? genInline(bqop->globalScale()) : "{}");
}

// Add swizzled allocation domain parameters if needed
// This is always skipped when quantizing to mxfp8
auto block_scales_tv = bqop->blockScales()->as<kir::TensorIndex>()->view();
if (block_scales_tv->hasAllocation()) {
auto logical_domain =
TensorDomain::noReductions(block_scales_tv->getLogicalDomain());
auto allocation_domain =
TensorDomain::noReductions(block_scales_tv->getAllocationDomain());

// Swizzled layout: 2D logical -> 5D allocation
if (logical_domain.size() == 2 && allocation_domain.size() == 5) {
// Add logical domain extent of the inner dimension
func_args.arg(genInline(logical_domain[1]->extent()));

// Add all allocation domain extents
for (const auto* alloc_id : allocation_domain) {
func_args.arg(genInline(alloc_id->extent()));
}
}
}

NVF_ERROR(
output_dtype == DataType::Float4_e2m1fn,
"only nvfp4 output is implemented");

// Generate the function call
indent() << genCall(
"bq::grouped_block_quantize_to_nvfp4",
template_args,
func_args)
<< ";\n";
}

std::string genReductionOp(BinaryOpType op_type, DataType data_type) {
std::stringstream lambda;
lambda << "[](" << data_type << " &a, " << data_type << " b) "
Expand Down
7 changes: 6 additions & 1 deletion csrc/device_lower/analysis/non_divisible_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,12 @@ NonDivisiblePredicateInfo::NonDivisiblePredicateInfo(Fusion* fusion) {
// mapped to any ID of the input or sibling output.
if (def == nullptr ||
(tv->definition()->isA<BlockQuantizationOp>() &&
tv == tv->definition()->as<BlockQuantizationOp>()->blockScales())) {
tv == tv->definition()->as<BlockQuantizationOp>()->blockScales()) ||
(tv->definition()->isA<GroupedBlockQuantizationOp>() &&
tv ==
tv->definition()
->as<GroupedBlockQuantizationOp>()
->blockScales())) {
continue;
}

Expand Down
15 changes: 10 additions & 5 deletions csrc/device_lower/analysis/sync_information.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,16 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) {
// sync/predication is handled there.
if ((parallel_type == ParallelType::BIDx ||
parallel_type == ParallelType::TIDx) &&
(consumer->definition()->isA<BlockQuantizationOp>() &&
consumer ==
consumer->definition()
->as<BlockQuantizationOp>()
->blockScales())) {
((consumer->definition()->isA<BlockQuantizationOp>() &&
consumer ==
consumer->definition()
->as<BlockQuantizationOp>()
->blockScales()) ||
(consumer->definition()->isA<GroupedBlockQuantizationOp>() &&
consumer ==
consumer->definition()
->as<GroupedBlockQuantizationOp>()
->blockScales()))) {
continue;
}

Expand Down
11 changes: 11 additions & 0 deletions csrc/device_lower/analysis/trivial_broadcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ void ConcretizedBroadcastDomains::handle(BlockQuantizationOp* bq) {
}
}

// GroupedBlockQuantizationOp introduces broadcast domains in the block scales
// output
void ConcretizedBroadcastDomains::handle(GroupedBlockQuantizationOp* bq) {
auto out = bq->blockScales()->as<TensorView>();
auto bcast_id = out->getLogicalDomain().back();
if (bcast_id->isBroadcast()) {
broadcast_origin_map_.emplace(
bcast_id, std::unordered_set<IterDomain*>({bcast_id}));
}
}

void ConcretizedBroadcastDomains::dispatch(Expr* expr) {
IterVisitor::dispatch(expr);

Expand Down
2 changes: 2 additions & 0 deletions csrc/device_lower/analysis/trivial_broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class NVF_API ConcretizedBroadcastDomains : private IterVisitor {

void handle(BlockQuantizationOp* bq) final;

void handle(GroupedBlockQuantizationOp* bq) final;

void dispatch(Expr* expr) final;

void markAsConcretized(
Expand Down
54 changes: 54 additions & 0 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,60 @@ void IndexLowering::handle(const BlockQuantizationOp* bqop) {
GpuLower::current()->propagateExprInfo(bqop, back());
}

void IndexLowering::handle(const GroupedBlockQuantizationOp* grouped_bqop) {
const auto in = IrBuilder::create<kir::TensorIndex>(
grouped_bqop->in()->as<TensorView>(), grouped_bqop->fusion()->zeroVal());

const auto out_scales = IrBuilder::create<kir::TensorIndex>(
grouped_bqop->blockScales()->as<TensorView>(),
grouped_bqop->fusion()->zeroVal());
const auto out_quantized = IrBuilder::create<kir::TensorIndex>(
grouped_bqop->quantizedOutput()->as<TensorView>(),
grouped_bqop->fusion()->zeroVal());

std::vector<Val*> logical_index = Index::getConsumerPerDimLogicalIndex(
grouped_bqop->quantizedOutput()->as<TensorView>(), for_loops_);
NVF_ERROR(
logical_index.size() == 2,
"only matrices are supported in GroupedBlockQuantizationOp");

// As part of runtime validation
// make sure that the inner dimension of the input is divisible by block size.
auto* inner_id =
grouped_bqop->in()->as<TensorView>()->getLogicalDomain().back();
Val* is_divisible = SimplifyingIrBuilder::eqExpr(
SimplifyingIrBuilder::modExpr(
inner_id->extent(),
IrBuilder::create<Val>(grouped_bqop->blockSize(), DataType::Index)),
grouped_bqop->fusion()->zeroVal());

NVFUSER_LOWER_VALIDATE(
is_divisible,
"Inner dimension of GroupedBlockQuantizationOp input must be divisible "
"by block "
"size (",
grouped_bqop->blockSize(),
"), but got extent ",
inner_id->extent()->toInlineString(),
" in ",
grouped_bqop->toString());

pushBack(IrBuilder::create<GroupedBlockQuantizationOp>(
out_scales,
out_quantized,
in,
grouped_bqop->inputOffsets(),
grouped_bqop->outputOffsets(),
grouped_bqop->layout(),
grouped_bqop->k(),
grouped_bqop->g(),
grouped_bqop->globalScale(),
16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the block_size parameter is hardcoded to 16, but it should use grouped_bqop->blockSize() to respect the original operation's block_size parameter

Suggested change
16,
grouped_bqop->blockSize(),

The GroupedBlockQuantizationOp constructor accepts a block_size parameter (line 1063 in composite_nodes.h), and the operation stores this value as an attribute accessible via blockSize() method (line 1081-1083). However, during index lowering, this value is being replaced with a hardcoded 16, which means any non-default block size specified by the user will be ignored.

logical_index[0],
logical_index[1]));
GpuLower::current()->propagateExprInfo(grouped_bqop, back());
}

void IndexLowering::handle(const SelectOp* sop) {
auto lowered_index = lowerSrcIndex(sop->input(1), sop->output(0));
auto lowered_index_cast = lowered_index;
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class IndexLowering : private OptOutConstDispatch {
void handle(const ArgsortOp*) final;
void handle(const TopKOp*) final;
void handle(const BlockQuantizationOp*) final;
void handle(const GroupedBlockQuantizationOp*) final;
void handle(const RNGOp*) final;
void handle(const ReductionOp*) final;
void handle(const GroupedReductionOp*) final;
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ bool isTvOp(const Expr* expr) {
ScanOp,
PreprocessGroupedMatmulInputSf,
BlockQuantizationOp,
GroupedBlockQuantizationOp,
LaunchDependentGridOp,
WaitForPriorGridOp,
kir::AllocTMem,
Expand Down
Loading
Loading