@@ -28,63 +28,54 @@ void quantize_gated_helper(const NVTETensor nvte_input, NVTETensor nvte_output,
2828 const Tensor input = *convertNVTETensorCheck (nvte_input);
2929 Tensor *output = convertNVTETensorCheck (nvte_output);
3030
31- const auto scaling_mode = output->scaling_mode ;
32- if ((scaling_mode != NVTE_DELAYED_TENSOR_SCALING) && !is_supported_by_CC_100 ()) {
33- NVTE_ERROR (" Not supported by the Arch < 10.0" );
34- }
35-
36- constexpr bool allow_empty = false ;
3731 CheckInputTensor (input, " input" );
38- CheckOutputTensor (*output, " output" , allow_empty);
39-
40- NVTE_CHECK (input.flat_last_dim () % 2 == 0 , " Number of columns must be even." );
32+ CheckOutputTensor (*output, " output" , /* allow_empty=*/ false );
4133
4234 const size_t rows = input.flat_first_dim ();
4335 const size_t cols = input.flat_last_dim () / 2 ;
4436
37+ NVTE_CHECK (input.flat_last_dim () % 2 == 0 ,
38+ " Wrong input shape. Expected (after flattening) last dimension to be even, " , " got [" ,
39+ input.flat_first_dim (), " , " , input.flat_last_dim (), " ]." );
40+ NVTE_CHECK (output->flat_last_dim () == cols,
41+ " Wrong output shape. Expected (after flattening) [*, " , cols,
42+ " ], got [" , output->flat_first_dim (), " , " , output->flat_last_dim (), " ]." );
43+
4544 NVTE_CHECK (output->has_data () || output->has_columnwise_data (),
4645 " Either rowwise or columnwise output data need to be allocated." );
4746
48- bool is_fp8_rowwise_output = true ;
49- bool is_fp8_colwise_output = true ;
50- if (output->has_data ()) {
51- is_fp8_rowwise_output = is_fp8_dtype (output->data .dtype );
52- NVTE_CHECK (output->flat_first_dim () == rows, " Wrong dimension of the output." );
53- NVTE_CHECK (output->flat_last_dim () == cols, " Wrong dimension of the output." );
54- }
55- if (output->has_columnwise_data ()) {
56- is_fp8_colwise_output = is_fp8_dtype (output->columnwise_data .dtype );
57- NVTE_CHECK (output->flat_first_dim () == rows, " Wrong dimension of the output." );
58- NVTE_CHECK (output->flat_last_dim () == cols, " Wrong dimension of the output." );
59- }
60-
61- const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && (cols % 32 == 0 ) &&
62- is_supported_by_CC_100 ();
63-
64- switch (scaling_mode) {
47+ switch (output->scaling_mode ) {
6548 case NVTE_DELAYED_TENSOR_SCALING: {
49+ const bool use_tma_kernels = (cols % 32 == 0 ) && is_supported_by_CC_100 ();
6650 if (use_tma_kernels) {
6751 Tensor dummy_tensor; // grad
68- fp8::cast_gated_dgated_tma< false , ParamOP, ActOP, nullptr >(dummy_tensor, input, output, p ,
69- stream);
52+ fp8::cast_gated_tma< /* IS_BWD= */ false , ParamOP, ActOP, nullptr >(dummy_tensor, input, output,
53+ p, stream);
7054 } else {
71- fp8::cast_gated <ParamOP, ActOP>(input, output, p, stream);
55+ fp8::cast_gated_fwd <ParamOP, ActOP>(input, output, p, stream);
7256 }
7357 break ;
7458 }
7559 case NVTE_MXFP8_1D_SCALING: {
76- if (use_tma_kernels) {
77- Tensor dummy_tensor; // grad
78- mxfp8::quantize_gated_dgated<false , ParamOP, ActOP, nullptr >(dummy_tensor, input, output, p,
79- stream);
80- } else {
81- NVTE_ERROR (" Invalid input shape. Expected the last dimension to be divisible " ,
82- " by 32, got input of shape " , input.data .shape );
60+ NVTE_CHECK (cols % 32 == 0 , " Invalid input shape. Expected the last dimension to be "
61+ " divisible by 32, but got " , cols, " ." );
62+ if (output->has_data ()) {
63+ NVTE_CHECK (is_fp8_dtype (output->data .dtype ),
64+ " The type of the output tensor should be FP8." );
65+ }
66+ if (output->has_columnwise_data ()) {
67+ NVTE_CHECK (is_fp8_dtype (output->columnwise_data .dtype ),
68+ " The type of the columnwise output tensor should be FP8." );
8369 }
70+ NVTE_CHECK (is_supported_by_CC_100 (),
71+ " Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+" );
72+ Tensor dummy_tensor; // grad
73+ mxfp8::quantize_gated</* IS_BWD=*/ false , ParamOP, ActOP, nullptr >
74+ (dummy_tensor, input, output, p, stream);
8475 break ;
8576 }
8677 default :
87- NVTE_ERROR (" Not supported scaling mode: " + to_string (scaling_mode) + " ." );
78+ NVTE_ERROR (" Not supported scaling mode: " + to_string (output-> scaling_mode ) + " ." );
8879 }
8980}
9081
@@ -97,68 +88,70 @@ void quantize_dgated_helper(const NVTETensor nvte_grad, const NVTETensor nvte_ga
9788 const Tensor gated_input = *convertNVTETensorCheck (nvte_gated_input);
9889 Tensor *output = convertNVTETensorCheck (nvte_output);
9990
100- const auto scaling_mode = output->scaling_mode ;
101- if ((scaling_mode != NVTE_DELAYED_TENSOR_SCALING) && !is_supported_by_CC_100 ()) {
102- NVTE_ERROR (" Not supported by the Arch < 10.0" );
103- }
104-
105- constexpr bool allow_empty = false ;
91+ CheckInputTensor (grad, " grad" );
10692 CheckInputTensor (gated_input, " gated_input" );
107- CheckOutputTensor (*output, " output" , allow_empty);
93+ CheckOutputTensor (*output, " output" , /* allow_empty= */ false );
10894
109- NVTE_CHECK (gated_input.flat_last_dim () % 2 == 0 , " Number of columns must be even." );
95+ NVTE_CHECK (gated_input.flat_last_dim () % 2 == 0 , " Number of columns must be even, but got " ,
96+ gated_input.flat_last_dim (), " ." );
11097
11198 const size_t rows = gated_input.flat_first_dim ();
11299 const size_t cols = gated_input.flat_last_dim () / 2 ;
113- const size_t output_cols = 2 * cols;
114100
115- CheckInputTensor (grad, " grad" );
116101 NVTE_CHECK (!is_fp8_dtype (grad.data .dtype ), " Grad input must be in higher precision." );
117102 NVTE_CHECK (grad.data .dtype == gated_input.data .dtype , " Types of both inputs must match." );
118- NVTE_CHECK (grad.flat_first_dim () == rows, " Wrong dimension of the grad input." );
119- NVTE_CHECK (grad.flat_last_dim () == cols, " Wrong dimension of the grad input." );
103+
104+ NVTE_CHECK (grad.flat_first_dim () == rows,
105+ " Wrong Grad shape. Expected first dimension (after flattening) [" , rows,
106+ " , *], got [" , grad.flat_first_dim (), " , " , grad.flat_last_dim (), " ]." );
107+ NVTE_CHECK (grad.flat_last_dim () == cols,
108+ " Wrong Grad shape. Expected last dimension (after flattening) [" , cols,
109+ " , *], got [" , grad.flat_first_dim (), " , " , grad.flat_last_dim (), " ]." );
120110
121111 NVTE_CHECK (output->has_data () || output->has_columnwise_data (),
122112 " Either rowwise or columnwise output data need to be allocated." );
123113
124- bool is_fp8_rowwise_output = true ;
125- bool is_fp8_colwise_output = true ;
126- if (output->has_data ()) {
127- is_fp8_rowwise_output = is_fp8_dtype (output->data .dtype );
128- NVTE_CHECK (output->flat_first_dim () == rows, " Wrong dimension of the output." );
129- NVTE_CHECK (output->flat_last_dim () == output_cols, " Wrong dimension of the output." );
130- }
131- if (output->has_columnwise_data ()) {
132- is_fp8_colwise_output = is_fp8_dtype (output->columnwise_data .dtype );
133- NVTE_CHECK (output->flat_first_dim () == rows, " Wrong dimension of the output." );
134- NVTE_CHECK (output->flat_last_dim () == output_cols, " Wrong dimension of the output." );
135- }
136-
137- const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && (cols % 32 == 0 ) &&
138- is_supported_by_CC_100 ();
139-
140- switch (scaling_mode) {
114+ NVTE_CHECK (output->flat_first_dim () == rows,
115+ " Wrong output shape. Expected (after flattening) [" , rows,
116+ " , *], got [" , output->flat_first_dim (), " , " , output->flat_last_dim (), " ]." );
117+ NVTE_CHECK (output->flat_last_dim () == cols * 2 ,
118+ " Wrong output shape. Expected (after flattening) [*, " , cols * 2 ,
119+ " ], got [" , output->flat_first_dim (), " , " , output->flat_last_dim (), " ]." );
120+ NVTE_CHECK (gated_input.data .shape == output->data .shape ,
121+ " Gated input and output shapes must match. Input shape: " , gated_input.data .shape ,
122+ " , output shape: " , output->data .shape , " ." );
123+
124+ switch (output->scaling_mode ) {
141125 case NVTE_DELAYED_TENSOR_SCALING: {
126+ const bool use_tma_kernels = (cols % 32 == 0 ) && is_supported_by_CC_100 ();
142127 if (use_tma_kernels) {
143- fp8::cast_gated_dgated_tma< true , ParamOP, ActOP, DActOP>(grad, gated_input, output, p,
144- stream);
128+ fp8::cast_gated_tma< /* IS_BWD= */ true , ParamOP, ActOP, DActOP>(grad, gated_input, output, p,
129+ stream);
145130 } else {
146- fp8::cast_dgated <ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
131+ fp8::cast_gated_bwd <ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
147132 }
148133 break ;
149134 }
150135 case NVTE_MXFP8_1D_SCALING: {
151- if (use_tma_kernels) {
152- mxfp8::quantize_gated_dgated<true , ParamOP, ActOP, DActOP>(grad, gated_input, output, p,
153- stream);
154- } else {
155- NVTE_ERROR (" Invalid input shape. Expected the last dimension to be divisible " ,
156- " by 32, got input of shape " , gated_input.data .shape );
136+ NVTE_CHECK (cols % 32 == 0 , " Invalid input shape. Expected the last dimension to be "
137+ " divisible by 32, but got " , cols, " ." );
138+ if (output->has_data ()) {
139+ NVTE_CHECK (is_fp8_dtype (output->data .dtype ),
140+ " The type of the output tensor should be FP8." );
141+ }
142+ if (output->has_columnwise_data ()) {
143+ NVTE_CHECK (is_fp8_dtype (output->columnwise_data .dtype ),
144+ " The type of the columnwise output tensor should be FP8." );
157145 }
146+ NVTE_CHECK (is_supported_by_CC_100 (),
147+ " Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+" );
148+
149+ mxfp8::quantize_gated</* IS_BWD=*/ true , ParamOP, ActOP, DActOP>
150+ (grad, gated_input, output, p, stream);
158151 break ;
159152 }
160153 default :
161- NVTE_ERROR (" Not supported scaling mode: " + to_string (scaling_mode) + " ." );
154+ NVTE_ERROR (" Not supported scaling mode: " + to_string (output-> scaling_mode ) + " ." );
162155 }
163156}
164157} // namespace dispatch
0 commit comments