Skip to content

Commit

Permalink
Modify DML EP activation functions to be case insensitive
Browse files Browse the repository at this point in the history
  • Loading branch information
dtang317 committed Nov 4, 2024
1 parent bbb38a9 commit 63c2109
Showing 1 changed file with 41 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,65 +123,86 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper

for (size_t i = 0; i < activations.size(); ++i)
{
const std::string& activationName = activations[i];
std::string& activationName = activations[i];
DML_OPERATOR_DESC& desc = descs[i];
ActivationOperatorDescUnion& activationDesc = m_activationDescs[i];
desc.Desc = &activationDesc;

if (activationName == AttrValue::ActivationRelu)
if (ActivationNameCompare(activationName, AttrValue::ActivationRelu))
{
desc.Type = DML_OPERATOR_ACTIVATION_RELU;
}
else if (activationName == AttrValue::ActivationLeakyRelu)
}
else if (ActivationNameCompare(activationName, AttrValue::ActivationLeakyRelu))
{
desc.Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU;
activationDesc.leakyRelu.Alpha = NextAlpha(desc.Type);
}
else if (activationName == AttrValue::ActivationThresholdedRelu)
else if (ActivationNameCompare(activationName, AttrValue::ActivationThresholdedRelu))
{
desc.Type = DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU;
activationDesc.thresholdedRelu.Alpha = NextAlpha(desc.Type);
}
else if (activationName == AttrValue::ActivationTanh)
}
else if (ActivationNameCompare(activationName, AttrValue::ActivationTanh))
{
desc.Type = DML_OPERATOR_ACTIVATION_TANH;
}
else if (activationName == AttrValue::ActivationScaledTanh)
}
else if (ActivationNameCompare(activationName, AttrValue::ActivationScaledTanh))
{
desc.Type = DML_OPERATOR_ACTIVATION_SCALED_TANH;
activationDesc.scaledTanh.Alpha = NextAlpha(desc.Type);
activationDesc.scaledTanh.Beta = NextBeta(desc.Type);
}
else if (activationName == AttrValue::ActivationSigmoid)
}
else if (ActivationNameCompare(activationName, AttrValue::ActivationSigmoid))
{
desc.Type = DML_OPERATOR_ACTIVATION_SIGMOID;
}
else if (activationName == AttrValue::ActivationSigmoidHard)
}
else if (ActivationNameCompare(activationName, AttrValue::ActivationSigmoidHard))
{
desc.Type = DML_OPERATOR_ACTIVATION_HARD_SIGMOID;
activationDesc.hardSigmoid.Alpha = NextAlpha(desc.Type);
activationDesc.hardSigmoid.Beta = NextBeta(desc.Type);
}
else if (activationName == AttrValue::ActivationElu)
}
else if (ActivationNameCompare(activationName, AttrValue::ActivationElu))
{
desc.Type = DML_OPERATOR_ACTIVATION_ELU;
activationDesc.elu.Alpha = NextAlpha(desc.Type);
}
else if (activationName == AttrValue::ActivationSoftsign)
}
else if (ActivationNameCompare(activationName, AttrValue::ActivationSoftsign))
{
desc.Type = DML_OPERATOR_ACTIVATION_SOFTSIGN;
}
else if (activationName == AttrValue::ActivationSoftplus)
}
else if (ActivationNameCompare(activationName, AttrValue::ActivationSoftplus))
{
desc.Type = DML_OPERATOR_ACTIVATION_SOFTPLUS;
}
else if (ActivationNameCompare(activationName, AttrValue::ActivationLeakyRelu))
{
desc.Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU;
}
else
{
ML_INVALID_ARGUMENT("Unsupported activation function");
}
}
}

bool ActivationNameCompare(const std::string& activationName, const char* attrValue)
{
if (activationName.size() != std::char_traits<char>::length(attrValue))
{
return false;
}

for (size_t i = 0; i < activationName.size(); ++i)
{
if (std::tolower(activationName[i]) != std::tolower(attrValue[i]))
{
return false;
}
}
return true;
}

void Compute(const MLOperatorKernelContext& kernelContext) override
{
// Assume that enough GPU work has been queued up after the RNN operator that it is worth
Expand Down

0 comments on commit 63c2109

Please sign in to comment.