Skip to content

Commit

Permalink
Changes to support TNLRV3 fine-tuning (#4639)
Browse files Browse the repository at this point in the history
* added reducesumlogexp gradient
added test
fixed type mismatch when calling cudnnreduce kernel
fixed python frontend to remove redundant states to match pytorch state dict
  • Loading branch information
Tixxx authored Jul 30, 2020
1 parent d8f3e46 commit f90a2d4
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 136 deletions.
12 changes: 9 additions & 3 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,12 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
}

CudnnReduceDescriptor reduce_desc;
if (std::is_same<T, MLFloat16>::value)
if (std::is_same<T, MLFloat16>::value) {
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, CudnnTensor::GetDataType<float>(), ReduceTensorIndices));
else
} else {
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, ReduceTensorIndices));
}

const auto one = Consts<CudaT>::One;
const auto zero = Consts<CudaT>::Zero;
CudnnTensor input_tensor;
Expand Down Expand Up @@ -437,7 +439,11 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
} else {
// Reduce max -- Max/Min will output indices data
CudnnReduceDescriptor reduce_max_desc;
ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES));
cudnnDataType_t cudnn_reduce_max_type = cudnn_type_X;
if((std::is_same<T, MLFloat16>::value)) {
cudnn_reduce_max_type = CUDNN_DATA_FLOAT;
}
ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_reduce_max_type, CUDNN_REDUCE_TENSOR_NO_INDICES));
size_t indices_bytes_max = 0;
CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(cuda_ep.PerThreadCudnnHandle(), reduce_max_desc,
input_tensor, output_tensor, &indices_bytes_max));
Expand Down
34 changes: 34 additions & 0 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,40 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
return result;
}

// Reference computation is pytorch's logsumexp_backward
// dx_i = exp(xi) / reduceSum(exp(xi))
// O(0) = log(reduceSum(exp(xi)))
// Self_Sub_Result = I(0) - O(0) = xi - log(sum(exp(xi))) = log( xi / reduceSum(exp(xi)))
// Gradient computation is re-using output and input from forward op, can be a recomputation candidate.
IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();
bool keepdims = true;
if (attributes.find("keepdims") != attributes.end() &&
attributes.at("keepdims").has_i()) {
keepdims = static_cast<bool>(attributes.at("keepdims").i());
}

ArgDef grad = GO(0);
if (!keepdims && attributes.find("axes") != attributes.end()) {
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
grad = IA("Unsqueezed_Grad");
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));

result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")}));
}
else {
result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")}));
}

result.push_back(NodeDef("Exp", {IA("Self_Sub_Result")}, {IA("Self_Sub_Result_Exp")}));

result.push_back(NodeDef("Mul", {IA("Self_Sub_Result_Exp"), grad}, {GI(0)}));

return result;
}

IMPLEMENT_GRADIENT_BUILDER(GetReduceSumGradient) {
std::vector<NodeDef> result;
auto attributes = SrcNodeAttributes();
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ DECLARE_GRADIENT_BUILDER(GetMulGradient)
DECLARE_GRADIENT_BUILDER(GetDivGradient)
DECLARE_GRADIENT_BUILDER(GetReduceMeanGradient)
DECLARE_GRADIENT_BUILDER(GetReduceSumGradient)
DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient)
DECLARE_GRADIENT_BUILDER(GetPowGradient)
DECLARE_GRADIENT_BUILDER(GetConcatGradient)
DECLARE_GRADIENT_BUILDER(GetReshapeGradient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Pow", GetPowGradient);
REGISTER_GRADIENT_BUILDER("ReduceMean", GetReduceMeanGradient);
REGISTER_GRADIENT_BUILDER("ReduceSum", GetReduceSumGradient);
REGISTER_GRADIENT_BUILDER("ReduceLogSumExp", GetReduceLogSumExpGradient);
REGISTER_GRADIENT_BUILDER("Add", GetAddSubGradient);
REGISTER_GRADIENT_BUILDER("Sub", GetAddSubGradient);
REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient);
Expand Down
8 changes: 7 additions & 1 deletion orttraining/orttraining/python/ort_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,8 @@ def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_opti
self.world_size = world_size
self.use_mixed_precision = use_mixed_precision

self.original_model_state_keys = list(model.state_dict().keys()) if hasattr(model, 'state_dict') else []

self.session = None
self.device_ = device
self.gradient_accumulation_steps = gradient_accumulation_steps
Expand Down Expand Up @@ -773,7 +775,11 @@ def state_dict(self):
if n.name not in torch_state:
torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n))

return torch_state
# Need to remove redundant initializers and name suffices to map back to original torch state names
torch_state_to_return = {key: torch_state[key] for key in self.original_model_state_keys if key in torch_state} \
if self.original_model_state_keys \
else torch_state
return torch_state_to_return

