Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
141 commits
Select commit Hold shift + click to select a range
8e3842f
creating a new node
protonu Sep 11, 2025
9c06080
removing commented out code
protonu Sep 29, 2025
fd79ac1
codegen the indices for the outputs
protonu Sep 30, 2025
60af605
adding a new test for 2D sched
protonu Sep 30, 2025
73295a1
write quantized output to regs
protonu Oct 1, 2025
f97264a
clean up
protonu Oct 1, 2025
a587641
clean up
protonu Oct 1, 2025
a44a6fc
clean up trivial broadcast
protonu Oct 1, 2025
727b663
clean up the tests
protonu Oct 1, 2025
e8b0ade
minor cleanup
protonu Oct 1, 2025
55bb2c3
address reviewer comments
protonu Oct 7, 2025
69c315e
reviewer comment
protonu Oct 7, 2025
ab5e60e
move comments around
protonu Oct 8, 2025
8eedb76
clean up
protonu Oct 8, 2025
a7b6d58
edit comments
protonu Oct 8, 2025
75699a1
remove setting parallel type for BIDx TIDx
protonu Oct 8, 2025
c7f1d8d
remove setting parallel type for BIDx TIDx - cleanup
protonu Oct 8, 2025
44740dc
adding support for parallel type group
protonu Oct 8, 2025
9130435
address reviewer comments
protonu Oct 10, 2025
434a1b1
adding a comment
protonu Oct 10, 2025
ce0b820
removing a comment
protonu Oct 10, 2025
fb4ad21
modifying a check
protonu Oct 10, 2025
2ce601f
merge
protonu Oct 10, 2025
beb2f06
merge
protonu Oct 10, 2025
9bd7936
Merge branch 'main' into pbasu_fp4_node
protonu Oct 10, 2025
b04d974
add validation, update test
protonu Oct 17, 2025
7c79c32
clean up from merge
protonu Oct 17, 2025
8564431
remove a utility function
protonu Oct 17, 2025
774e27d
support half and bfloat in tests
protonu Oct 17, 2025
d43096e
removing vectorize
protonu Oct 17, 2025
3725a51
Merge branch 'main' into pbasu_fp4_node
protonu Oct 17, 2025
44d192c
updating comment for validation fn
protonu Oct 17, 2025
4634fc7
Merge branch 'main' into pbasu_fp4_node
protonu Oct 21, 2025
11fff3d
working index
protonu Oct 22, 2025
c729ada
remove header
protonu Oct 23, 2025
5e55689
Merge branch 'main' into pbasu_fp4_node
protonu Oct 27, 2025
c0cd7f9
Merge branch 'main' into pbasu_fp4_node
protonu Oct 28, 2025
0cadefc
Merge branch 'main' into pbasu_fp4_node
protonu Oct 30, 2025
535075d
Update csrc/codegen.cpp
protonu Oct 31, 2025
115ddbf
Update runtime/block_quantization_kernels.cu
protonu Oct 31, 2025
6363032
Update runtime/block_quantization_kernels.cu
protonu Oct 31, 2025
65c8f25
Update runtime/block_quantization_kernels.cu
protonu Oct 31, 2025
1f9829a
Update csrc/ops/arith.cpp
protonu Oct 31, 2025
aaa591e
address reviewer comments
protonu Oct 31, 2025
d49f09d
Merge branch 'main' into pbasu_fp4_node
protonu Oct 31, 2025
64a921e
runtime validation for inner dim size
protonu Oct 31, 2025
7e3835c
address reviewer comments
protonu Oct 31, 2025
9dd3a7a
removing code for validation
protonu Oct 31, 2025
25fe4e2
Merge branch 'main' into pbasu_fp4_node
protonu Oct 31, 2025
a444ea7
update comments
protonu Oct 31, 2025
d0b5635
Merge branch 'main' into pbasu_fp4_node
protonu Oct 31, 2025
375feae
edit comments
protonu Oct 31, 2025
1104e6a
Merge branch 'main' into pbasu_fp4_node
protonu Oct 31, 2025
389cdb4
Merge branch 'main' into pbasu_fp4_node
protonu Nov 1, 2025
d24562a
adding validation checks and initial tests
protonu Nov 1, 2025
02bf8ce
merge
protonu Nov 2, 2025
9ac8261
merge
protonu Nov 2, 2025
a1612c8
adding comments to tests
protonu Nov 3, 2025
fb6bd53
more comments for validation
protonu Nov 3, 2025
fe787e0
Update tests/cpp/test_low_precision_recipe.cpp
protonu Nov 3, 2025
5b11a74
Apply suggestion from @greptile-apps[bot]
protonu Nov 3, 2025
4ea7bae
fix weird duplicated code that showed up
protonu Nov 3, 2025
bb21123
removing stale comment
protonu Nov 3, 2025
918bd94
changes to tests based on reviewer comments
protonu Nov 4, 2025
82990c4
Apply suggestion from @greptile-apps[bot]
protonu Nov 4, 2025
b87aa9a
remove vectorization testt and address less involved reviewer comments
protonu Nov 4, 2025
8daae2f
adding new validation
protonu Nov 4, 2025
dfd7553
foramt
protonu Nov 4, 2025
d34936c
Merge branch 'main' into pbasu_fp4_validation
protonu Nov 4, 2025
00f5f94
address greptile comments
protonu Nov 4, 2025
0b4f880
Merge branch 'main' into pbasu_fp4_validation
protonu Nov 4, 2025
e8ee708
Merge branch 'pbasu_fp4_validation' of github.com:nvidia/fuser into p…
protonu Nov 4, 2025
6b61060
edit assert error handler
protonu Nov 5, 2025
490b6a7
cleanup
protonu Nov 5, 2025
25da8be
more clean up
protonu Nov 5, 2025
2059878
refactor
protonu Nov 5, 2025
a20f0ea
Merge branch 'main' into pbasu_fp4_validation
protonu Nov 5, 2025
b67f205
refactor using lambda and edit comments
protonu Nov 5, 2025
d1f2cbc
Update csrc/device_lower/validation.cpp
protonu Nov 5, 2025
9f582a8
Apply suggestion from @greptile-apps[bot]
protonu Nov 5, 2025
f30ab2e
edit comments
protonu Nov 5, 2025
b90b309
Merge branch 'main' into pbasu_fp4_validation
protonu Nov 5, 2025
521461c
Apply suggestions from code review
naoyam Nov 6, 2025
ce25215
Merge branch 'main' into pbasu_fp4_validation
naoyam Nov 6, 2025
4510a77
allows bq kernel to take 2,4,8 elem per thread
protonu Nov 6, 2025
9afefd1
update comments
protonu Nov 6, 2025
6a6b46d
extend pointwise scheduler to accept block quantization op
protonu Nov 6, 2025
3f124b6
merge
protonu Nov 6, 2025
5cb3e50
edit comments
protonu Nov 6, 2025
ac602ab
Merge branch 'pbasu_fp4_modified_runtime_fn' into pbasu_fp4_new_auto_…
protonu Nov 6, 2025
63cf7b9
address comments from greptile
protonu Nov 6, 2025
07eb934
merge
protonu Nov 8, 2025
f56e742
Merge branch 'main' into pbasu_fp4_new_auto_sched
protonu Nov 8, 2025
4ca4d8e
address reviewer comment - move codea around
protonu Nov 11, 2025
031c280
Update csrc/scheduler/pointwise.cpp
protonu Nov 11, 2025
479691b
correct typo
protonu Nov 12, 2025
fb03f91
Merge branch 'main' into pbasu_fp4_new_auto_sched
protonu Nov 12, 2025
6a5a553
move checks to canScheduleRunTime
protonu Nov 13, 2025
76a9603
Merge branch 'main' into pbasu_fp4_new_auto_sched
protonu Nov 13, 2025
1514022
Apply suggestion from @greptile-apps[bot]
protonu Nov 13, 2025
19e43df
Merge branch 'main' into pbasu_fp4_new_auto_sched
protonu Nov 13, 2025
1e92c4d
cache check for BQ ops
protonu Nov 14, 2025
7d74583
move unroll factor computation bypass
protonu Nov 14, 2025
4a0d708
cleanup redundant code
protonu Nov 14, 2025
0321c32
merge
protonu Nov 14, 2025
6f31bd4
move data_cache access to getUnroll
protonu Nov 14, 2025
fd87633
Merge branch 'main' into pbasu_fp4_new_auto_sched
protonu Nov 17, 2025
03b374f
support global scale
protonu Nov 17, 2025
6029ecc
merge
protonu Nov 17, 2025
9d2cc3d
minor edits
protonu Nov 17, 2025
a6bb1fb
validation
protonu Nov 17, 2025
d4beb20
wip
protonu Nov 17, 2025
16593c5
merge
protonu Nov 18, 2025
8445011
modify tests and validation
protonu Nov 18, 2025
79c7dc2
format
protonu Nov 18, 2025
91e9c59
Merge branch 'main' into pbasu_working_swizzle
protonu Nov 18, 2025
989b974
Merge branch 'main' into pbasu_working_swizzle
protonu Nov 18, 2025
483069f
address reviewer comments
protonu Nov 18, 2025
11eb6c6
Merge branch 'pbasu_working_swizzle' of github.com:nvidia/fuser into …
protonu Nov 18, 2025
190802f
Apply suggestions from code review
protonu Nov 18, 2025
de431d3
add validation for swizzling
protonu Nov 18, 2025
18a51af
more tests for swizzle validation
protonu Nov 18, 2025
b664575
Merge branch 'main' into pbasu_working_swizzle
protonu Nov 18, 2025
646b136
clean up
protonu Nov 19, 2025
5651c24
Merge branch 'pbasu_working_swizzle' of github.com:nvidia/fuser into …
protonu Nov 19, 2025
d0c0e02
clean up
protonu Nov 19, 2025
06f7987
Merge branch 'main' into pbasu_working_swizzle
protonu Nov 19, 2025
741e530
handle clang-tidy error
protonu Nov 19, 2025
661e6ea
Merge branch 'pbasu_working_swizzle' of github.com:nvidia/fuser into …
protonu Nov 19, 2025
317445d
better comment
protonu Nov 19, 2025
c13125c
Merge branch 'main' into pbasu_working_swizzle
protonu Nov 19, 2025
e249df9
cleanup validation
protonu Nov 20, 2025
2ebb93f
cleanup
protonu Nov 20, 2025
896e45e
python API for block quantization
protonu Nov 20, 2025
3a41bf8
Merge branch 'main' into pbasu_bq_py_api
protonu Nov 21, 2025
553daa4
wip
protonu Nov 21, 2025
2962ac4
wip
protonu Nov 21, 2025
740eea9
wip
protonu Nov 21, 2025
a6e767d
adding test against TE
protonu Nov 24, 2025
2434f76
almost working tests with TE
protonu Nov 24, 2025
5e0b02f
scaled_mm test
protonu Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions csrc/tensor_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,10 @@ inferAndValidateAllocationSizesAndStrides(
auto [allocation_sizes, allocation_strides] =
inferAllocationSizesAndStrides(tensor, tv, ee);
// Only validate final sizes and strides when we have a non-empty tensor.
if (tensor.numel() != 0) {
validateAllocationSizesAndStrides(tv, allocation_sizes, allocation_strides);
}
// if (tensor.numel() != 0) {
// validateAllocationSizesAndStrides(tv, allocation_sizes,
// allocation_strides);
// }
return {std::move(allocation_sizes), std::move(allocation_strides)};
}

