Skip to content

Commit

Permalink
DML functions always returning a value (#9485)
Browse files Browse the repository at this point in the history
* Always return a value
* @fdwr advice added
  • Loading branch information
gineshidalgo99 authored Oct 27, 2021
1 parent a2b3e6b commit 2d44bd5
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ onnx::OpSchema::FormalParameterOption AbiCustomRegistry::ConvertFormalParameterO

default:
THROW_HR(E_NOTIMPL);
return onnx::OpSchema::FormalParameterOption::Single;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tenso
case DML_TENSOR_DATA_TYPE_INT64: return MLOperatorTensorDataType::Int64;
case DML_TENSOR_DATA_TYPE_FLOAT64: return MLOperatorTensorDataType::Double;

default: ML_INVALID_ARGUMENT("Unknown DML_TENSOR_DATA_TYPE.");
default:
ML_INVALID_ARGUMENT("Unknown DML_TENSOR_DATA_TYPE.");
return MLOperatorTensorDataType::Undefined;
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ struct ActivationOperatorDesc
case DML_OPERATOR_ACTIVATION_TANH: return { activationType, &params.tanh };
case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return { activationType, &params.thresholdedRelu };
case DML_OPERATOR_ACTIVATION_SHRINK: return { activationType, &params.shrink };
default: THROW_HR(E_INVALIDARG);
default:
THROW_HR(E_INVALIDARG);
return { activationType, &params.relu };
}
}
};
Expand Down Expand Up @@ -206,9 +208,9 @@ class StackAllocator

~DynamicBucket()
{
if (data)
if (this->data)
{
(void)VirtualFree(data, 0, MEM_RELEASE);
(void)VirtualFree(this->data, 0, MEM_RELEASE);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,7 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args

default:
THROW_HR(E_INVALIDARG);
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA;
case DML_OPERATOR_ACTIVATION_SHRINK: return DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA;

default: THROW_HR(E_INVALIDARG);
default:
THROW_HR(E_INVALIDARG);
return DML_ACTIVATION_RELU_OPERATOR_SCHEMA;
}
}

Expand Down Expand Up @@ -2052,7 +2054,11 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_SHRINK_OPERATOR_DESC*>(opDesc.Desc)));
default: THROW_HR(E_INVALIDARG);
default:
THROW_HR(E_INVALIDARG);
return AbstractOperatorDesc(
&DML_ACTIVATION_RELU_OPERATOR_SCHEMA,
GetFields(*static_cast<const DML_ACTIVATION_RELU_OPERATOR_DESC*>(opDesc.Desc)));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace Dml::GraphDescBuilder

assert(false);
THROW_HR(E_UNEXPECTED);
return node.OutputDefs()[0]->Name();
}

GraphDesc BuildGraphDesc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ size_t AttributeValue::ElementCount() const {
// The type is validated when default attributes are registered
assert(false);
THROW_HR(E_FAIL);
return 0;
}
}

Expand Down Expand Up @@ -238,6 +239,7 @@ ::MLOperatorTensorDataType ToMLTensorDataType(onnxruntime::MLDataType type) {
ML_TENSOR_TYPE_CASE(onnxruntime::MLFloat16);

THROW_HR(E_NOTIMPL);
return MLOperatorTensorDataType::Undefined;
}

#undef ML_TENSOR_TYPE_CASE
Expand All @@ -264,6 +266,7 @@ onnxruntime::MLDataType ToTensorDataType(::MLOperatorTensorDataType type) {
ML_TENSOR_TYPE_CASE(onnxruntime::MLFloat16);

THROW_HR(E_NOTIMPL);
return onnxruntime::DataTypeImpl::GetTensorType<float>();
}

::MLOperatorTensorDataType ToMLTensorDataType(onnx::TensorProto_DataType type) {
Expand Down Expand Up @@ -315,6 +318,7 @@ ::MLOperatorTensorDataType ToMLTensorDataType(onnx::TensorProto_DataType type) {

default:
THROW_HR(E_NOTIMPL);
return MLOperatorTensorDataType::Undefined;
}
}

Expand Down Expand Up @@ -389,6 +393,7 @@ std::string ToTypeString(MLOperatorEdgeDescription desc) {

default:
THROW_HR(E_NOTIMPL);
return "";
}
}

Expand Down Expand Up @@ -594,7 +599,7 @@ HRESULT OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetAttributeHelper(
uint32_t elementByteSize,
void* value) const {
using elementType_t = typename MLAttributeTypeTraits<T>::Type;
static_assert(!typename MLAttributeTypeTraits<T>::IsArray, "This function only works for simple non-array types.");
static_assert(!MLAttributeTypeTraits<T>::IsArray, "This function only works for simple non-array types.");
ML_CHECK_BOOL(sizeof(elementType_t) == elementByteSize);
THROW_IF_NOT_OK(m_impl->template GetAttr<elementType_t>(name, static_cast<elementType_t*>(value)));
return S_OK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper
if (direction == AttrValue::DirectionBidirectional) { return DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL; }

ML_INVALID_ARGUMENT("Unsupported direction"); // throws
return DML_RECURRENT_NETWORK_DIRECTION_FORWARD;
}

void InitActivationDescs(const MLOperatorKernelCreationContext& kernelInfo, _Out_ std::vector<DML_OPERATOR_DESC>& descs, gsl::span<const std::string> defaultActivations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ namespace Dml
return *index;
}
ML_INVALID_ARGUMENT("Unknown interpolation mode");
return (DML_INTERPOLATION_MODE)0;
}

DML_DEPTH_SPACE_ORDER MapStringToDepthSpaceMode(std::string_view mode)
Expand All @@ -450,6 +451,7 @@ namespace Dml
return *index;
}
ML_INVALID_ARGUMENT("Unknown depth/space order");
return (DML_DEPTH_SPACE_ORDER)0;
}

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ inline size_t GetByteSizeFromMlDataType(MLOperatorTensorDataType tensorDataType)
case MLOperatorTensorDataType::Complex64: return 8;
case MLOperatorTensorDataType::Complex128: return 16;
case MLOperatorTensorDataType::Undefined:
default: THROW_HR(E_INVALIDARG);
default:
THROW_HR(E_INVALIDARG);
return 0;
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ namespace OperatorHelper
case MLOperatorTensorDataType::Complex64: return static_cast<int64_t>(*reinterpret_cast<const float*>(p)); // Read the real component.
case MLOperatorTensorDataType::Complex128: return static_cast<int64_t>(*reinterpret_cast<const double*>(p)); // Read the real component.
case MLOperatorTensorDataType::Undefined:
default: ML_INVALID_ARGUMENT("Unknown MLOperatorTensorDataType.");
default:
ML_INVALID_ARGUMENT("Unknown MLOperatorTensorDataType.");
return 0;
};
}

Expand All @@ -187,7 +189,9 @@ namespace OperatorHelper
case MLOperatorTensorDataType::Complex64: return static_cast<double>(*reinterpret_cast<const float*>(p)); // Read the real component.
case MLOperatorTensorDataType::Complex128: return static_cast<double>(*reinterpret_cast<const double*>(p)); // Read the real component.
case MLOperatorTensorDataType::Undefined:
default: ML_INVALID_ARGUMENT("Unknown MLOperatorTensorDataType.");
default:
ML_INVALID_ARGUMENT("Unknown MLOperatorTensorDataType.");
return 0.0;
};
}

Expand Down

0 comments on commit 2d44bd5

Please sign in to comment.