def load_state_dict(self, state_dict, strict=False):
# Note: It may happen ONNX model has not yet been initialized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

namespace onnxruntime {
namespace test {
using TestDataVector = std::tuple<std::vector<std::vector<TensorInfo>>, // Input data
std::vector<std::vector<TensorInfo>>, // output data
std::vector<std::vector<onnx::AttributeProto>>>; //attribute

class GradientOpTester : public OpTester {
public:
Expand Down Expand Up @@ -39,3 +42,4 @@ class GradientOpTester : public OpTester {
};
} // namespace test
} // namespace onnxruntime

204 changes: 72 additions & 132 deletions orttraining/orttraining/test/gradient/gradient_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,70 @@ static bool IsErrorWithinTolerance(float error, float tolerance) {
#define EXPECT_IS_TINY(max_error) \
EXPECT_IS_TINIER_THAN(max_error, 1.5e-2f)

static void RunReductionTests(const OpDef& op_def) {

TestDataVector test_data(
// Input X
{
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
{{4, 3, 2}},
},
// Input Y
{
{{1, 1, 1}},
{{}},
{{1, 3, 1}},
{{2}},
{{4, 1, 2}},
{{4, 3}},
{{4, 1, 2}},
{{4}}
},
// Attributes
{
// default
{},
// axes = [0, 1, 2], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [0, 2], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{0, 2})},
// axes = [0, 1], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [1], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{1}),
MakeAttribute("keepdims", int64_t(1))},
// axes = [2], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{2}),
MakeAttribute("keepdims", int64_t(0))},
// axes = [-2], keepdims = 1
{MakeAttribute("axes", std::vector<int64_t>{-2}),
MakeAttribute("keepdims", int64_t(1))},
// axes = [-2, -1], keepdims = 0
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
MakeAttribute("keepdims", int64_t(0))}
});

GradientChecker<float, float, float> gradient_checker;

float max_error;

for (size_t i = 0; i < std::get<0>(test_data).size(); i++) {
max_error = 0;
gradient_checker.ComputeGradientError(op_def, std::get<0>(test_data)[i],
std::get<1>(test_data)[i], &max_error,
std::get<2>(test_data)[i]);
EXPECT_IS_TINY(max_error);
}
}

template <typename T>
void GenerateRandomDataWithOneHot(
std::vector<std::vector<float>>& x_datas,
Expand Down Expand Up @@ -426,149 +490,24 @@ TEST(GradientCheckerTest, GemmGrad) {
}

TEST(GradientCheckerTest, ReduceMeanGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceMean", kOnnxDomain, 11};

// default
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error);
EXPECT_IS_TINY(max_error);
}

// TODO: Fix forward kernel behavior for default axes
// default axes, keepdims = 0
/*
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
{MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
*/

// axes = [0, 1, 2], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}

// axes = [0, 2], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 2})});
EXPECT_IS_TINY(max_error);
}

// axes = [0, 1], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}

// axes = [1], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{1}),
MakeAttribute("keepdims", int64_t(1))});
EXPECT_IS_TINY(max_error);
}

// axes = [2], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{2}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}

// axes = [-2], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{-2}),
MakeAttribute("keepdims", int64_t(1))});
EXPECT_IS_TINY(max_error);
}

// axes = [-2, -1], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
RunReductionTests(op_def);
}

TEST(GradientCheckerTest, ReduceSumGrad) {
float max_error;
GradientChecker<float, float, float> gradient_checker;
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceSum", kOnnxDomain, 11};

// default
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error);
EXPECT_IS_TINY(max_error);
}

// axes = [0, 1, 2], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}

// axes = [0, 2], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 2})});
EXPECT_IS_TINY(max_error);
}

// axes = [0, 1], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}

// axes = [1], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{1}),
MakeAttribute("keepdims", int64_t(1))});
EXPECT_IS_TINY(max_error);
}

// axes = [2], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{2}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
RunReductionTests(op_def);
}

// axes = [-2], keepdims = 1
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{-2}),
MakeAttribute("keepdims", int64_t(1))});
EXPECT_IS_TINY(max_error);
}
TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
// Attribute axes supports negative values from opset 11.
OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};

// axes = [-1, -3], keepdims = 0
{
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{3}}, &max_error,
{MakeAttribute("axes", std::vector<int64_t>{-1, -3}),
MakeAttribute("keepdims", int64_t(0))});
EXPECT_IS_TINY(max_error);
}
RunReductionTests(op_def);
}

#ifndef USE_CUDA
Expand Down Expand Up @@ -1960,3 +1899,4 @@ TEST(GradientCheckerTest, ExpandGrad) {
} // namespace onnxruntime

#endif // NDEBUG

0 comments on commit f90a2d4

Please sign in to comment.