Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU EP] Int4 support for QuantizeLinear, DequantizeLinear, and Transpose #20362

Merged
merged 79 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
40da679
Update include/framework/ with int4
adrianlizarraga Apr 17, 2024
e3e8a6b
Update onnxruntime_c_api.h with int4 type
adrianlizarraga Apr 17, 2024
5e01e0f
Update cpu_contrib_kernels.cc with int4 Q/DQ
adrianlizarraga Apr 17, 2024
44e0e02
Update framework/data_types.cc with int4 types
adrianlizarraga Apr 17, 2024
ce03eb2
Update onnxruntime map type info with int4 types
adrianlizarraga Apr 17, 2024
e159007
Update Tensor methods to calc int4 tensor data sizze
adrianlizarraga Apr 17, 2024
46c3d0d
Update function to map tensor_proto int4 to onnxruntime enum
adrianlizarraga Apr 17, 2024
583dae1
Update tensorprotoutils to handle int4 protobufs
adrianlizarraga Apr 17, 2024
0009a47
Add functions to map Int4x2 to an onnxruntime tensor type enum
adrianlizarraga Apr 17, 2024
d11a3d4
Update com.microsoft.DequantizeLinear schema to support int4 types fo…
adrianlizarraga Apr 17, 2024
208c403
Add option to disable int4 type in Conv and MatMul qdq node group sel…
adrianlizarraga Apr 17, 2024
f91ae69
Add DequantizeLinear with int4 support (missing block quant)
adrianlizarraga Apr 18, 2024
7323793
update transpose helper to support int4
adrianlizarraga Apr 18, 2024
8c79905
Update provider bridge with int4 apis
adrianlizarraga Apr 18, 2024
fc695ca
Update quantizer tool with int4
adrianlizarraga Apr 18, 2024
eeacb78
Remove duplicate enum
adrianlizarraga Apr 18, 2024
6f9da04
Remove MatMulSelector constructor arg
adrianlizarraga Apr 18, 2024
7e8c458
Remove unnecessary explicit template instantiation
adrianlizarraga Apr 18, 2024
3b7ed5f
Add static_cast
adrianlizarraga Apr 18, 2024
c7086a5
Add temporary CPU EP Int4 test (qdq conv)
adrianlizarraga Apr 18, 2024
34dfa17
Update operator docs
adrianlizarraga Apr 18, 2024
e7bec9c
Update testing version of tensorprotoutils with int4 helpers
adrianlizarraga Apr 18, 2024
cd8912e
Run lintrunner
adrianlizarraga Apr 18, 2024
4cf3a75
Fix api to create int4 ort value
adrianlizarraga Apr 18, 2024
ca785c2
Wrap long lines in tensorprotoutils
adrianlizarraga Apr 18, 2024
f87e785
Add operator unit tests for Dequant int4/uint4
adrianlizarraga Apr 18, 2024
d028f2f
Remove comments
adrianlizarraga Apr 18, 2024
24cc617
Add QuantizeLinear int4 impl
adrianlizarraga Apr 18, 2024
f35b09e
Update operator docs
adrianlizarraga Apr 18, 2024
e33f198
Disable potentially bugged onnx tests
adrianlizarraga Apr 18, 2024
10f28aa
Add TODO username
adrianlizarraga Apr 18, 2024
de1ded4
Fix warning as error and clean up
adrianlizarraga Apr 19, 2024
378718e
Merge branch 'main' into adrianl/dq-transpose-int4
adrianlizarraga Apr 20, 2024
746312b
Mlas kernels to quantize int4 (not blocked). Missing powerpc
adrianlizarraga Apr 20, 2024
2935f79
branchless update of 4-bit element
adrianlizarraga Apr 21, 2024
807537c
more branchless update of int4 lane
adrianlizarraga Apr 21, 2024
a36a128
Fix cast warning as error
adrianlizarraga Apr 21, 2024
bc44557
Remove decrement of N
adrianlizarraga Apr 22, 2024
f40992d
Clean up Int4x2 class
adrianlizarraga Apr 22, 2024
d0e17e2
Github linter fixes
adrianlizarraga Apr 22, 2024
6568d48
Remove temporary unittest
adrianlizarraga Apr 22, 2024
b8d5869
Case statement missing :
adrianlizarraga Apr 22, 2024
67b8cff
Merge branch 'main' into adrianl/dq-transpose-int4
adrianlizarraga May 2, 2024
934a063
Add powerpc int4 quant kernel
adrianlizarraga May 2, 2024
4853853
Try to exclude MLAS C++ code from Github's cpplint workflow. MLAS has…
adrianlizarraga May 3, 2024
5119384
Remove backslash from cpplint flags
adrianlizarraga May 3, 2024
279a50d
Template on sign instead of unpacked type
adrianlizarraga May 3, 2024
488117f
Use typename
adrianlizarraga May 3, 2024
4c862bd
Add utils to compute tensor storage size and num elements for sub-byt…
adrianlizarraga May 7, 2024
27c554e
Add more uses of Tensor::CalcTensorStorageSize()
adrianlizarraga May 7, 2024
e2ac5d2
Merge branch 'main' into adrianl/dq-transpose-int4
adrianlizarraga May 7, 2024
2e8c3b9
Add new PrimitiveDataTypeBase methods to provider api
adrianlizarraga May 7, 2024
f26d885
Remove SparseTensor registrations for int4 types
adrianlizarraga May 7, 2024
f3fdc2e
Support Transpose int4
adrianlizarraga May 7, 2024
8adbb4a
Revert to default types for older transpose opsets
adrianlizarraga May 7, 2024
0ac8427
Update op docs
adrianlizarraga May 7, 2024
da80e3a
Add comments
adrianlizarraga May 7, 2024
f72c3d5
Exclude TRT from int4 traspose test
adrianlizarraga May 7, 2024
6605512
Test C API for creating int4 OrtValues
adrianlizarraga May 7, 2024
ea96d09
Add comment to qmath macro for defining the int4 quantization functions
adrianlizarraga May 7, 2024
aa029ba
Clean up tensorprotoutils macros
adrianlizarraga May 7, 2024
2b9f53a
Use CalcNumInt4Pairs()
adrianlizarraga May 7, 2024
b89e6e9
Merge branch 'main' into adrianl/dq-transpose-int4
adrianlizarraga May 7, 2024
84c6c72
Add comment reference to ONNX PR that fixes int4 q/dq onnx node tests
adrianlizarraga May 7, 2024
7bee6f1
Add another use of CalcNumInt4Pairs() in base_tester.h
adrianlizarraga May 7, 2024
1435222
Add another use of CalcNumInt4Pairs() in cpu quantize_linear tests
adrianlizarraga May 7, 2024
5d8b029
Temporarily disable the block_size attribute for Q/DQ ops
adrianlizarraga May 7, 2024
ca2a1c5
Disable QDQ fusions for int4
adrianlizarraga May 7, 2024
0efc352
Add typename
adrianlizarraga May 7, 2024
7452329
Add python quantization unit test for int4 qdq
adrianlizarraga May 7, 2024
4b4df98
Merge branch 'main' into adrianl/dq-transpose-int4
adrianlizarraga May 7, 2024
8eec173
Merge latest main branch
adrianlizarraga May 29, 2024
6c06bfb
Review comments
adrianlizarraga May 30, 2024
09c11c6
Save one instruction in MlasSetInt4Element()
adrianlizarraga May 30, 2024
12d7d0e
Use workaround to ensure quant tool stores negative INT4 weights pack…
adrianlizarraga May 30, 2024
43c7bf1
Add int4 qdq quantization tool test
adrianlizarraga May 30, 2024
d4a05b7
Check opset when using int4 types with quant tool
adrianlizarraga May 30, 2024
27301e6
Check opset version when creating qdq config for QNN
adrianlizarraga May 30, 2024
8eea1c1
Merge branch 'main' into adrianl/dq-transpose-int4
adrianlizarraga May 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ jobs:
github_token: ${{ secrets.github_token }}
reporter: github-pr-check
level: warning
flags: --linelength=120 --exclude=java/src/main/native/*.c
flags: --linelength=120
--exclude=java/src/main/native/*.c
--exclude=onnxruntime/core/mlas/inc/*
--exclude=onnxruntime/core/mlas/lib/*
filter: "-runtime/references"

lint-js:
Expand Down
4 changes: 2 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1373,7 +1373,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32)</dt>
<dt><tt>T1</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int32), tensor(int4), tensor(uint4)</dt>
<dd>Constrain 'x' and 'x_zero_point' to 8-bit integer tensors, 16-bit integer tensors, or 32-bit signed integer tensors.</dd>
<dt><tt>T2</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain 'y', 'x_scale' to float tensors.</dd>
Expand Down Expand Up @@ -4832,7 +4832,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T1</tt> : tensor(float16), tensor(float)</dt>
<dd>Constrain 'x', 'y_scale' to float tensors.</dd>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16)</dt>
<dt><tt>T2</tt> : tensor(int8), tensor(uint8), tensor(int16), tensor(uint16), tensor(int4), tensor(uint4)</dt>
<dd>Constrain 'y_zero_point' and 'y' to 8-bit and 16-bit integer tensors.</dd>
</dl>

Expand Down
10 changes: 5 additions & 5 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ Do not modify directly.*
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[1, 10]|**T** = tensor(double), tensor(float)|
|DequantizeLinear|*in* x:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *out* y:**tensor(float)**<br><br>or<br><br>*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *out* y:**tensor(float)**<br><br>or<br><br>*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|21+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||[19, 20]|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int32), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||[13, 18]|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|||[10, 12]|**T** = tensor(int32), tensor(int8), tensor(uint8)|
Expand Down Expand Up @@ -259,7 +259,7 @@ Do not modify directly.*
|||[7, 11]|**T** = tensor(double), tensor(float)|
|QLinearConv|*in* x:**T1**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T1**<br> *in* w:**T2**<br> *in* w_scale:**tensor(float)**<br> *in* w_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *in* B:**T4**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)<br/> **T4** = tensor(int32)|
|QLinearMatMul|*in* a:**T1**<br> *in* a_scale:**TS**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**TS**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**TS**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**<br><br>or<br><br>*in* a:**T1**<br> *in* a_scale:**tensor(float)**<br> *in* a_zero_point:**T1**<br> *in* b:**T2**<br> *in* b_scale:**tensor(float)**<br> *in* b_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**<br><br>or<br><br>*in* x:**T1**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|||[19, 20]|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)|
|||[13, 18]|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|||[10, 12]|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
Expand Down Expand Up @@ -418,7 +418,7 @@ Do not modify directly.*
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|11+|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float)|
|||[1, 9]|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float)|
|Transpose|*in* data:**T**<br> *out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Transpose|*in* data:**T**<br> *out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)|
|||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)|
Expand Down Expand Up @@ -468,7 +468,7 @@ Do not modify directly.*
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *in* crop_size:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int32)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint8)<br/> **T2** = tensor(float)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(float)|
|DynamicQuantizeLSTM|*in* X:**T**<br> *in* W:**T2**<br> *in* R:**T2**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* W_scale:**T**<br> *in* W_zero_point:**T2**<br> *in* R_scale:**T**<br> *in* R_zero_point:**T2**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(float)<br/> **T1** = tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**<br> *in* B:**T2**<br> *in* b_scale:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -504,7 +504,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**<br> *in* X:**T**<br> *in* x_scale:**TF**<br> *in* x_zero_point:**T**<br> *in* Y:**T**<br> *in* y_scale:**TF**<br> *in* y_zero_point:**T**<br> *in* z_scale:**TF**<br> *in* z_zero_point:**T**<br> *out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float)|
Expand Down
57 changes: 42 additions & 15 deletions include/onnxruntime/core/framework/data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "core/framework/endian.h"
#include "core/framework/float8.h"
#include "core/framework/float16.h"
#include "core/framework/int4.h"
#include "core/graph/onnx_protobuf.h"
#include "core/framework/to_tensor_proto_element_type.h"

Expand Down Expand Up @@ -280,7 +281,8 @@ struct IsAnyOf<T, H, Tail...> {
template <typename T>
struct IsTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, uint16_t, int16_t,
int32_t, int64_t, std::string, bool, MLFloat16,
double, uint32_t, uint64_t, BFloat16
double, uint32_t, uint64_t, BFloat16,
Int4x2, UInt4x2
#if !defined(DISABLE_FLOAT8_TYPES)
,
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
Expand Down Expand Up @@ -917,7 +919,8 @@ class OpaqueType : public NonTensorType<T> {
* Base class for primitive Tensor contained types
*
* \details This class contains an integer constant that can be
* used for input data type dispatching
* used for input data type dispatching. This class also stores the number of subelements per size units.
* Example: For int4, the size unit is 1 byte and the number of subelements is 2.
*
*/
class PrimitiveDataTypeBase : public DataTypeImpl {
Expand All @@ -934,12 +937,21 @@ class PrimitiveDataTypeBase : public DataTypeImpl {
return data_type_;
}

int32_t GetNumSubElems() const {
return num_sub_elems_;
}

bool HasSubElems() const {
return num_sub_elems_ > 1;
}

protected:
PrimitiveDataTypeBase(size_t size, int32_t data_type)
: DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type} {}
PrimitiveDataTypeBase(size_t size, int32_t data_type, int32_t num_sub_elems)
: DataTypeImpl{GeneralType::kPrimitive, size}, data_type_{data_type}, num_sub_elems_{num_sub_elems} {}

