forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
QTensor.cpp
392 lines (342 loc) · 12.4 KB
/
QTensor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/quantized/cpu/QuantUtils.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/quantized/Quantizer.h>
#include <c10/util/irange.h>
#include <cmath>
#include <utility>
namespace at {
namespace native {
Tensor quantize_per_tensor_dynamic(
const Tensor& self,
ScalarType dtype,
bool reduce_range) {
TORCH_CHECK( (dtype == ScalarType::QInt8 || dtype == ScalarType::QUInt8 || dtype == ScalarType::Half), "dtype ", dtype, "not supported");
auto input_contig = self.contiguous();
if (dtype == ScalarType::Half) {
return input_contig.to(ScalarType::Half);
}
float x_min = input_contig.min().item<float>();
float x_max = input_contig.max().item<float>();
if (reduce_range && at::globalContext().qEngine() == at::QEngine::QNNPACK) {
reduce_range = false;
}
int qmin;
int qmax;
if (dtype == ScalarType::QInt8) {
qmin = -128;
qmax = 127;
} else {
// for now, this branch executes for dtype == ScalarType::QUInt8
// additional cases will be added when quantization support for other dtypes becomes available
qmin = 0;
qmax = 255;
}
auto q_params = quant_utils::ChooseQuantizationParams(
/*min=*/x_min,
/*max=*/x_max,
/*qmin=*/qmin,
/*qmax=*/qmax,
/*preserve_sparsity=*/false,
/*force_scale_power_of_two=*/false,
/*reduce_range=*/reduce_range);
return at::native::quantize_per_tensor(self, q_params.scale, q_params.zero_point, dtype);
}
Tensor quantize_per_tensor(
const Tensor& self,
double scale,
int64_t zero_point,
ScalarType dtype) {
auto quantizer = make_per_tensor_affine_quantizer(scale, zero_point, dtype);
return quantizer->quantize(self);
}
Tensor quantize_per_tensor_tensor_qparams(
const Tensor& self,
const Tensor& scale,
const Tensor& zero_point,
ScalarType dtype) {
auto quantizer = make_per_tensor_affine_quantizer(scale.item().toDouble(), zero_point.item().toLong(), dtype);
return quantizer->quantize(self);
}
std::vector<Tensor> quantize_per_tensor_list_cpu(
TensorList tensors,
const Tensor& scales,
const Tensor& zero_points,
ScalarType dtype) {
std::vector<Tensor> quantized_tensors;
for (const auto i : c10::irange(tensors.size())) {
quantized_tensors.push_back(at::quantize_per_tensor(
tensors[i],
scales[i].item<double>(),
zero_points[i].item<int64_t>(),
dtype));
}
return quantized_tensors;
}
Tensor quantize_per_channel(
const Tensor& self,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis,
ScalarType dtype) {
auto quantizer = make_per_channel_affine_quantizer(scales, zero_points, axis, dtype);
return quantizer->quantize(self);
}
Tensor dequantize_cpu_or_cuda(const Tensor& self) {
return self.to(at::kFloat);
}
Tensor dequantize_quantized(const Tensor& self) {
return get_qtensorimpl(self)->quantizer()->dequantize(self);
}
std::vector<Tensor> dequantize_tensors_quantized_cpu(TensorList tensors) {
std::vector<Tensor> dequantized_tensors;
for (const auto & tensor : tensors) {
dequantized_tensors.push_back(tensor.dequantize());
}
return dequantized_tensors;
}
double q_scale_quant(const Tensor& self) {
auto quantizer = get_qtensorimpl(self)->quantizer();
TORCH_CHECK(quantizer->qscheme() == kPerTensorAffine);
return static_cast<PerTensorAffineQuantizer*>(quantizer.get())->scale();
}
int64_t q_zero_point_quant(const Tensor& self) {
auto quantizer = get_qtensorimpl(self)->quantizer();
TORCH_CHECK(quantizer->qscheme() == kPerTensorAffine);
return static_cast<PerTensorAffineQuantizer*>(quantizer.get())->zero_point();
}
Tensor q_per_channel_scales(const Tensor& self) {
auto quantizer = get_qtensorimpl(self)->quantizer();
TORCH_CHECK(quantizer->qscheme() == kPerChannelAffine || quantizer->qscheme() == kPerChannelAffineFloatQParams);
return static_cast<PerChannelAffineQuantizer*>(quantizer.get())->scales();
}
Tensor q_per_channel_zero_points(const Tensor& self) {
auto quantizer = get_qtensorimpl(self)->quantizer();
TORCH_CHECK(quantizer->qscheme() == kPerChannelAffine || quantizer->qscheme() == kPerChannelAffineFloatQParams);
return static_cast<PerChannelAffineQuantizer*>(quantizer.get())->zero_points();
}
int64_t q_per_channel_axis(const Tensor& self) {
auto quantizer = get_qtensorimpl(self)->quantizer();
TORCH_CHECK(quantizer->qscheme() == kPerChannelAffine || quantizer->qscheme() == kPerChannelAffineFloatQParams);
return static_cast<PerChannelAffineQuantizer*>(quantizer.get())->axis();
}
Tensor make_per_channel_quantized_tensor_cpu(
const Tensor& self,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
Tensor dst = at::_empty_per_channel_affine_quantized(
self.sizes(),
scales,
zero_points,
axis,
self.options().dtype(toQIntType(self.scalar_type())));
Tensor self_contig = self.contiguous();
AT_DISPATCH_QINT_TYPES(
dst.scalar_type(), "per_channel_affine_qtensor", [&]() {
underlying_t* self_data = self_contig.data_ptr<underlying_t>();
underlying_t* dst_data =
reinterpret_cast<underlying_t*>(dst.data_ptr<scalar_t>());
if (self.numel() > 0) {
memcpy(dst_data, self_data, self.nbytes());
}
});
return dst;
}
Tensor& set_storage_quantized_(
Tensor& self,
Storage storage,
int64_t storage_offset,
IntArrayRef sizes,
IntArrayRef strides) {
auto* self_ = self.unsafeGetTensorImpl();
self_->set_storage_keep_dtype(std::move(storage));
self_->set_storage_offset(storage_offset);
self_->set_sizes_and_strides(sizes, strides);
return self;
}
QScheme qscheme_quant(const Tensor& self) {
auto quantizer = get_qtensorimpl(self)->quantizer();
return quantizer->qscheme();
}
Tensor quantized_clone(
const Tensor& self,
c10::optional<c10::MemoryFormat> optional_memory_format) {
auto memory_format =
optional_memory_format.value_or(MemoryFormat::Contiguous);
// TODO: To support all features of MemoryFormat::Preserve we need to add
// _empty_affine_quantized_strided function and use it similarly to
// Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat>
// optional_memory_format) if (self.is_non_overlapping_and_dense()) ->
// _empty_affine_quantized_strided
if (memory_format == MemoryFormat::Preserve) {
memory_format = self.suggest_memory_format();
}
Tensor dst;
if (self.qscheme() == at::kPerTensorAffine) {
dst = at::_empty_affine_quantized(
self.sizes(),
self.options().memory_format(memory_format),
self.q_scale(),
self.q_zero_point(),
c10::nullopt);
} else if (self.qscheme() == at::kPerChannelAffine) {
dst = at::_empty_per_channel_affine_quantized(
self.sizes(),
self.q_per_channel_scales(),
self.q_per_channel_zero_points(),
self.q_per_channel_axis(),
self.options().memory_format(memory_format),
c10::nullopt);
} else {
TORCH_CHECK(false, "clone for quantized Tensor only works for \
PerTensorAffine and PerChannelAffine qscheme right now");
}
at::native::copy_(dst, self, false);
return dst;
}
bool equal_quantized_cpu(const Tensor& self, const Tensor& other) {
TORCH_CHECK(
self.device().type() == kCPU && other.device().type() == kCPU,
"quantized_equal is implemented only for the QuantizedCPU backend");
if (!self.is_quantized() || !other.is_quantized()) {
return false;
}
// Delegate to virtual equalTo method. This will ensure different concrete
// Quantizers can have specific logic for comparison
auto self_quantizer = get_qtensorimpl(self)->quantizer();
auto other_quantizer = get_qtensorimpl(other)->quantizer();
if (!self_quantizer->equalTo(other_quantizer)) {
return false;
}
// Sizes and element types must be the same
if (self.sizes() != other.sizes()) {
return false;
}
if (self.element_size() != other.element_size()) {
return false;
}
// Data must be the same
auto self_contig = self.contiguous();
auto other_contig = other.contiguous();
void* self_data = self_contig.data_ptr();
void* other_data = other_contig.data_ptr();
return 0 == memcmp(self_data, other_data, self.numel() * self.element_size());
}
/* Calculate the quantization params for the activation tensor */
std::tuple<double, int64_t> _choose_qparams_per_tensor(
const Tensor& self,
bool reduce_range) {
at::Tensor a;
auto input_contig = self.contiguous();
float x_min = input_contig.min().item<float>();
float x_max = input_contig.max().item<float>();
if (reduce_range && at::globalContext().qEngine() == at::QEngine::QNNPACK) {
reduce_range = false;
}
auto q_params = quant_utils::ChooseQuantizationParams(
/*min=*/x_min,
/*max=*/x_max,
/*qmin=*/0,
/*qmax=*/255,
/*preserve_sparsity=*/false,
/*force_scale_power_of_two=*/false,
/*reduce_range=*/reduce_range);
return std::make_tuple(q_params.scale, q_params.zero_point);
}
float calculate_quant_loss(
const float* input,
int numel,
float xmin,
float xmax,
float* q_input,
int bit_width) {
xmin = static_cast<at::Half>(xmin);
float data_range = xmax - xmin;
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
float qmax = (1 << bit_width) - 1;
float scale = data_range == 0
? 1.0
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
: static_cast<float>(static_cast<at::Half>(data_range / qmax));
float inverse_scale = scale == 0 ? 1.0f : 1.0f / scale;
float norm = 0.0f;
int i = 0;
// TODO add FBGEMM kernel
// #ifdef USE_FBGEMM
// #endif
// remainder loop
for (; i < numel; i++) {
q_input[i] = std::max(
0.0f, std::min<float>(std::nearbyint((input[i] - xmin) * inverse_scale), qmax));
q_input[i] = q_input[i] * scale + xmin;
norm += (input[i] - q_input[i]) * (input[i] - q_input[i]);
}
return std::sqrt(norm);
}
/*
Helper function to find the best min/max for a tensor to calculate qparams.
It uses a greedy approach to nudge the min and max and calculate the l2 norm
and tries to minimize the quant error by doing `torch.norm(x-fake_quant(x,s,z))`
Returns the optimized xmax and xmin value of the tensor.
*/
std::tuple<Tensor, Tensor> choose_qparams_optimized(
const at::Tensor& input_tensor,
int64_t numel,
const int64_t n_bins,
const double ratio,
int64_t bit_width) {
if (numel < 0 || numel > input_tensor.numel()) {
TORCH_CHECK(false, "numel is out of the bound of input tensor");
}
TORCH_CHECK(numel <= input_tensor.numel(), "numel ", numel,
" greater than input_tensor.numel() ", input_tensor.numel());
const float* input_row = input_tensor.data_ptr<float>();
float xmin = *std::min_element(input_row, input_row + numel);
float xmax = *std::max_element(input_row, input_row + numel);
float stepsize = (xmax - xmin) / n_bins;
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
int min_bins = n_bins * (1.0 - (float) ratio);
Tensor input_tensor_contig = input_tensor.contiguous();
const float* input = input_tensor_contig.data_ptr<float>();
std::vector<float> q_input(numel);
float loss =
calculate_quant_loss(input, numel, xmin, xmax, q_input.data(), bit_width);
float best_loss = loss;
float cur_min = xmin;
float cur_max = xmax;
float cur_loss = loss;
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
float thr = min_bins * stepsize;
while (cur_min + thr < cur_max) {
// move left
float loss1 = calculate_quant_loss(
input, numel, cur_min + stepsize, cur_max, q_input.data(), bit_width);
// move right
float loss2 = calculate_quant_loss(
input, numel, cur_min, cur_max - stepsize, q_input.data(), bit_width);
if (cur_loss < loss1 && cur_loss < loss2 && cur_loss < best_loss) {
// found a local optima
best_loss = cur_loss;
xmin = cur_min;
xmax = cur_max;
}
if (loss1 < loss2) {
cur_min = cur_min + stepsize;
cur_loss = loss1;
} else {
cur_max = cur_max - stepsize;
cur_loss = loss2;
}
}
at::Tensor xmax_tensor = at::empty({1});
at::Tensor xmin_tensor = at::empty({1});
xmax_tensor[0] = xmax;
xmin_tensor[0] = xmin;
return std::make_tuple(xmax_tensor, xmin_tensor);
}
} // namespace native
} // namespace at