Skip to content

Commit

Permalink
register groupnorm21
Browse files Browse the repository at this point in the history
  • Loading branch information
dtang317 committed Nov 13, 2024
1 parent 6d7603f commit db06cac
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, BiasAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)},
{REG_INFO_MS( 21, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, MatMulNBits, typeNameListTwo, supportedTypeListMatMulNBits, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMatMulNBits)},

// Operators that need to alias an input with an output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,7 @@ using ShapeInferenceHelper_BatchNormalization15 = BatchNormalizationHelper;
using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_GroupNorm = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_GroupNorm21 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_SkipLayerNormalization = SkipLayerNormHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,4 +498,9 @@ namespace OperatorHelper
static const int sc_sinceVer_DynamicQuantizeMatMul = 1;
} // namespace MsftOperatorSet1

namespace MsftOperatorSet21
{
static const int sc_sinceVer_GroupNorm = 21;
} // namespace MsftOperatorSet21

} // namespace OperatorHelper

0 comments on commit db06cac

Please sign in to comment.