private:
const int32_t data_type_;
const int32_t num_sub_elems_; // > 1 for subbyte primitives, 1 for normal primitives.
};

/**
Expand All @@ -965,9 +977,9 @@ class PrimitiveDataType : public PrimitiveDataTypeBase {
}

private:
PrimitiveDataType()
explicit PrimitiveDataType(int32_t num_sub_elems)
: PrimitiveDataTypeBase{sizeof(T),
utils::ToTensorProtoElementType<T>()} {
utils::ToTensorProtoElementType<T>(), num_sub_elems} {
}
};

Expand Down Expand Up @@ -1074,15 +1086,30 @@ inline const PrimitiveDataTypeBase* DataTypeImpl::AsPrimitiveDataType() const {
return SequenceTensorType<ELEM_TYPE>::Type(); \
}

#define ORT_REGISTER_PRIM_TYPE(TYPE) \
template <> \
MLDataType PrimitiveDataType<TYPE>::Type() { \
static PrimitiveDataType<TYPE> prim_data_type; \
return &prim_data_type; \
} \
template <> \
MLDataType DataTypeImpl::GetType<TYPE>() { \
return PrimitiveDataType<TYPE>::Type(); \
#define ORT_REGISTER_PRIM_TYPE(TYPE) \
template <> \
MLDataType PrimitiveDataType<TYPE>::Type() { \
static PrimitiveDataType<TYPE> prim_data_type(1); \
return &prim_data_type; \
} \
template <> \
MLDataType DataTypeImpl::GetType<TYPE>() { \
return PrimitiveDataType<TYPE>::Type(); \
}

// Registers a subbyte primitive.
// Examples:
// - Int4x2 stores 2 packed 4-bit elements in 1 byte: ORT_*_SUBBYTE_TYPE(Int4x2, 2)
// - [not supported] Int3x8 could store 8 packed 3-bit elements in 3 bytes: ORT_*_SUBBYTE_TYPE(Int3x8, 8)
#define ORT_REGISTER_PRIM_SUBBYTE_TYPE(TYPE, NUM_SUB_ELEMS) \
template <> \
MLDataType PrimitiveDataType<TYPE>::Type() { \
static PrimitiveDataType<TYPE> prim_data_type(NUM_SUB_ELEMS); \
return &prim_data_type; \
} \
template <> \
MLDataType DataTypeImpl::GetType<TYPE>() { \
return PrimitiveDataType<TYPE>::Type(); \
}

#define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \
Expand Down
24 changes: 24 additions & 0 deletions include/onnxruntime/core/framework/data_types_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
function<Float8E5M2FNUZ>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
function<Int4x2>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
function<UInt4x2>(__VA_ARGS__); \
break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
Expand Down Expand Up @@ -153,6 +159,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: \
retval = function<Float8E5M2FNUZ>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
retval = function<Int4x2>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
retval = function<UInt4x2>(__VA_ARGS__); \
break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
Expand Down Expand Up @@ -203,6 +215,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
function<BFloat16>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
function<Int4x2>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
function<UInt4x2>(__VA_ARGS__); \
break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
Expand Down Expand Up @@ -251,6 +269,12 @@ namespace utils {
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: \
retval = function<BFloat16>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_INT4: \
retval = function<Int4x2>(__VA_ARGS__); \
break; \
case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \
retval = function<UInt4x2>(__VA_ARGS__); \
break; \
default: \
ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \
}
Expand Down
Loading
Loading