Skip to content

Commit

Permalink
[DML EP] Update DirectML Helper files (#21709)
Browse files Browse the repository at this point in the history
### Description
Now since 1.18.2 uses DML 1.15.1, it requires to update the
corresponding DirectML helper files as well. It is not needed for 1.18.1
because 1.18.1 is not using DML1.14.0



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Sheil Kumar <[email protected]>
Co-authored-by: Sheil Kumar <[email protected]>
Co-authored-by: Dwayne Robinson <[email protected]>
  • Loading branch information
4 people authored Aug 12, 2024
1 parent f4f4953 commit 9691af1
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 62 deletions.
8 changes: 6 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,7 @@ Do not modify directly.*
|||12+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
|Col2Im|*in* input:**T**<br> *in* image_shape:**tensor(int64)**<br> *in* block_shape:**tensor(int64)**<br> *out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Concat|*in* inputs:**T**<br> *out* concat_result:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||4+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down Expand Up @@ -1129,7 +1130,8 @@ Do not modify directly.*
|PRelu|*in* X:**T**<br> *in* slope:**T**<br> *out* Y:**T**|16+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||7+|**T** = tensor(float), tensor(float16)|
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *in* axes:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *in* axes:**Tind**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* constant_value:**T**<br> *out* output:**T**<br><br>or<br><br>*in* data:**T**<br> *out* output:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||2+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down Expand Up @@ -1197,7 +1199,9 @@ Do not modify directly.*
|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|Resize|*in* X:**T**<br> *in* scales:**tensor(float)**<br> *out* Y:**T**<br><br>or<br><br>*in* X:**T1**<br> *in* roi:**T2**<br> *in* scales:**tensor(float)**<br> *in* sizes:**tensor(int64)**<br> *out* Y:**T1**|19+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||18+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||11+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|||10+|**T** = tensor(float), tensor(float16)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//---------------------------------------------------------------------------
//---------------------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//
// This file is automatically generated. Please do not edit it directly.
Expand Down Expand Up @@ -241,6 +241,7 @@ DML_OPERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value
{"DML_OPERATOR_ACTIVATION_SWISH", DML_OPERATOR_ACTIVATION_SWISH},
{"DML_OPERATOR_ACTIVATION_HARD_SWISH", DML_OPERATOR_ACTIVATION_HARD_SWISH},
{"DML_OPERATOR_RESAMPLE2", DML_OPERATOR_RESAMPLE2},
{"DML_OPERATOR_RESAMPLE3", DML_OPERATOR_RESAMPLE3},
{"DML_OPERATOR_RESAMPLE_GRAD1", DML_OPERATOR_RESAMPLE_GRAD1},
{"DML_OPERATOR_DIAGONAL_MATRIX1", DML_OPERATOR_DIAGONAL_MATRIX1},
{"DML_OPERATOR_MULTIHEAD_ATTENTION", DML_OPERATOR_MULTIHEAD_ATTENTION},
Expand All @@ -250,6 +251,9 @@ DML_OPERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value
{"DML_OPERATOR_MULTIHEAD_ATTENTION1", DML_OPERATOR_MULTIHEAD_ATTENTION1},
{"DML_OPERATOR_QUANTIZE", DML_OPERATOR_QUANTIZE},
{"DML_OPERATOR_DEQUANTIZE", DML_OPERATOR_DEQUANTIZE},
{"DML_OPERATOR_ROI_ALIGN_GRAD", DML_OPERATOR_ROI_ALIGN_GRAD},
{"DML_OPERATOR_FOLD", DML_OPERATOR_FOLD},
{"DML_OPERATOR_UNFOLD", DML_OPERATOR_UNFOLD},
};
auto index = StringUtil::MapToIndex(value, mapping);
if (!index)
Expand Down Expand Up @@ -369,6 +373,7 @@ DML_PADDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value)
{"DML_PADDING_MODE_EDGE", DML_PADDING_MODE_EDGE},
{"DML_PADDING_MODE_REFLECTION", DML_PADDING_MODE_REFLECTION},
{"DML_PADDING_MODE_SYMMETRIC", DML_PADDING_MODE_SYMMETRIC},
{"DML_PADDING_MODE_WRAP", DML_PADDING_MODE_WRAP},
};
auto index = StringUtil::MapToIndex(value, mapping);
if (!index)
Expand Down Expand Up @@ -454,6 +459,7 @@ DML_FEATURE_LEVEL ApiTraits::StringifyHelpers::FromString(std::string_view value
{"DML_FEATURE_LEVEL_6_1", DML_FEATURE_LEVEL_6_1},
{"DML_FEATURE_LEVEL_6_2", DML_FEATURE_LEVEL_6_2},
{"DML_FEATURE_LEVEL_6_3", DML_FEATURE_LEVEL_6_3},
{"DML_FEATURE_LEVEL_6_4", DML_FEATURE_LEVEL_6_4},
};
auto index = StringUtil::MapToIndex(value, mapping);
if (!index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ union ActivationOperatorDescUnion
DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_DESC thresholdedRelu;
DML_ACTIVATION_SHRINK_OPERATOR_DESC shrink;
DML_ACTIVATION_GELU_OPERATOR_DESC gelu;
DML_ACTIVATION_SWISH_OPERATOR_DESC swish;
DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC hardSwish;
DML_ELEMENT_WISE_CLIP_OPERATOR_DESC clip;
};

