Skip to content

Commit b764dea

Browse files
Fixes per the PR review
Signed-off-by: Oleg Goncharov <[email protected]>
1 parent 9afdba1 commit b764dea

File tree

11 files changed

+350
-380
lines changed

11 files changed

+350
-380
lines changed

transformer_engine/common/cast/dispatch/dequantize.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t
2727

2828
switch (input.scaling_mode) {
2929
case NVTE_DELAYED_TENSOR_SCALING: {
30+
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
31+
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
32+
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
3033
fp8::dequantize(input, output, stream);
3134
break;
3235
}

transformer_engine/common/cast/dispatch/gated.cuh

Lines changed: 69 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -28,63 +28,54 @@ void quantize_gated_helper(const NVTETensor nvte_input, NVTETensor nvte_output,
2828
const Tensor input = *convertNVTETensorCheck(nvte_input);
2929
Tensor *output = convertNVTETensorCheck(nvte_output);
3030

31-
const auto scaling_mode = output->scaling_mode;
32-
if ((scaling_mode != NVTE_DELAYED_TENSOR_SCALING) && !is_supported_by_CC_100()) {
33-
NVTE_ERROR("Not supported by the Arch < 10.0");
34-
}
35-
36-
constexpr bool allow_empty = false;
3731
CheckInputTensor(input, "input");
38-
CheckOutputTensor(*output, "output", allow_empty);
39-
40-
NVTE_CHECK(input.flat_last_dim() % 2 == 0, "Number of columns must be even.");
32+
CheckOutputTensor(*output, "output", /*allow_empty=*/false);
4133

4234
const size_t rows = input.flat_first_dim();
4335
const size_t cols = input.flat_last_dim() / 2;
4436

37+
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
38+
"Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
39+
input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
40+
NVTE_CHECK(output->flat_last_dim() == cols,
41+
"Wrong output shape. Expected (after flattening) [*, ", cols,
42+
"], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
43+
4544
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
4645
"Either rowwise or columnwise output data need to be allocated.");
4746

48-
bool is_fp8_rowwise_output = true;
49-
bool is_fp8_colwise_output = true;
50-
if (output->has_data()) {
51-
is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype);
52-
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output.");
53-
NVTE_CHECK(output->flat_last_dim() == cols, "Wrong dimension of the output.");
54-
}
55-
if (output->has_columnwise_data()) {
56-
is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype);
57-
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output.");
58-
NVTE_CHECK(output->flat_last_dim() == cols, "Wrong dimension of the output.");
59-
}
60-
61-
const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && (cols % 32 == 0) &&
62-
is_supported_by_CC_100();
63-
64-
switch (scaling_mode) {
47+
switch (output->scaling_mode) {
6548
case NVTE_DELAYED_TENSOR_SCALING: {
49+
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
6650
if (use_tma_kernels) {
6751
Tensor dummy_tensor; // grad
68-
fp8::cast_gated_dgated_tma<false, ParamOP, ActOP, nullptr>(dummy_tensor, input, output, p,
69-
stream);
52+
fp8::cast_gated_tma</*IS_BWD=*/false, ParamOP, ActOP, nullptr>(dummy_tensor, input, output,
53+
p, stream);
7054
} else {
71-
fp8::cast_gated<ParamOP, ActOP>(input, output, p, stream);
55+
fp8::cast_gated_fwd<ParamOP, ActOP>(input, output, p, stream);
7256
}
7357
break;
7458
}
7559
case NVTE_MXFP8_1D_SCALING: {
76-
if (use_tma_kernels) {
77-
Tensor dummy_tensor; // grad
78-
mxfp8::quantize_gated_dgated<false, ParamOP, ActOP, nullptr>(dummy_tensor, input, output, p,
79-
stream);
80-
} else {
81-
NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ",
82-
"by 32, got input of shape ", input.data.shape);
60+
NVTE_CHECK(cols % 32 == 0, "Invalid input shape. Expected the last dimension to be "
61+
"divisible by 32, but got ", cols, ".");
62+
if (output->has_data()) {
63+
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
64+
"The type of the output tensor should be FP8.");
65+
}
66+
if (output->has_columnwise_data()) {
67+
NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype),
68+
"The type of the columnwise output tensor should be FP8.");
8369
}
70+
NVTE_CHECK(is_supported_by_CC_100(),
71+
"Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+");
72+
Tensor dummy_tensor; // grad
73+
mxfp8::quantize_gated</*IS_BWD=*/false, ParamOP, ActOP, nullptr>
74+
(dummy_tensor, input, output, p, stream);
8475
break;
8576
}
8677
default:
87-
NVTE_ERROR("Not supported scaling mode: " + to_string(scaling_mode) + ".");
78+
NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + ".");
8879
}
8980
}
9081

@@ -97,68 +88,70 @@ void quantize_dgated_helper(const NVTETensor nvte_grad, const NVTETensor nvte_ga
9788
const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input);
9889
Tensor *output = convertNVTETensorCheck(nvte_output);
9990

100-
const auto scaling_mode = output->scaling_mode;
101-
if ((scaling_mode != NVTE_DELAYED_TENSOR_SCALING) && !is_supported_by_CC_100()) {
102-
NVTE_ERROR("Not supported by the Arch < 10.0");
103-
}
104-
105-
constexpr bool allow_empty = false;
91+
CheckInputTensor(grad, "grad");
10692
CheckInputTensor(gated_input, "gated_input");
107-
CheckOutputTensor(*output, "output", allow_empty);
93+
CheckOutputTensor(*output, "output", /*allow_empty=*/false);
10894

