Skip to content

Commit b5dd94a

Browse files
pre-commit-ci[bot]phu0ngng
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 38d5be3 commit b5dd94a

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

transformer_engine/jax/cpp_extensions/quantization.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,10 @@ def _quantize_dbias_impl(
840840
# It is faster to use 1x quantization for tensor scaling and 2D NVFP4_1D_SCALING
841841
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
842842
force_1x_quantization = (
843-
(quantizer.scaling_mode.is_tensor_scaling() or quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING)
843+
(
844+
quantizer.scaling_mode.is_tensor_scaling()
845+
or quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING
846+
)
844847
and quantizer.q_layout.is_rowwise_colwise
845848
and is_1x_kernel_supported
846849
)
@@ -896,7 +899,10 @@ def _quantize_dbias_impl(
896899
rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
897900
)
898901

899-
if quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING and quantizer.q_layout.is_rowwise_colwise:
902+
if (
903+
quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING
904+
and quantizer.q_layout.is_rowwise_colwise
905+
):
900906
assert q_layout.is_rowwise_only
901907
# Quantizer requires 2x quantization, but we are using 1x quantization
902908
# for performance reasons, so we need to generate the colwise data in JAX
@@ -909,24 +915,27 @@ def _quantize_dbias_impl(
909915
flatten_axis = (flatten_axis + x.ndim) % x.ndim
910916
## Split the dim before the flatten_axis to (its size / block_size, block_size)
911917
colwise_scale_inv = rowwise_scale_inv.reshape(
912-
*scale_shape[:flatten_axis - 1],
918+
*scale_shape[: flatten_axis - 1],
913919
int(scale_shape[flatten_axis - 1] / 16),
914-
16, # <-- block_dim
920+
16, # <-- block_dim
915921
*scale_shape[flatten_axis:],
916922
)
917923
# now flatten_axis = flatten_axis + 1
918-
colwise_scale_inv = jnp.transpose(colwise_scale_inv,
919-
(*range(flatten_axis + 1, colwise_scale_inv.ndim),
920-
flatten_axis, # <-- block_dim after transpose
921-
*range(0, flatten_axis)),
922-
)
924+
colwise_scale_inv = jnp.transpose(
925+
colwise_scale_inv,
926+
(
927+
*range(flatten_axis + 1, colwise_scale_inv.ndim),
928+
flatten_axis, # <-- block_dim after transpose
929+
*range(0, flatten_axis),
930+
),
931+
)
923932
block_dim = colwise_scale_inv.ndim - flatten_axis - 1
924933
assert block_dim >= 1
925934
# Merge the block_dim back
926935
colwise_scale_inv = colwise_scale_inv.reshape(
927-
*colwise_scale_inv.shape[:block_dim - 1],
936+
*colwise_scale_inv.shape[: block_dim - 1],
928937
-1,
929-
*colwise_scale_inv.shape[block_dim + 1:],
938+
*colwise_scale_inv.shape[block_dim + 1 :],
930939
)
931940
quantizer.update(updated_amax)
932941
if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias:

0 commit comments

Comments
 (0)