Skip to content

Commit

Permalink
add bf16 for Tile CUDA executor (#20854)
Browse files Browse the repository at this point in the history
### Description
add bf16 for Tile CUDA executor



### Motivation and Context
required change to support phimm model for ORT training
  • Loading branch information
frank-dong-ms authored Jun 17, 2024
1 parent 0babc33 commit 8aa2667
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ Do not modify directly.*
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|ThresholdedRelu|*in* X:**T**<br> *out* Y:**T**|10+|**T** = tensor(double), tensor(float), tensor(float16)|
|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|11+|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/tensor/tile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ ONNX_OPERATOR_KERNEL_EX(
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<MLFloat16>()})
DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<BFloat16>()})
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int64_t>()),
Tile);

Expand Down

0 comments on commit 8aa2667

Please sign in to comment.