Expand Down
73 changes: 73 additions & 0 deletions python/python_direct/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3486,6 +3486,78 @@ tuple[TensorView, TensorView, TensorView]
py::return_value_policy::reference);
}

namespace {

// Helper function to apply swizzle transformation to block scaling factors
// for FP4 quantization. This transforms the memory layout to optimize access
// patterns.
void swizzleBlockScale(TensorView* tv_block_scale_fp8) {
// auto original_loop = tv_block_scale_fp8->getLoopDomain();
tv_block_scale_fp8->split(0, 128);
// m/128, 128, k
tv_block_scale_fp8->split(1, 32);
// m/128, 4(m_o), 32(m_i), k
tv_block_scale_fp8->split(3, 4);
// m/128, 4(m_o), 32(m_i), k/4, 4(k)
std::vector<IterDomain*> tv_block_scale_fp8_alloc{
tv_block_scale_fp8->axis(0),
tv_block_scale_fp8->axis(3),
tv_block_scale_fp8->axis(2),
tv_block_scale_fp8->axis(1),
tv_block_scale_fp8->axis(4)};
// m/128, k/4, 32(m_i), 4(m_o), 4(k)
tv_block_scale_fp8->setAllocationDomain(tv_block_scale_fp8_alloc, true);

// back to a 2D logical domain.
// m/128, 4(m_o), 32(m_i), k/4, 4(k) ->
// m/32, 32, k/4, 4(k)
tv_block_scale_fp8->merge(0);
// m/32, 32, k/4, 4(k) -> m, k/4, 4(k)
tv_block_scale_fp8->merge(0);
// m, k/4, 4(k) -> m, k
tv_block_scale_fp8->merge(-2);
}

} // namespace