struct ActivationOperatorDesc
Expand All @@ -46,7 +49,7 @@ struct ActivationOperatorDesc
case DML_OPERATOR_ACTIVATION_CELU: return { activationType, &params.celu };
case DML_OPERATOR_ACTIVATION_HARDMAX: return { activationType, &params.hardmax };
case DML_OPERATOR_ACTIVATION_HARDMAX1: return { activationType, &params.hardmax1 };
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return { activationType, &params.sigmoid };
case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return { activationType, &params.hardSigmoid };
case DML_OPERATOR_ACTIVATION_IDENTITY: return { activationType, &params.identity };
case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return { activationType, &params.leakyRelu };
case DML_OPERATOR_ACTIVATION_LINEAR: return { activationType, &params.linear };
Expand All @@ -66,6 +69,9 @@ struct ActivationOperatorDesc
case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return { activationType, &params.thresholdedRelu };
case DML_OPERATOR_ACTIVATION_SHRINK: return { activationType, &params.shrink };
case DML_OPERATOR_ACTIVATION_GELU: return { activationType, &params.gelu };
case DML_OPERATOR_ACTIVATION_SWISH: return { activationType, &params.swish };
case DML_OPERATOR_ACTIVATION_HARD_SWISH: return { activationType, &params.hardSwish };
case DML_OPERATOR_ELEMENT_WISE_CLIP: return { activationType, &params.clip };
default:
ORT_THROW_HR(E_INVALIDARG);
return { activationType, &params.relu };
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

Expand All @@ -24,7 +24,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 174;
static constexpr auto ValueCount = 178;
static constexpr size_t ActivationFunctionCount = 26;
};

Expand Down Expand Up @@ -62,7 +62,7 @@ struct EnumTraits<DML_CONVOLUTION_DIRECTION>
template <>
struct EnumTraits<DML_PADDING_MODE>
{
static constexpr auto ValueCount = 4;
static constexpr auto ValueCount = 5;
};

template <>
Expand All @@ -86,7 +86,7 @@ struct EnumTraits<DML_FEATURE>
template <>
struct EnumTraits<DML_FEATURE_LEVEL>
{
static constexpr auto ValueCount = 14;
static constexpr auto ValueCount = 15;
};

template <>
Expand Down Expand Up @@ -1023,6 +1023,12 @@ struct OperatorDescTraits<DML_RESAMPLE2_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE2;
};

template <>
struct OperatorDescTraits<DML_RESAMPLE3_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_RESAMPLE3;
};

template <>
struct OperatorDescTraits<DML_RESAMPLE_GRAD1_OPERATOR_DESC>
{
Expand Down Expand Up @@ -1053,6 +1059,18 @@ struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
};

template <>
struct OperatorDescTraits<DML_FOLD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_FOLD;
};

