6
6
#include < type_traits>
7
7
#include < vector>
8
8
9
+ #include " core/common/gsl.h"
9
10
#include " core/common/type_utils.h"
10
11
#include " core/graph/graph.h"
11
12
#include " core/framework/framework_common.h"
@@ -195,13 +196,14 @@ class ModelTestBuilder {
195
196
return &graph_.GetOrCreateNodeArg (name, &type_proto);
196
197
}
197
198
198
- template <typename T>
199
- NodeArg* MakeInitializer (const std::vector<int64_t >& shape, const std::vector<T>& data) {
199
+ // Makes an initializer from the provided shape, element type, and raw_data bytes.
200
+ NodeArg* MakeInitializer (gsl::span<const int64_t > shape, ONNX_NAMESPACE::TensorProto_DataType elem_type,
201
+ gsl::span<const std::byte> raw_data) {
200
202
std::string name = graph_.GenerateNodeArgName (" constant" );
201
203
ONNX_NAMESPACE::TensorProto tensor_proto;
202
204
tensor_proto.set_name (name);
203
- tensor_proto.set_data_type (utils::ToTensorProtoElementType<T>() );
204
- tensor_proto.set_raw_data (data .data (), data .size () * sizeof (T ));
205
+ tensor_proto.set_data_type (elem_type );
206
+ tensor_proto.set_raw_data (raw_data .data (), raw_data .size ());
205
207
206
208
for (auto & dim : shape) {
207
209
tensor_proto.add_dims (dim);
@@ -212,6 +214,12 @@ class ModelTestBuilder {
212
214
return &graph_.GetOrCreateNodeArg (name, nullptr );
213
215
}
214
216
217
+ template <typename T>
218
+ NodeArg* MakeInitializer (const std::vector<int64_t >& shape, const std::vector<T>& data) {
219
+ gsl::span<const std::byte> raw_data = ReinterpretAsSpan<const std::byte, const T>(data);
220
+ return MakeInitializer (shape, utils::ToTensorProtoElementType<T>(), raw_data);
221
+ }
222
+
215
223
// Special handle for std::vector<bool>.
216
224
NodeArg* MakeInitializerBool (const std::vector<int64_t >& shape, const std::vector<bool >& data) {
217
225
std::string name = graph_.GenerateNodeArgName (" constant" );
@@ -342,6 +350,57 @@ class ModelTestBuilder {
342
350
return AddNode (" QuantizeLinear" , input_args, {output_arg}, domain, attributes);
343
351
}
344
352
353
+ static std::vector<std::byte> GetZeroPointBytes (int64_t zero_point, ONNX_NAMESPACE::TensorProto_DataType type) {
354
+ switch (type) {
355
+ case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
356
+ int8_t val = static_cast <int8_t >(zero_point);
357
+ auto span = ReinterpretAsSpan<const std::byte, const int8_t >(gsl::make_span (&val, 1 ));
358
+ return std::vector<std::byte>(span.begin (), span.end ());
359
+ }
360
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
361
+ uint8_t val = static_cast <uint8_t >(zero_point);
362
+ auto span = ReinterpretAsSpan<const std::byte, const uint8_t >(gsl::make_span (&val, 1 ));
363
+ return std::vector<std::byte>(span.begin (), span.end ());
364
+ }
365
+ case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
366
+ int16_t val = static_cast <int16_t >(zero_point);
367
+ auto span = ReinterpretAsSpan<const std::byte, const int16_t >(gsl::make_span (&val, 1 ));
368
+ return std::vector<std::byte>(span.begin (), span.end ());
369
+ }
370
+ case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
371
+ uint16_t val = static_cast <uint16_t >(zero_point);
372
+ auto span = ReinterpretAsSpan<const std::byte, const uint16_t >(gsl::make_span (&val, 1 ));
373
+ return std::vector<std::byte>(span.begin (), span.end ());
374
+ }
375
+ case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
376
+ int32_t val = static_cast <int32_t >(zero_point);
377
+ auto span = ReinterpretAsSpan<const std::byte, const int32_t >(gsl::make_span (&val, 1 ));
378
+ return std::vector<std::byte>(span.begin (), span.end ());
379
+ }
380
+ default :
381
+ ORT_THROW (" Unhandled zero-point type " , type, " ." );
382
+ }
383
+ }
384
+
385
+ // Adds a Q node with a runtime configurable zero-point type.
386
+ // Takes in an int64_t zero_point value, which is large enough to represent all ONNX zero-point types.
387
+ Node& AddQuantizeLinearNode (NodeArg* input_arg,
388
+ float input_scale,
389
+ int64_t input_zero_point,
390
+ ONNX_NAMESPACE::TensorProto_DataType zero_point_type,
391
+ NodeArg* output_arg,
392
+ bool use_ms_domain = false ) {
393
+ std::vector<NodeArg*> input_args;
394
+ input_args.push_back (input_arg);
395
+ input_args.push_back (MakeScalarInitializer<float >(input_scale));
396
+
397
+ std::vector<std::byte> zp_bytes = GetZeroPointBytes (input_zero_point, zero_point_type);
398
+ input_args.push_back (MakeInitializer ({}, zero_point_type, zp_bytes));
399
+
400
+ std::string domain = use_ms_domain ? kMSDomain : " " ;
401
+ return AddNode (" QuantizeLinear" , input_args, {output_arg}, domain);
402
+ }
403
+
345
404
template <typename T>
346
405
typename std::enable_if<IsTypeDequantLinearCompatible<T>::value, Node&>::type
347
406
AddDequantizeLinearNode (NodeArg* input_arg,
@@ -400,6 +459,25 @@ class ModelTestBuilder {
400
459
return AddNode (" DequantizeLinear" , input_args, {output_arg}, domain, attributes);
401
460
}
402
461
462
+ // Adds a DQ node with a runtime configurable zero-point type.
463
+ // Takes in an int64_t zero_point value, which is large enough to represent all ONNX zero-point types.
464
+ Node& AddDequantizeLinearNode (NodeArg* input_arg,
465
+ float input_scale,
466
+ int64_t input_zero_point,
467
+ ONNX_NAMESPACE::TensorProto_DataType zero_point_type,
468
+ NodeArg* output_arg,
469
+ bool use_ms_domain = false ) {
470
+ std::vector<NodeArg*> input_args;
471
+ input_args.push_back (input_arg);
472
+ input_args.push_back (MakeScalarInitializer<float >(input_scale));
473
+
474
+ std::vector<std::byte> zp_bytes = GetZeroPointBytes (input_zero_point, zero_point_type);
475
+ input_args.push_back (MakeInitializer ({}, zero_point_type, zp_bytes));
476
+
477
+ std::string domain = use_ms_domain ? kMSDomain : " " ;
478
+ return AddNode (" DequantizeLinear" , input_args, {output_arg}, domain);
479
+ }
480
+
403
481
template <typename TWeight>
404
482
Node& AddQLinearConvNode (NodeArg* input_arg,
405
483
float input_scale,
0 commit comments