void bindQuantizationOps(py::module_& ops) {
ops.def(
"nv_block_quantize",
[](TensorView* input,
TensorView* global_scale,
bool swizzle_block_scales,
int64_t block_size,
PrimDataType dtype) -> py::tuple {
auto output = blockQuantize(input, global_scale, block_size, dtype);
if (swizzle_block_scales) {
swizzleBlockScale(output.block_scales);
}
return py::make_tuple(output.quantized_tensor, output.block_scales);
},
py::arg("input"),
py::arg("global_scale").none(true) = py::none(),
py::arg("swizzle_block_scales") = false,
py::arg("block_size") = 16,
py::arg("dtype") = DataType::Float4_e2m1fn,
R"(
Block quantize tensor to NVFP4 format.
Parameters
----------
input : TensorView
Input tensor to quantize. Must be a floating point tensor.
global_scale : Val or TensorView, optional
block_size : int, optional
Block size for quantization. Default is 16.
Returns
-------
tuple[TensorView, TensorView]
A tuple containing (block_scales, quantized_tensor) where:
- block_scales: Per-block scaling factors
- quantized_tensor: Quantized tensor in NVFP4 format
)",
py::return_value_policy::reference);
}

template <
class ShapeType,
TensorView* (*RandomFuncWithSeed)(
Expand Down Expand Up @@ -3638,6 +3710,7 @@ void bindOperations(py::module& nvfuser) {
bindSearchOps(nvf_ops);
bindSdpaOps(nvf_ops);
bindRandomOps(nvf_ops);
bindQuantizationOps(nvf_ops);
}

} // namespace nvfuser::python
27 changes: 27 additions & 0 deletions python/python_direct/python_translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,33 @@ class PythonTranslator : public OptInConstDispatch {
{out_tv});
}