109-
NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even.");
95+
NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ",
96+
gated_input.flat_last_dim(), ".");
11097

11198
const size_t rows = gated_input.flat_first_dim();
11299
const size_t cols = gated_input.flat_last_dim() / 2;
113-
const size_t output_cols = 2 * cols;
114100

115-
CheckInputTensor(grad, "grad");
116101
NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision.");
117102
NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match.");
118-
NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input.");
119-
NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input.");
103+
104+
NVTE_CHECK(grad.flat_first_dim() == rows,
105+
"Wrong Grad shape. Expected first dimension (after flattening) [", rows,
106+
", *], got [", grad.flat_first_dim(), ", ", grad.flat_last_dim(), "].");
107+
NVTE_CHECK(grad.flat_last_dim() == cols,
108+
"Wrong Grad shape. Expected last dimension (after flattening) [", cols,
109+
", *], got [", grad.flat_first_dim(), ", ", grad.flat_last_dim(), "].");
120110

121111
NVTE_CHECK(output->has_data() || output->has_columnwise_data(),
122112
"Either rowwise or columnwise output data need to be allocated.");
123113

124-
bool is_fp8_rowwise_output = true;
125-
bool is_fp8_colwise_output = true;
126-
if (output->has_data()) {
127-
is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype);
128-
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output.");
129-
NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output.");
130-
}
131-
if (output->has_columnwise_data()) {
132-
is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype);
133-
NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output.");
134-
NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output.");
135-
}
136-
137-
const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && (cols % 32 == 0) &&
138-
is_supported_by_CC_100();
139-
140-
switch (scaling_mode) {
114+
NVTE_CHECK(output->flat_first_dim() == rows,
115+
"Wrong output shape. Expected (after flattening) [", rows,
116+
", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
117+
NVTE_CHECK(output->flat_last_dim() == cols * 2,
118+
"Wrong output shape. Expected (after flattening) [*, ", cols * 2,
119+
"], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
120+
NVTE_CHECK(gated_input.data.shape == output->data.shape,
121+
"Gated input and output shapes must match. Input shape: ", gated_input.data.shape,
122+
", output shape: ", output->data.shape, ".");
123+
124+
switch (output->scaling_mode) {
141125
case NVTE_DELAYED_TENSOR_SCALING: {
126+
const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
142127
if (use_tma_kernels) {
143-
fp8::cast_gated_dgated_tma<true, ParamOP, ActOP, DActOP>(grad, gated_input, output, p,
144-
stream);
128+
fp8::cast_gated_tma</*IS_BWD=*/true, ParamOP, ActOP, DActOP>(grad, gated_input, output, p,
129+
stream);
145130
} else {
146-
fp8::cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
131+
fp8::cast_gated_bwd<ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
147132
}
148133
break;
149134
}
150135
case NVTE_MXFP8_1D_SCALING: {
151-
if (use_tma_kernels) {
152-
mxfp8::quantize_gated_dgated<true, ParamOP, ActOP, DActOP>(grad, gated_input, output, p,
153-
stream);
154-
} else {
155-
NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ",
156-
"by 32, got input of shape ", gated_input.data.shape);
136+
NVTE_CHECK(cols % 32 == 0, "Invalid input shape. Expected the last dimension to be "
137+
"divisible by 32, but got ", cols, ".");
138+
if (output->has_data()) {
139+
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
140+
"The type of the output tensor should be FP8.");
141+
}
142+
if (output->has_columnwise_data()) {
143+
NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype),
144+
"The type of the columnwise output tensor should be FP8.");
157145
}
146+
NVTE_CHECK(is_supported_by_CC_100(),
147+
"Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+");
148+
149+
mxfp8::quantize_gated</*IS_BWD=*/true, ParamOP, ActOP, DActOP>
150+
(grad, gated_input, output, p, stream);
158151
break;
159152
}
160153
default:
161-
NVTE_ERROR("Not supported scaling mode: " + to_string(scaling_mode) + ".");
154+
NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + ".");
162155
}
163156
}
164157
} // namespace dispatch

transformer_engine/common/cast/dispatch/quantize.cuh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
6666
"Stochastic rounding is only supported for NVFP4 quantization.");
6767
}
6868

69+
NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(),
70+
"Either rowwise or columnwise output data need to be allocated.");
71+
6972
// Dispatch to quantization kernel depending on data format
7073
switch (output_tensor->scaling_mode) {
7174
case NVTE_DELAYED_TENSOR_SCALING: {
@@ -105,8 +108,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
105108
int32_t rows = input_tensor->flat_first_dim();
106109
int32_t cols = input_tensor->flat_last_dim();
107110
auto dtype = input_tensor->dtype();
108-
bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 &&
109-
output_tensor->has_data();
111+
bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0)
112+
&& (cols % 32 == 0) && output_tensor->has_data();
110113

111114
// Launch NVFP4 quantize kernel
112115
if (use_optimized_kernel) {

transformer_engine/common/cast/fp8/dequantize_fp8.cuh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ __device__ inline float dequantize_func(float value, const DequantizeParam &para
3333
}
3434

3535
inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
36-
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
37-
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
38-
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
39-
4036
const size_t N = product(input.data.shape);
4137
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
4238
input.data.dtype, IType,

0 commit comments

Comments
 (0)