template <>
struct OperatorDescTraits<DML_UNFOLD_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_UNFOLD;
};

template <>
struct OperatorDescTraits<DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC>
{
Expand Down Expand Up @@ -2073,6 +2091,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE2>
using DescType = DML_RESAMPLE2_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE3>
{
using DescType = DML_RESAMPLE3_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_RESAMPLE_GRAD1>
{
Expand Down Expand Up @@ -2103,6 +2127,18 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGE
using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_FOLD>
{
using DescType = DML_FOLD_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_UNFOLD>
{
using DescType = DML_UNFOLD_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2>
{
Expand Down Expand Up @@ -2575,6 +2611,8 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_RESAMPLE2:
return std::invoke(std::forward<Visitor>(visitor), DML_RESAMPLE2_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_RESAMPLE3:
return std::invoke(std::forward<Visitor>(visitor), DML_RESAMPLE3_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_RESAMPLE_GRAD1:
return std::invoke(std::forward<Visitor>(visitor), DML_RESAMPLE_GRAD1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_DIAGONAL_MATRIX1:
Expand All @@ -2585,6 +2623,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_FOLD:
return std::invoke(std::forward<Visitor>(visitor), DML_FOLD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_UNFOLD:
return std::invoke(std::forward<Visitor>(visitor), DML_UNFOLD_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2:
return std::invoke(std::forward<Visitor>(visitor), DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MULTIHEAD_ATTENTION1:
Expand Down Expand Up @@ -2650,7 +2692,6 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
}
}


namespace StringifyHelpers
{
template <typename T>
Expand Down Expand Up @@ -2871,6 +2912,7 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_ACTIVATION_SWISH: return "DML_OPERATOR_ACTIVATION_SWISH";
case DML_OPERATOR_ACTIVATION_HARD_SWISH: return "DML_OPERATOR_ACTIVATION_HARD_SWISH";
case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2";
case DML_OPERATOR_RESAMPLE3: return "DML_OPERATOR_RESAMPLE3";
case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1";
case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1";
case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION";
Expand All @@ -2880,6 +2922,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_MULTIHEAD_ATTENTION1: return "DML_OPERATOR_MULTIHEAD_ATTENTION1";
case DML_OPERATOR_QUANTIZE: return "DML_OPERATOR_QUANTIZE";
case DML_OPERATOR_DEQUANTIZE: return "DML_OPERATOR_DEQUANTIZE";
case DML_OPERATOR_ROI_ALIGN_GRAD: return "DML_OPERATOR_ROI_ALIGN_GRAD";
case DML_OPERATOR_FOLD: return "DML_OPERATOR_FOLD";
case DML_OPERATOR_UNFOLD: return "DML_OPERATOR_UNFOLD";
default:
assert(false);
return "<unknown>";
Expand Down Expand Up @@ -2971,6 +3016,7 @@ inline gsl::czstring ToString(DML_PADDING_MODE value)
case DML_PADDING_MODE_EDGE: return "DML_PADDING_MODE_EDGE";
case DML_PADDING_MODE_REFLECTION: return "DML_PADDING_MODE_REFLECTION";
case DML_PADDING_MODE_SYMMETRIC: return "DML_PADDING_MODE_SYMMETRIC";
case DML_PADDING_MODE_WRAP: return "DML_PADDING_MODE_WRAP";
default:
assert(false);
return "<unknown>";
Expand Down Expand Up @@ -3036,6 +3082,7 @@ inline gsl::czstring ToString(DML_FEATURE_LEVEL value)
case DML_FEATURE_LEVEL_6_1: return "DML_FEATURE_LEVEL_6_1";
case DML_FEATURE_LEVEL_6_2: return "DML_FEATURE_LEVEL_6_2";
case DML_FEATURE_LEVEL_6_3: return "DML_FEATURE_LEVEL_6_3";
case DML_FEATURE_LEVEL_6_4: return "DML_FEATURE_LEVEL_6_4";
default:
assert(false);
return "<unknown>";
Expand Down
Loading

0 comments on commit 9691af1

Please sign in to comment.