// Map BlockQuantizationOp to python frontend
void handle(const BlockQuantizationOp* bqop) final {
NVF_ERROR(bqop != nullptr);
visited_vals_.insert(bqop->output(0));
visited_vals_.insert(bqop->output(1));

static const auto default_args = std::make_tuple(
KeywordArgument<decltype(bqop->globalScale())>{"global_scale", nullptr},
KeywordArgument<int64_t>{"block_size", 16},
KeywordArgument<bool>{"swizzle_block_scales", false},
KeywordArgument<DataType>{"dtype", DataType::Float4_e2m1fn});

auto dtype = bqop->quantizedOutput()->as<TensorView>()->dtype();
auto swizzled_block_scale =
bqop->blockScales()->as<TensorView>()->hasAllocation();
printer_.generateKwargsOperation(
"fd.ops.nv_block_quantize",
std::make_tuple(bqop->in()),
default_args,
std::make_tuple(
bqop->globalScale(),
bqop->blockSize(),
swizzled_block_scale,
dtype),
std::vector<const nvfuser::Val*>{bqop->output(0), bqop->output(1)});
}

// Map EmbeddingFwdOp to python frontend
void handle(const EmbeddingFwdOp* eop) final {
NVF_ERROR(eop != nullptr);
Expand Down
25 changes: 13 additions & 12 deletions runtime/block_quantization_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ __device__ void block_quantize_to_nvfp4(
if constexpr (std::is_same<T, float>::value) {
vec_in[i] = input[i];
} else if constexpr (std::is_same<T, __bfloat>::value) {
vec_in[i] = __bfloat2float(input[i]);
vec_in[i] = __bfloat2float(__float2bfloat(__bfloat2float(input[i])));
} else if constexpr (std::is_same<T, __half>::value) {
vec_in[i] = __half2float(input[i]);
}
Expand All @@ -107,22 +107,23 @@ __device__ void block_quantize_to_nvfp4(

// This division should be replaced with a multiplication
// by a reciprocal for better performance.
float scaled_max = block_max / 6.000000000e+00f;
// float scaled_max = block_max / 6.000000000e+00f;

constexpr float rcp_6f = 1.0f / 6.0f;

float scaled_max = 0.0f;
if constexpr (USE_GLOBAL_SCALE) {
scaled_max = scaled_max * global_scale[0];
scaled_max = block_max * global_scale[0] * rcp_6f;
} else {
scaled_max = block_max / 6.000000000e+00f;
}

float clamped_max = clamp(
scaled_max, 1.562500000e-02f, 4.480000000e+02f); // Clamp between 0 and 1

__e4m3 clamped_max_fp8 = __float2e4m3(clamped_max);
__e4m3 clamped_max_fp8 = __float2e4m3(scaled_max);

// Convert back from FP8 to float using __e4m32float
float clamped_max_converted = __e4m32float(clamped_max_fp8);
float clamped_max = __e4m32float(clamped_max_fp8);

if constexpr (USE_GLOBAL_SCALE) {
clamped_max_converted = clamped_max_converted / global_scale[0];
clamped_max = global_scale[0] / clamped_max;
}

// Write out the block scaling factor to global memory.
Expand Down Expand Up @@ -165,8 +166,8 @@ __device__ void block_quantize_to_nvfp4(
Array<float, ITEMS_PER_THREAD, ITEMS_PER_THREAD> clamped_vals;
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
float scaled_val = vec_in[i] / clamped_max_converted;
clamped_vals[i] = clamp(scaled_val, -6.000000000e+00f, 6.000000000e+00f);
// float scaled_val = vec_in[i] / clamped_max;
clamped_vals[i] = vec_in[i] * clamped_max;
}

Array<__e2m1, ITEMS_PER_THREAD, 1> fp4_vals;
Expand Down
Loading