diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index e12ed646540..9b89216aae8 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1898,6 +1898,120 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << genCall(fn_call, template_args, func_args) << ";\n"; } + // Special handling of GroupedBlockQuantizationOp to call the runtime + // function. + void handle(const GroupedBlockQuantizationOp* bqop) final { + // This operator is plumbed down to a runtime function call. + // One of the assumptions is that the device runtime expects + // n consecutive inputs per thread. Where n can be 2 or 4 for Float, and 2, + // 4, or 8 for Half. We achieve this by having the quantized output tv + // scheduled to have the inner dimension grouped by 2/4/8. + auto output = bqop->quantizedOutput()->as()->view(); + auto output_dtype = output->getDataType(); + + // Extract group size from the loop domain + int64_t group_size = 1; + const auto& loop_domain = output->getLoopDomain(); + for (const auto* domain : loop_domain) { + if (domain->getParallelType() == ParallelType::Group && + domain->extent()->isConstInt()) { + group_size = domain->extent()->evaluate().as(); + break; + } + } + + // Validate group size based on input data type + const auto input_dtype = + bqop->in()->as()->view()->getDataType().value(); + const bool is_half_precision = + (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half); + const bool is_valid_group_size = is_half_precision + ? (group_size == 2 || group_size == 4 || group_size == 8) + : (group_size == 2 || group_size == 4); + + NVF_ERROR( + is_valid_group_size, + "Group size should be ", + is_half_precision ? "2, 4 or 8" : "2 or 4", + " for GroupedBlockQuantizationOp with input type ", + input_dtype, + ". Found: ", + group_size, + ". Expr: ", + bqop->toString()); + + // Build template arguments + ArgumentBuilder template_args; + // No global scale is required when quantizing to mxfp8 + if (output_dtype == DataType::Float4_e2m1fn) { + template_args.arg(bqop->hasGlobalScale()); + } + switch (bqop->layout()) { + case BlockScalingFactorLayout::Block128x4: + template_args.arg(32); // block_row_outer + template_args.arg(4); // block_row_inner + template_args.arg(4); // block_col + break; + default: + NVF_THROW("unrecognized layout"); + break; + } + template_args.arg(group_size); // ITEMS_PER_THREAD + + // Build function arguments + ArgumentBuilder func_args; + func_args.arg(genInline( + bqop->input(0)->as()->view())); // input data + func_args.arg(genInline(output)); // quantized output + func_args.arg(genInline( + bqop->blockScales()->as()->view())); // block scales + + // generate logical index for runtime function + func_args.arg(genInline(bqop->attributeVal(2))); + func_args.arg(genInline(bqop->attributeVal(3))); + func_args.arg("&").append(genVariableName(bqop->inputOffsets()) + "[0]"); + func_args.arg("&").append(genVariableName(bqop->outputOffsets()) + "[0]"); + func_args.arg(genInline(bqop->k())); + func_args.arg(genInline(bqop->g())); + + if (output_dtype == DataType::Float4_e2m1fn) { + func_args.arg( + bqop->hasGlobalScale() ? genInline(bqop->globalScale()) : "{}"); + } + + // Add swizzled allocation domain parameters if needed + // This is always skipped when quantizing to mxfp8 + auto block_scales_tv = bqop->blockScales()->as()->view(); + if (block_scales_tv->hasAllocation()) { + auto logical_domain = + TensorDomain::noReductions(block_scales_tv->getLogicalDomain()); + auto allocation_domain = + TensorDomain::noReductions(block_scales_tv->getAllocationDomain()); + + // Swizzled layout: 2D logical -> 5D allocation + if (logical_domain.size() == 2 && allocation_domain.size() == 5) { + // Add logical domain extent of the inner dimension + func_args.arg(genInline(logical_domain[1]->extent())); + + // Add all allocation domain extents + for (const auto* alloc_id : allocation_domain) { + func_args.arg(genInline(alloc_id->extent())); + } + } + } + + NVF_ERROR( + output_dtype == DataType::Float4_e2m1fn, + "only nvfp4 output is implemented"); + + // Generate the function call + indent() << genCall( + "bq::grouped_block_quantize_to_nvfp4", + template_args, + func_args) + << ";\n"; + } + std::string genReductionOp(BinaryOpType op_type, DataType data_type) { std::stringstream lambda; lambda << "[](" << data_type << " &a, " << data_type << " b) " diff --git a/csrc/device_lower/analysis/non_divisible_split.cpp b/csrc/device_lower/analysis/non_divisible_split.cpp index 0effe1618e1..8a46bfa352e 100644 --- a/csrc/device_lower/analysis/non_divisible_split.cpp +++ b/csrc/device_lower/analysis/non_divisible_split.cpp @@ -218,7 +218,12 @@ NonDivisiblePredicateInfo::NonDivisiblePredicateInfo(Fusion* fusion) { // mapped to any ID of the input or sibling output. if (def == nullptr || (tv->definition()->isA() && - tv == tv->definition()->as()->blockScales())) { + tv == tv->definition()->as()->blockScales()) || + (tv->definition()->isA() && + tv == + tv->definition() + ->as() + ->blockScales())) { continue; } diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 178408860e2..bba2ff3e2a7 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -299,11 +299,16 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { // sync/predication is handled there. if ((parallel_type == ParallelType::BIDx || parallel_type == ParallelType::TIDx) && - (consumer->definition()->isA() && - consumer == - consumer->definition() - ->as() - ->blockScales())) { + ((consumer->definition()->isA() && + consumer == + consumer->definition() + ->as() + ->blockScales()) || + (consumer->definition()->isA() && + consumer == + consumer->definition() + ->as() + ->blockScales()))) { continue; } diff --git a/csrc/device_lower/analysis/trivial_broadcast.cpp b/csrc/device_lower/analysis/trivial_broadcast.cpp index 34655598e25..2c0938f79ec 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.cpp +++ b/csrc/device_lower/analysis/trivial_broadcast.cpp @@ -125,6 +125,17 @@ void ConcretizedBroadcastDomains::handle(BlockQuantizationOp* bq) { } } +// GroupedBlockQuantizationOp introduces broadcast domains in the block scales +// output +void ConcretizedBroadcastDomains::handle(GroupedBlockQuantizationOp* bq) { + auto out = bq->blockScales()->as(); + auto bcast_id = out->getLogicalDomain().back(); + if (bcast_id->isBroadcast()) { + broadcast_origin_map_.emplace( + bcast_id, std::unordered_set({bcast_id})); + } +} + void ConcretizedBroadcastDomains::dispatch(Expr* expr) { IterVisitor::dispatch(expr); diff --git a/csrc/device_lower/analysis/trivial_broadcast.h b/csrc/device_lower/analysis/trivial_broadcast.h index b002d73e976..d17a0d0bd94 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.h +++ b/csrc/device_lower/analysis/trivial_broadcast.h @@ -53,6 +53,8 @@ class NVF_API ConcretizedBroadcastDomains : private IterVisitor { void handle(BlockQuantizationOp* bq) final; + void handle(GroupedBlockQuantizationOp* bq) final; + void dispatch(Expr* expr) final; void markAsConcretized( diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index a2d2877402e..e70f092fae8 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -437,6 +437,60 @@ void IndexLowering::handle(const BlockQuantizationOp* bqop) { GpuLower::current()->propagateExprInfo(bqop, back()); } +void IndexLowering::handle(const GroupedBlockQuantizationOp* grouped_bqop) { + const auto in = IrBuilder::create( + grouped_bqop->in()->as(), grouped_bqop->fusion()->zeroVal()); + + const auto out_scales = IrBuilder::create( + grouped_bqop->blockScales()->as(), + grouped_bqop->fusion()->zeroVal()); + const auto out_quantized = IrBuilder::create( + grouped_bqop->quantizedOutput()->as(), + grouped_bqop->fusion()->zeroVal()); + + std::vector logical_index = Index::getConsumerPerDimLogicalIndex( + grouped_bqop->quantizedOutput()->as(), for_loops_); + NVF_ERROR( + logical_index.size() == 2, + "only matrices are supported in GroupedBlockQuantizationOp"); + + // As part of runtime validation + // make sure that the inner dimension of the input is divisible by block size. + auto* inner_id = + grouped_bqop->in()->as()->getLogicalDomain().back(); + Val* is_divisible = SimplifyingIrBuilder::eqExpr( + SimplifyingIrBuilder::modExpr( + inner_id->extent(), + IrBuilder::create(grouped_bqop->blockSize(), DataType::Index)), + grouped_bqop->fusion()->zeroVal()); + + NVFUSER_LOWER_VALIDATE( + is_divisible, + "Inner dimension of GroupedBlockQuantizationOp input must be divisible " + "by block " + "size (", + grouped_bqop->blockSize(), + "), but got extent ", + inner_id->extent()->toInlineString(), + " in ", + grouped_bqop->toString()); + + pushBack(IrBuilder::create( + out_scales, + out_quantized, + in, + grouped_bqop->inputOffsets(), + grouped_bqop->outputOffsets(), + grouped_bqop->layout(), + grouped_bqop->k(), + grouped_bqop->g(), + grouped_bqop->globalScale(), + grouped_bqop->blockSize(), + logical_index[0], + logical_index[1])); + GpuLower::current()->propagateExprInfo(grouped_bqop, back()); +} + void IndexLowering::handle(const SelectOp* sop) { auto lowered_index = lowerSrcIndex(sop->input(1), sop->output(0)); auto lowered_index_cast = lowered_index; diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index b47a9f9b36a..be4a33d25aa 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -58,6 +58,7 @@ class IndexLowering : private OptOutConstDispatch { void handle(const ArgsortOp*) final; void handle(const TopKOp*) final; void handle(const BlockQuantizationOp*) final; + void handle(const GroupedBlockQuantizationOp*) final; void handle(const RNGOp*) final; void handle(const ReductionOp*) final; void handle(const GroupedReductionOp*) final; diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index efb5933aee7..d305348384a 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -152,6 +152,7 @@ bool isTvOp(const Expr* expr) { ScanOp, PreprocessGroupedMatmulInputSf, BlockQuantizationOp, + GroupedBlockQuantizationOp, LaunchDependentGridOp, WaitForPriorGridOp, kir::AllocTMem, diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index f786f6847ac..8be4cdd0b76 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -756,6 +756,8 @@ class ExprValidator : public OptOutDispatch { ". Expr: ", bqop->toString()); + // [ NOTE: check scheduling requirements for block quantization ] + // // M K // │ │ // ▼ ▼ @@ -862,6 +864,198 @@ class ExprValidator : public OptOutDispatch { "contiguous IDs from the logical domain for BlockQuantizationOp: ", quantized_output->toString()); } + + void handle(GroupedBlockQuantizationOp* bqop) final { + auto inp_tv = bqop->input(0)->as(); + auto quantized_output = bqop->quantizedOutput()->as(); + auto block_scaling_factor = bqop->blockScales()->as(); + auto output_dtype = quantized_output->dtype(); + + NVF_ERROR_EQ( + inp_tv->getMemoryType(), + MemoryType::Local, + "Input must be a local memory tensor. Found: ", + inp_tv->getMemoryType()); + + NVF_ERROR_EQ( + quantized_output->getMemoryType(), + MemoryType::Local, + "Quantized output must be a local memory tensor. Found: ", + quantized_output->getMemoryType()); + + NVF_ERROR_EQ( + block_scaling_factor->getMemoryType(), + MemoryType::Global, + "Block scaling factor must be a global memory tensor. Found: ", + block_scaling_factor->getMemoryType()); + + NVF_ERROR( + output_dtype != DataType::Float8_e4m3fn, + "output of Float8_e4m3fn is not yet implemented"); + + if (bqop->hasGlobalScale()) { + auto global_scale = bqop->globalScale()->as(); + + NVF_ERROR_EQ( + global_scale->getMemoryType(), + MemoryType::Global, + "Global scaling factor must be a global memory tensor. Found: ", + global_scale->getMemoryType()); + + NVF_ERROR_EQ( + global_scale->dtype(), + DataType::Float, + "Global scaling factor must be of type float. Found: ", + global_scale->dtype()); + } + + // Outputs have the same allocation domain + // as the logical domain - no allocation domain. + NVF_ERROR( + !quantized_output->hasAllocation(), + "Quantized output must not have an allocation domain."); + + IterDomain* grouped_id = nullptr; + IterDomain* thread_x = nullptr; + IterDomain* block_x = nullptr; + IterDomain* thread_z = nullptr; + IterDomain* block_z = nullptr; + + for (const auto& loop_id : quantized_output->getLoopDomain()) { + if (loop_id->getParallelType() == ParallelType::Group) { + grouped_id = loop_id; + } else if (loop_id->getParallelType() == ParallelType::TIDx) { + thread_x = loop_id; + } else if (loop_id->getParallelType() == ParallelType::BIDx) { + block_x = loop_id; + } else if (loop_id->getParallelType() == ParallelType::TIDz) { + thread_z = loop_id; + } else if (loop_id->getParallelType() == ParallelType::BIDz) { + block_z = loop_id; + } else if ( + loop_id->getParallelType() == ParallelType::Serial || + loop_id->getParallelType() == ParallelType::Unswitch || + loop_id->getParallelType() == ParallelType::Unroll) { + // Check this is ID has a constant extent and is 1 + NVF_ERROR( + loop_id->extent()->isConstInt(), + "Expected constant extent for Serial/Unswitch/Unroll ID in " + "GroupedBlockQuantizationOp"); + NVF_ERROR_EQ( + loop_id->extent()->evaluate().as(), + 1, + "Expected non-TID/BID/Group ID to have extent of 1 for " + "GroupedBlockQuantizationOp: ", + bqop->toString()); + } + } + + NVF_ERROR( + grouped_id != nullptr, + "One of the output IDs must be grouped for " + "GroupedBlockQuantizationOp: ", + bqop->toString()); + + NVF_ERROR( + thread_x != nullptr && block_x != nullptr, + "Need to have both TIDx and BIDx when using " + "GroupedBlockQuantizationOp: ", + bqop->toString()); + + NVF_ERROR( + !thread_z && !block_z, + "Parallelization along z axis is not supported for " + "GroupedBlockQuantizationOp: ", + bqop->toString()); + + auto inner_extent = grouped_id->extent()->evaluate().as(); + auto input_dtype = inp_tv->dtype(); + + NVF_ERROR( + ((inner_extent == 4 || inner_extent == 2) && + input_dtype == DataType::Float) || + ((inner_extent == 8 || inner_extent == 4 || inner_extent == 2) && + (input_dtype == DataType::BFloat16 || + input_dtype == DataType::Half)), + "The group dimension must be 2/4 (FP32) or 2/4/8 " + "(BF16). Found: ", + inner_extent, + ". Expr: ", + bqop->toString()); + + // see [ NOTE: check scheduling requirements for block quantization ] + auto transform_exprs = DependencyCheck::getAllExprsBetween( + {quantized_output->getLogicalDomain().begin(), + quantized_output->getLogicalDomain().end()}, + {quantized_output->getLoopDomain().begin(), + quantized_output->getLoopDomain().end()}); + + std::vector ids_to_transform = + quantized_output->getLogicalDomain(); + + std::deque frontier( + quantized_output->getLogicalDomain().begin(), + quantized_output->getLogicalDomain().end()); + + // This will get the xforms from logical to loop and apply them on the + // logical domain. We will get a loop domain minus the reordering. + // This pass also removes all IDs from frontier that were derived using + // non-contiguous merges. + scheduler_utils::applyTransforms( + ids_to_transform, transform_exprs, [&frontier](Expr* expr) { + traverseFrontierWithContiguityCheck(frontier, expr); + }); + + // The grouped ID must correspond to the innermost loop-like domain + NVF_ERROR( + ids_to_transform.back() == grouped_id, + "The grouped ID must correspond to the innermost of all splits " + "from logical domains to loop domains for GroupedBlockQuantizationOp. " + "TV: ", + quantized_output->toString()); + + // Iterate from the back to find TIDx, skipping group_id (last element) + // Ensure all IDs between group_id and TIDx have extent 1 + bool found_tidx = false; + for (auto it = ids_to_transform.rbegin() + 1; it != ids_to_transform.rend(); + ++it) { + if (*it == thread_x) { + found_tidx = true; + break; + } + // All non-TIDx IDs between Group ID and TIDx must have extent of 1 + NVF_ERROR( + (*it)->extent()->isConstInt() && + (*it)->extent()->evaluate().as() == 1, + "Expected IDs between Group ID and TIDx to have extent of 1 for " + "GroupedBlockQuantizationOp: ", + quantized_output->toString()); + } + + NVF_ERROR( + found_tidx, + "TIDx must follow the Group ID in the schedule for " + "GroupedBlockQuantizationOp: ", + quantized_output->toString()); + + // Check if grouped_id in frontier + auto grouped_it = std::ranges::find(frontier, grouped_id); + NVF_ERROR( + grouped_it != frontier.end(), + "All merge operations deriving the grouped ID must combine " + "contiguous IDs from the logical domain for " + "GroupedBlockQuantizationOp: ", + quantized_output->toString()); + // Do the same for thread_x + auto threadx_it = + std::ranges::find(frontier.begin(), frontier.end(), thread_x); + NVF_ERROR( + threadx_it != frontier.end(), + "All merge operations deriving the TIDx ID must combine " + "contiguous IDs from the logical domain for " + "GroupedBlockQuantizationOp: ", + quantized_output->toString()); + } }; } // namespace @@ -1869,7 +2063,8 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) { def->isA() || def->isA() || def->isA() || def->isA() || def->isA() || def->isA() || def->isA() || - def->isA(), + def->isA() || + def->isA(), "Invalid use of ParallelType::Group: ", def->toString()); diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 822ababb149..f46423b498a 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -112,6 +112,7 @@ class Val; f(CutlassNvfp4GroupedMmaOp); \ f(PreprocessGroupedMatmulInputSf); \ f(BlockQuantizationOp); \ + f(GroupedBlockQuantizationOp); \ f(TopKOp); \ f(ScanOp); \ f(Merge); \ diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index d0dc64cbe2a..f472d41f061 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1878,7 +1878,12 @@ std::pair> SegmentedFusion::makeFusion( for (auto inp : getAllInputs(sg)) { auto clone_tv = complete_to_segment_map.clone(inp); fusion_segment->addInput(clone_tv); - if (inp->isDefinitionType()) { + if (inp->isDefinitionType() || + (inp->isDefinitionType() && + (inp == + inp->definition() + ->as() + ->blockScales()))) { // NOTE: inp is an input to fusion segment. // // There's no point of replaying allocation domain if we cannot index into diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index 44cd766dcd1..0410b0d6148 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -1829,4 +1829,62 @@ std::vector BlockQuantizationOp::evaluate( NVFUSER_DEFINE_CLONE_AND_CREATE(BlockQuantizationOp) +GroupedBlockQuantizationOp::GroupedBlockQuantizationOp( + IrBuilderPasskey passkey, + Val* output_scales, + Val* output, + Val* input, + Val* input_offsets, + Val* output_offsets, + BlockScalingFactorLayout layout, + Val* k, + Val* g, + Val* global_scale, + int64_t block_size, + Val* row_idx, + Val* col_idx) + : Expr(passkey) { + addOutput(output); + addOutput(output_scales); + addInput(input); + addInput(input_offsets); + addInput(output_offsets); + addInput(k); + addInput(g); + if (global_scale) { + addInput(global_scale); + } + addDataAttribute(block_size); + addDataAttribute(layout); + if (row_idx != nullptr) { + addAttribute(row_idx); + } + if (col_idx != nullptr) { + addAttribute(col_idx); + } +} + +std::string GroupedBlockQuantizationOp::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "(" << blockScales()->toString() << ",\n " + << quantizedOutput()->toString() << ")\n" + << " = grouped_block_quantize(" << in()->toString() + << ",\n " << inputOffsets()->toString() << ",\n " + << outputOffsets()->toString() << ")\n"; + return ss.str(); +} + +std::string GroupedBlockQuantizationOp::toInlineString(int indent_size) const { + NVF_CHECK(false, "GroupedBlockQuantizationOp can not be printed inline"); +} + +std::vector GroupedBlockQuantizationOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + // This is a placeholder, currently we don't have a fallback kernel available + NVF_THROW("GroupedBlockQuantizationOp evaluation not yet implemented"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedBlockQuantizationOp) + } // namespace nvfuser diff --git a/csrc/ir/composite_nodes.h b/csrc/ir/composite_nodes.h index 5a40ac4b9b3..629cbe55f3b 100644 --- a/csrc/ir/composite_nodes.h +++ b/csrc/ir/composite_nodes.h @@ -1038,4 +1038,96 @@ class BlockQuantizationOp : public Expr { const std::vector& inputs) const override; }; +class GroupedBlockQuantizationOp : public Expr { + public: + using Expr::Expr; + + // This op takes in a high precision input(input) + // and returns the quantized output(output) along with the block scaling + // factors (output_scales). It can also take as an optional input the global + // scaling factor and block size (though we currently only support 16). + // logical_index is used for internal implemtation. This op is currently + // implemented via a runtime function. During index computation, we compute + // the index of the output_scales and pass it to the runtime function. + GroupedBlockQuantizationOp( + IrBuilderPasskey, + Val* output_scales, + Val* output, + Val* input, + Val* input_offsets, + Val* output_offsets, + BlockScalingFactorLayout layout, + Val* k, + Val* g, + Val* global_scale = nullptr, + int64_t block_size = 16, + Val* row_idx = nullptr, + Val* col_idx = nullptr); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + Val* blockScales() const { + return output(1); + } + + Val* quantizedOutput() const { + return output(0); + } + + Val* in() const { + return input(0); + } + + int64_t blockSize() const { + return attribute(0); + } + + bool hasGlobalScale() const { + if (inputs().size() > 5) { + return true; + } + return false; + } + + Val* globalScale() const { + if (hasGlobalScale()) { + return input(5); + } + return nullptr; + } + + const char* getOpString() const override { + return "GroupedBlockQuantizationOp"; + } + + TensorView* inputOffsets() const { + return input(1)->as(); + } + + TensorView* outputOffsets() const { + return input(2)->as(); + } + + // get scalar - column size + Val* k() const { + return input(3); + } + + // get scalar - number of groups + Val* g() const { + return input(4); + } + + // get enum - block scaling factor layout + BlockScalingFactorLayout layout() const { + return attribute(1); + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; +}; + } // namespace nvfuser diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 9418b137688..d24fdd1c879 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1337,7 +1337,11 @@ bool hasTrivialAllocationDomain(const TensorView* tv) { alloc | TensorDomain::kNoReductions | TensorDomain::kNoBroadcasts); } bool hasUniformSiblings(Expr* expr) { - return !expr->isOneOf(); + return !expr->isOneOf< + SdpaFwdOp, + SdpaBwdOp, + BlockQuantizationOp, + GroupedBlockQuantizationOp>(); } bool mayRequireAllocation(const TensorView* tv, IterDomain* id) { diff --git a/csrc/kernel.cpp b/csrc/kernel.cpp index 55552d2638d..76cea7675bd 100644 --- a/csrc/kernel.cpp +++ b/csrc/kernel.cpp @@ -324,6 +324,10 @@ class KernelIrScanner : private IrVisitor { summary_.has_block_quantize_op = true; } + void handle(GroupedBlockQuantizationOp* bqop) final { + summary_.has_block_quantize_op = true; + } + void handle(ScanOp* scan) final { summary_.has_scan = true; } diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 26f6200a757..56ea5b324ef 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -141,14 +141,33 @@ std::pair, bool> getNonMappingDomainInfo( // as it's extent is reduced by a factor of the block size // for example [i0, i1] => [i0, i1/16] where 16 is the block size. // Make sure the producer isn't the global scale. - if (consumer_tv == - consumer_tv->definition() - ->as() - ->blockScales() && - producer_tv != - consumer_tv->definition() - ->as() - ->globalScale()) { + Val* block_scales = + consumer_tv->definition()->as()->blockScales(); + Val* global_scale = + consumer_tv->definition()->as()->globalScale(); + + if (consumer_tv == block_scales && producer_tv != global_scale) { + auto producer_logical = + TensorDomain::noReductions(producer_tv->getLogicalDomain()); + auto last_logical_dim = producer_logical.size() - 1; + non_mapping_ids.insert(producer_logical.at(last_logical_dim)); + // We are mapping everything but the last ID. + has_consumer_id = true; + } + } else if ( + auto grouped_bqop = dynamic_cast( + consumer_tv->definition())) { + if (producer_tv != grouped_bqop->in()) { + auto producer_logical = + TensorDomain::noReductions(producer_tv->getLogicalDomain()); + non_mapping_ids.insert(producer_logical.begin(), producer_logical.end()); + // we are not mapping anything, `has_consumer_id` doesn't matter. + has_consumer_id = false; + } else if (consumer_tv == grouped_bqop->blockScales()) { + // We don't map the inner-most dimension of the block scaling factors + // as it's extent is reduced by a factor of the block size + // for example [i0, i1] => [i0, i1/16] where 16 is the block size. + // Make sure the producer isn't the global scale. auto producer_logical = TensorDomain::noReductions(producer_tv->getLogicalDomain()); auto last_logical_dim = producer_logical.size() - 1; @@ -1387,7 +1406,8 @@ void ComputeAtLogicalDomainMapBuilder::mapPointwiseLikeOp(Expr* expr) { NVF_ERROR( expr->isA() || expr->isA() || expr->isA() || expr->isA() || - expr->isA(), + expr->isA() || + expr->isA(), "Unknown multi-output Expr type ", expr->getOpString(), " is found"); diff --git a/csrc/logical_domain_map.h b/csrc/logical_domain_map.h index 3d76758ddb6..133d526fbf6 100644 --- a/csrc/logical_domain_map.h +++ b/csrc/logical_domain_map.h @@ -550,6 +550,10 @@ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor { mapPointwiseLikeOp(op); } + void handle(GroupedBlockQuantizationOp* op) override { + mapPointwiseLikeOp(op); + } + void handle(TensorView* tv) override; //! Maps all pending mappings. diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 64f68cbf1a4..9baf7536e82 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -2753,4 +2754,144 @@ BlockQuantizationResults blockQuantize( return BlockQuantizationResults(quantized_tensor, block_scales); } +BlockQuantizationResults groupedBlockQuantize( + TensorView* input, + TensorView* input_offsets, + TensorView* output_offsets, + BlockScalingFactorLayout layout, + TensorView* global_scaling_factor, + int64_t block_size, + DataType out_dtype) { + NVF_CHECK( + out_dtype == DataType::Float4_e2m1fn || + out_dtype == DataType::Float8_e4m3fn, + "Currently only output data type of Float4_e2m1fn or Float8_e4m3fn is " + "supported"); + if (out_dtype == DataType::Float4_e2m1fn) { + NVF_ERROR_EQ( + block_size, + 16, + "Block size must be 16 for Float4_e2m1fn, got ", + block_size); + } else if (out_dtype == DataType::Float8_e4m3fn) { + NVF_ERROR_EQ( + block_size, + 32, + "Block size must be 32 for Float8_e4m3fn, got ", + block_size); + NVF_CHECK( + !global_scaling_factor, + "global_scaling_factor must be nullptr for Float8_e4m3fn"); + } + + // Validate input data type + // We'll only support FP32 or BF16/FP16 + NVF_CHECK( + input->getDataType().value() == DataType::Float || + input->getDataType().value() == DataType::BFloat16 || + input->getDataType().value() == DataType::Half, + "Grouped block quantization expects floating point input but got ", + input->getDataType().value()); + + // Check that if global_scaling_factor in non-null + // then it is a scalar float TensorView + if (global_scaling_factor != nullptr) { + NVF_CHECK( + TensorDomain::noReductions(global_scaling_factor->getLogicalDomain()) + .empty(), + "Global scaling factor for grouped block quantization must be a scalar " + "tensor"); + NVF_CHECK( + global_scaling_factor->getDataType().value() == DataType::Float, + "Global scaling factor for grouped block quantization must be of float " + "data " + "type"); + } + + auto inp_domain = TensorDomain::noReductions(input->getLogicalDomain()); + + // Validate input tensor is 2d + NVF_ERROR_EQ( + inp_domain.size(), + 2, + "Grouped block quantization only supports 2-dimensional tensors"); + + // Create output domain for quantized tensor (same shape as input) + std::vector quantized_out_domain; + quantized_out_domain.reserve(inp_domain.size()); + + for (auto inp_domain_ptr : inp_domain) { + quantized_out_domain.push_back(inp_domain_ptr->cloneWithoutRFactor()); + } + + // Create output tensors + TensorView* quantized_tensor = IrBuilder::create( + IrBuilder::create( + quantized_out_domain, + TensorDomain::getContiguityFilledWith(quantized_out_domain, true)), + out_dtype); + + // Create output blocked scaling factor + auto block_scales_dtype = (out_dtype == DataType::Float4_e2m1fn) + ? DataType::Float8_e4m3fn + : DataType::Float8_e8m0fnu; + + // This is used for both root and loop domain on output + // maps directly to input's logical domain. + std::vector scales_out_domain; + scales_out_domain.reserve(inp_domain.size()); + + for (auto inp_id : inp_domain) { + if (inp_id == inp_domain.back()) { + scales_out_domain.push_back( + IterDomainBuilder( + inp_id->start(), + SimplifyingIrBuilder::divExpr( + inp_id->extent(), + IrBuilder::create(block_size, DataType::Index))) + .build()); + + } else { + scales_out_domain.push_back(inp_id->cloneWithoutRFactor()); + } + } + + std::vector offset_logical_dom = + TensorDomain::noReductions(input_offsets->getLogicalDomain()); + Val* num_groups = offset_logical_dom[0]->extent(); + + // Create the allocation domain of output. + std::vector out_alloc_dom = + layoutAllocationDomain(scales_out_domain, num_groups, layout); + + // Create block scaling factors + TensorView* block_scales = IrBuilder::create( + IrBuilder::create( + /*root_domain=*/std::vector(), + /*logical_domain=*/scales_out_domain, + /*allocation=*/out_alloc_dom, + /*loop_domain=*/scales_out_domain, + /*alternate_loop_domain=*/std::nullopt, + /*contiguity=*/ + TensorDomain::getContiguityFilledWith(out_alloc_dom, true), + /*additional_ids=*/std::vector(), + /*skip_checks=*/true), + block_scales_dtype); + + // Create the grouped block quantization operation + IrBuilder::create( + block_scales, + quantized_tensor, + input, + input_offsets, + output_offsets, + layout, + inp_domain[1]->getMaybeExpandedExtent(), + num_groups, + global_scaling_factor, + block_size); + + return BlockQuantizationResults(quantized_tensor, block_scales); +} + } // namespace nvfuser diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index e49a1a416f1..f4595d36561 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -855,4 +855,13 @@ NVF_API BlockQuantizationResults blockQuantize( bool swizzle_scales = false, DataType out_dtype = DataType::Float4_e2m1fn); +NVF_API BlockQuantizationResults groupedBlockQuantize( + TensorView* input, + TensorView* input_offsets, + TensorView* output_offsets, + BlockScalingFactorLayout layout, + TensorView* global_scaling_factor = nullptr, + int64_t block_size = 16, + DataType out_dtype = DataType::Float4_e2m1fn); + } // namespace nvfuser diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index a86e2f08495..11500f55381 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -300,7 +300,9 @@ bool PointWiseScheduler::canScheduleRunTime( data_cache, [fusion]() { return std::make_unique( - !ir_utils::getOpsOfType(fusion).empty()); + !ir_utils::getOpsOfType(fusion).empty() || + !ir_utils::getOpsOfType(fusion) + .empty()); }) .get(); @@ -417,6 +419,26 @@ std::unique_ptr PointWiseScheduler::computeHeuristics( fusion, runtime_info, data_cache, prop); } NVF_ERROR(pparams != nullptr); + + // cap vectorization when block quantization op is encountered, since there's + // a validation during device_lower + auto has_block_quantization_ops = + HeuristicDataCacheEntry( + data_cache, + [fusion]() { + return std::make_unique( + !ir_utils::getOpsOfType(fusion).empty() || + !ir_utils::getOpsOfType(fusion) + .empty()); + }) + .get(); + if (has_block_quantization_ops) { + // FIXME: this needs to be done per input dtype. I'm capping it as 4 for + // simplicity for now. + pparams->as()->vectorization_factor = std::min( + 4, pparams->as()->vectorization_factor); + } + return pparams; } diff --git a/csrc/scheduler/pointwise_non_tma.cpp b/csrc/scheduler/pointwise_non_tma.cpp index 37d30b8049c..bdf66b8b14c 100644 --- a/csrc/scheduler/pointwise_non_tma.cpp +++ b/csrc/scheduler/pointwise_non_tma.cpp @@ -148,7 +148,9 @@ int64_t getUnrollFactor( data_cache, [fusion]() { return std::make_unique( - !ir_utils::getOpsOfType(fusion).empty()); + !ir_utils::getOpsOfType(fusion).empty() || + !ir_utils::getOpsOfType(fusion) + .empty()); }) .get(); @@ -603,11 +605,16 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { // We do so as the runtime function for block quantization expects 2/4/8 // elements per thread. auto bq_ops = ir_utils::getOpsOfType(fusion); + auto gbq_ops = ir_utils::getOpsOfType(fusion); std::vector nvfp4_quantized_outputs = {}; for (auto bq_op : bq_ops) { nvfp4_quantized_outputs.push_back( bq_op->quantizedOutput()->as()); } + for (auto gbq_op : gbq_ops) { + nvfp4_quantized_outputs.push_back( + gbq_op->quantizedOutput()->as()); + } if (pparams->vectorization_factor > 1) { // Grab all tensor views that should be vectorized diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index 490ffd6d036..8b1c545f2f6 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -851,6 +851,13 @@ bool hasNonTerminalBlockQuantizeOp(Fusion* fusion) { if (!block_scales->isFusionOutput()) { return true; } + } else if (expr->isA()) { + auto block_scales = expr->as() + ->blockScales() + ->as(); + if (!block_scales->isFusionOutput()) { + return true; + } } } return false; @@ -1083,6 +1090,19 @@ bool SchedulerTopologyChecker::rejectScheduleFusionGlobalBufferRequirement( layout_op, layout_op->outputOffsets(), scheduler_type)) { return true; } + } else if (expr->isA()) { + // The runtime function of GroupedBlockQuantizationOp needs: + // 1. Write scale output directly to global memory + // 2. Read two offset inputs directly from global memory + auto grouped_bop = expr->as(); + if (rejectScheduleFusionOutputRequirement( + grouped_bop, grouped_bop->blockScales(), scheduler_type) || + rejectScheduleFusionInputRequirement( + grouped_bop, grouped_bop->inputOffsets(), scheduler_type) || + rejectScheduleFusionInputRequirement( + grouped_bop, grouped_bop->outputOffsets(), scheduler_type)) { + return true; + } } } return false; diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 55abe2d0914..8d9cc77c57e 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -66,6 +66,13 @@ bool canIgnoreIndexedInputDomainID( input_tv == layout->outputOffsets()) { continue; } + } else if (auto layout = dynamic_cast(use)) { + // since we don't index into offsets, scheduler doesn't need to cover + // offset TVs ID. + if (input_tv == layout->inputOffsets() || + input_tv == layout->outputOffsets()) { + continue; + } } else { // If the input TV is used by any other ops return false; @@ -420,6 +427,12 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const { output_tv == output_tv->definition() ->as() + ->blockScales()) || + (output_tv->definition() && + output_tv->definition()->isA() && + output_tv == + output_tv->definition() + ->as() ->blockScales())) { continue; } diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 17ac3c44328..f32a0f8eb93 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1354,17 +1354,18 @@ std::vector> cacheInputs( // TODO: we might need to explicitly promote offsets to global memory // We expect offsets to remain in global memory, so we do not add it to // cache - auto isPreprocessGroupedMatmulInputSfOffsets = [tv](Expr* use) { - if (!use->isA()) { - return false; + auto isGroupOffsets = [tv](Expr* use) { + if (auto op = dynamic_cast(use)) { + return tv == op->inputOffsets() || tv == op->outputOffsets(); + } else if (auto op = dynamic_cast(use)) { + return tv == op->inputOffsets() || tv == op->outputOffsets(); } - auto layout = use->as(); - return tv == layout->inputOffsets() || tv == layout->outputOffsets(); + return false; }; std::vector cached_uses; for (auto use : tv->uses()) { if (!use->isOneOf() && !isGatherLookUpTvInUse(use) && - !isPreprocessGroupedMatmulInputSfOffsets(use)) { + !isGroupOffsets(use)) { cached_uses.push_back(use); } } @@ -1408,7 +1409,11 @@ std::vector> cacheAndForkOutputs( ->isOneOf() || (output->definition()->isA() && output->definition()->as()->blockScales() == - output)) { + output) || + (output->definition()->isA() && + output->definition() + ->as() + ->blockScales() == output)) { continue; } if (!output->uses().empty()) { diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 89d83a97eea..7ab282644c9 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -356,6 +356,12 @@ inferAndValidateAllocationSizesAndStrides( if (bqop->isSwizzledScales() && tv == bqop->blockScales()) { skip_validation = true; } + } else if ( + tv->definition() && tv->definition()->isA()) { + auto bqop = tv->definition()->as(); + if (tv == bqop->blockScales()) { + skip_validation = true; + } } // Skip validation for scale input to ScaledMmaOp as it will be swizzled. diff --git a/tests/cpp/test_layout_op.cpp b/tests/cpp/test_layout_op.cpp index 1a2fa739f27..1535f45e428 100644 --- a/tests/cpp/test_layout_op.cpp +++ b/tests/cpp/test_layout_op.cpp @@ -62,7 +62,8 @@ bool validateGroupedLayout( .transpose(1, 3) .reshape({mn_tile * 4 * 32, k_tile * 4}) .slice(0, 0, m_g) - .slice(1, 0, k); + .slice(1, 0, k) + .to(ref.dtype()); auto ref_g = ref.slice( 0, expert_offsets[i].item().to(), @@ -376,4 +377,71 @@ TEST_F(LayoutOpTest, Inlining) { EXPECT_EQ(inp_cache->getComputeAtPosition(), 2); } +TEST_F(LayoutOpTest, GroupedBlockQuantizeOp) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto inp = makeSymbolicTensor(2); + auto offsets = makeSymbolicTensor(1, DataType::Int32); + auto rounded_offsets = makeSymbolicTensor(1, DataType::Int32); + fusion.addInput(inp); + fusion.addInput(offsets); + fusion.addInput(rounded_offsets); + + auto outs = groupedBlockQuantize( + inp, offsets, rounded_offsets, BlockScalingFactorLayout::Block128x4); + fusion.addOutput(castOp(DataType::Float, outs.quantized_tensor)); + fusion.addOutput(outs.block_scales); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + int m = 512; + int k = 9 * 16; // note: padded column size needs to be a multiple of 16 + auto t0 = at::randn({m, k}, options); + + // tokens per group are [100, 150, 262] respectively, so each group would be + // padded to multiple of 128. Hence the total output row span would cover a + // length of 128 + 256 + 384 = 768. + auto t1 = at::tensor({0, 100, 250}, options.dtype(at::kInt)); + auto t2 = at::tensor({0, 128, 384}, options.dtype(at::kInt)); + + // automatic scheduling. + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); + + at::Tensor ref_block_sf; + at::Tensor ref_scaled_out; + // producing reference + { + std::unique_ptr fusion_new_op = std::make_unique(); + FusionGuard fg2(fusion_new_op.get()); + auto tv_in = makeContigTensor(2); + fusion_new_op->addInput(tv_in); + auto quantization_results = + blockQuantize(tv_in, nullptr, /*block_size=*/16, false); + + fusion_new_op->addOutput(quantization_results.block_scales); + fusion_new_op->addOutput( + castOp(DataType::Float, quantization_results.quantized_tensor)); + FusionExecutorCache executor_cache(std::move(fusion_new_op)); + auto outputs_new_op = executor_cache.runFusionWithInputs({t0}); + ref_block_sf = outputs_new_op[0].as().to(at::kFloat); + ref_scaled_out = outputs_new_op[1].as(); + } + + // check scaled output + EXPECT_TRUE(at::allclose(ref_scaled_out, outputs[0].as())); + // check block scaling factor + ASSERT_TRUE(validateGroupedLayout( + BlockScalingFactorLayout::Block128x4, + outputs[1].as(), + ref_block_sf, + t1, + t2)); + + EXPECT_THAT( + executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(), + UnorderedElementsAre(HeuristicIs(SchedulerType::PointWise))); +} + } // namespace nvfuser