@@ -840,7 +840,10 @@ def _quantize_dbias_impl(
840840 # It is faster to use 1x quantization for tensor scaling and 2D NVFP4_1D_SCALING
841841 is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability () < 100 )
842842 force_1x_quantization = (
843- (quantizer .scaling_mode .is_tensor_scaling () or quantizer .scaling_mode == ScalingMode .NVFP4_2D_SCALING )
843+ (
844+ quantizer .scaling_mode .is_tensor_scaling ()
845+ or quantizer .scaling_mode == ScalingMode .NVFP4_2D_SCALING
846+ )
844847 and quantizer .q_layout .is_rowwise_colwise
845848 and is_1x_kernel_supported
846849 )
@@ -896,7 +899,10 @@ def _quantize_dbias_impl(
896899 rowwise_casted_output , (* range (flatten_axis , x .ndim ), * range (flatten_axis ))
897900 )
898901
899- if quantizer .scaling_mode == ScalingMode .NVFP4_2D_SCALING and quantizer .q_layout .is_rowwise_colwise :
902+ if (
903+ quantizer .scaling_mode == ScalingMode .NVFP4_2D_SCALING
904+ and quantizer .q_layout .is_rowwise_colwise
905+ ):
900906 assert q_layout .is_rowwise_only
901907 # Quantizer requires 2x quantization, but we are using 1x quantization
902908 # for performance reasons, so we need to generate the colwise data in JAX
@@ -909,24 +915,27 @@ def _quantize_dbias_impl(
909915 flatten_axis = (flatten_axis + x .ndim ) % x .ndim
910916 ## Split the dim before the flatten_axis to (its size / block_size, block_size)
911917 colwise_scale_inv = rowwise_scale_inv .reshape (
912- * scale_shape [:flatten_axis - 1 ],
918+ * scale_shape [: flatten_axis - 1 ],
913919 int (scale_shape [flatten_axis - 1 ] / 16 ),
914- 16 , # <-- block_dim
920+ 16 , # <-- block_dim
915921 * scale_shape [flatten_axis :],
916922 )
917923 # now flatten_axis = flatten_axis + 1
918- colwise_scale_inv = jnp .transpose (colwise_scale_inv ,
919- (* range (flatten_axis + 1 , colwise_scale_inv .ndim ),
920- flatten_axis , # <-- block_dim after transpose
921- * range (0 , flatten_axis )),
922- )
924+ colwise_scale_inv = jnp .transpose (
925+ colwise_scale_inv ,
926+ (
927+ * range (flatten_axis + 1 , colwise_scale_inv .ndim ),
928+ flatten_axis , # <-- block_dim after transpose
929+ * range (0 , flatten_axis ),
930+ ),
931+ )
923932 block_dim = colwise_scale_inv .ndim - flatten_axis - 1
924933 assert block_dim >= 1
925934 # Merge the block_dim back
926935 colwise_scale_inv = colwise_scale_inv .reshape (
927- * colwise_scale_inv .shape [:block_dim - 1 ],
936+ * colwise_scale_inv .shape [: block_dim - 1 ],
928937 - 1 ,
929- * colwise_scale_inv .shape [block_dim + 1 :],
938+ * colwise_scale_inv .shape [block_dim + 1 :],
930939 )
931940 quantizer .update (updated_amax )
932941 if quantizer .scaling_mode .is_nvfp4_scaling and is_dbias :
0 commit comments