diff --git a/oneflow/core/autograd/gradient_funcs/nll.cpp b/oneflow/core/autograd/gradient_funcs/nll.cpp index 76a946dd4b0..713b5190b53 100644 --- a/oneflow/core/autograd/gradient_funcs/nll.cpp +++ b/oneflow/core/autograd/gradient_funcs/nll.cpp @@ -24,6 +24,7 @@ namespace one { struct NLLCaptureState : public AutoGradCaptureState { bool requires_grad = false; int64_t ignore_index = -100; + std::string reduction = "none"; }; class NLLGradFunction : public OpExprGradFunction { @@ -53,11 +54,13 @@ Maybe NLLGradFunction::Capture(NLLCaptureState* ctx, const TensorTuple& in ComposedAttrMap composed_attrs(attrs, base_attrs_); ctx->ignore_index = JUST(composed_attrs.GetAttr("ignore_index")); + ctx->reduction = JUST(composed_attrs.GetAttr("reduction")); ctx->SaveTensorForBackward(input); // input ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target if (inputs.size() == 3) { ctx->SaveTensorForBackward(inputs[2]); // weight } + ctx->SaveTensorForBackward(outputs[3]); // total_weight return Maybe::Ok(); } @@ -65,25 +68,31 @@ Maybe NLLGradFunction::Apply(const NLLCaptureState* ctx, const TensorTuple TensorTuple* in_grads) const { if (!ctx->requires_grad) { return Maybe::Ok(); } - CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg) + CHECK_EQ_OR_RETURN(out_grads.size(), 4); // NOLINT(maybe-need-error-msg) CHECK_GE_OR_RETURN(ctx->SavedTensors().size(), 2) << Error::RuntimeError() << "The number of saved tensors is expected to be greater than or equal to 2, but got " << ctx->SavedTensors().size(); - const auto& out_grad = out_grads[0]; + + // outputs: output, out_weight, reduced_out, total_weight + const auto& out_grad = out_grads[0]; // for reduction="none" + const auto& reduced_out_grad = out_grads[2]; // for reduction="mean"/"sum" const auto& input = ctx->SavedTensors()[0]; const auto& target = ctx->SavedTensors()[1]; in_grads->resize(ctx->SavedTensors().size()); - if (ctx->SavedTensors().size() == 2) { + if (ctx->SavedTensors().size() == 3) { // no weight + const auto& total_weight = ctx->SavedTensors()[2]; JUST(VectorAt(*in_grads, 0)) = - JUST(functional::NLLGrad(out_grad, input, target, NullOpt, ctx->ignore_index)); - } else { - // has weight - auto weight = JUST(VectorAt(ctx->SavedTensors(), 2)); + JUST(functional::NLLGrad(out_grad, reduced_out_grad, input, target, total_weight, NullOpt, + ctx->ignore_index, ctx->reduction)); + } else if (ctx->SavedTensors().size() == 4) { // has weight + const auto& weight = ctx->SavedTensors()[2]; + const auto& total_weight = ctx->SavedTensors()[3]; JUST(VectorAt(*in_grads, 0)) = - JUST(functional::NLLGrad(out_grad, input, target, weight, ctx->ignore_index)); + JUST(functional::NLLGrad(out_grad, reduced_out_grad, input, target, total_weight, weight, + ctx->ignore_index, ctx->reduction)); } return Maybe::Ok(); diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 5c6d8148ac1..ac454fd8ba7 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1355,7 +1355,7 @@ bind_python: True - name: "nll_grad" - signature: "Tensor(Tensor out_grad, Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index) => NLLGrad" + signature: 'Tensor(Tensor out_grad, Tensor reduced_out_grad, Tensor input, Tensor target, Tensor total_weight, Tensor weight=None, Int64 ignore_index, String reduction="none") => NLLGrad' bind_python: False - name: "binary_cross_entropy_loss" diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index aef7ef62a3b..9c1a4b89cee 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -588,7 +588,27 @@ class ArgWhereFunctor { const Symbol& dtype) const { auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dtype"); attrs.SetAllAttrs(dtype->data_type()); +#ifdef WITH_NPU + auto device_type = DeviceType::kCPU; + if (x->is_global()) { + device_type = JUST(x->parallel_desc())->device_type(); + } else { + device_type = JUST(x->device())->enum_type(); + } + if (device_type == DeviceType::kNPU) { + // NOTE: use cpu argwhere when device="npu" + auto cpu_tensor = JUST(one::functional::To(x, "cpu")); + auto result = JUST(OpInterpUtil::Dispatch(*op_, {cpu_tensor}, attrs)); + for (int i = 0; i < result->size(); ++i) { + (*result)[i] = JUST(one::functional::To((*result)[i], "npu")); + } + return result; + } else { + return OpInterpUtil::Dispatch(*op_, {x}, attrs); + } +#else return OpInterpUtil::Dispatch(*op_, {x}, attrs); +#endif // WITH_NPU } private: diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index f2d1a8a029b..4c40f75dad0 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -930,9 +930,9 @@ class SkipLayerNormFunctor { std::tuple(has_skip, has_gamma, has_beta, has_bias), op_expr)); } // has_bias - } // has_beta - } // has_gamma - } // has_skip + } // has_beta + } // has_gamma + } // has_skip } Maybe operator()(const std::shared_ptr& x, @@ -1170,8 +1170,8 @@ class SkipRMSNormFunctor { ops_.insert(std::pair, std::shared_ptr>( std::tuple(has_weight, has_skip, has_bias), op_expr)); } // has_bias - } // has_skip - } // has_weight + } // has_skip + } // has_weight } Maybe operator()(const std::shared_ptr& x, @@ -1477,7 +1477,7 @@ class MaxUnpoolNDFunctor { .Input("x") .Input("indices") .Output("y") - .Build())){}; + .Build())) {}; Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& indices, const std::vector& kernel_size, @@ -1819,6 +1819,8 @@ class NLLLossFunctor { .Input("target") .Output("output") .Output("out_weight") + .Output("reduced_out") + .Output("total_weight") .Build()); op_weight_ = CHECK_JUST(one::OpBuilder("nll") @@ -1827,6 +1829,8 @@ class NLLLossFunctor { .Input("weight") .Output("output") .Output("out_weight") + .Output("reduced_out") + .Output("total_weight") .Build()); } @@ -1875,8 +1879,10 @@ class NLLLossFunctor { target_ = target; } - auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index"); - attrs.SetAllAttrs(ignore_index); + // auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index"); + // attrs.SetAllAttrs(ignore_index); + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index", "reduction"); + attrs.SetAllAttrs(ignore_index, reduction); std::shared_ptr nll_result; if (weight) { @@ -1890,9 +1896,14 @@ class NLLLossFunctor { if (K > 2) { output = JUST(functional::Reshape(output, *target_shape)); } if (reduction == "none") { return output; } - +#ifdef WITH_NPU + // npu device + if (!input->is_cpu()) { + auto reduced_out = JUST(VectorAt(*nll_result, 2)); + if (reduction == "sum" || reduction == "mean") { return reduced_out; } + } +#endif // WITH_NPU auto sum = JUST(functional::ReduceSum(output, {}, false, NullOpt)); - if (reduction == "sum") { return sum; } auto total_weight = @@ -1915,6 +1926,8 @@ class CrossEntropyFunctor { .Input("target") .Output("output") .Output("out_weight") + .Output("reduced_out") + .Output("total_weight") .Build()); op_nll_weight_ = CHECK_JUST(one::OpBuilder("nll") @@ -1923,6 +1936,8 @@ class CrossEntropyFunctor { .Input("weight") .Output("output") .Output("out_weight") + .Output("reduced_out") + .Output("total_weight") .Build()); } Maybe operator()(const std::shared_ptr& input, @@ -1959,8 +1974,8 @@ class CrossEntropyFunctor { const auto target_ = JUST(functional::Flatten(target, 0, target->shape()->NumAxes() - 1)); - auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index"); - attrs.SetAllAttrs(ignore_index); + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index", "reduction"); + attrs.SetAllAttrs(ignore_index, reduction); std::shared_ptr nll_result; if (weight) { @@ -1974,6 +1989,13 @@ class CrossEntropyFunctor { output = JUST(functional::Reshape(output, *target_shape)); if (reduction == "none") { return output; } +#ifdef WITH_NPU + if (!input->is_cpu()) { + auto reduced_out = JUST(VectorAt(*nll_result, 2)); + // auto total_weight = JUST(VectorAt(*nll_result, 3)); + if (reduction == "sum" || reduction == "mean") { return reduced_out; } + } +#endif // WITH_NPU auto sum = JUST(functional::ReduceSum(output, {}, false, NullOpt)); if (reduction == "sum") { return sum; } diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index e833c98d2ae..02dd4bba427 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -451,32 +451,43 @@ class NLLGradFunctor { NLLGradFunctor() { op_ = CHECK_JUST(one::OpBuilder("nll_grad") .Input("out_grad") + .Input("reduced_out_grad") .Input("input") .Input("target") + .Input("total_weight") .Output("in_grad") .Build()); op_weight_ = CHECK_JUST(one::OpBuilder("nll_grad") .Input("out_grad") + .Input("reduced_out_grad") .Input("input") .Input("target") + .Input("total_weight") .Input("weight") .Output("in_grad") .Build()); } Maybe operator()(const std::shared_ptr& out_grad, + const std::shared_ptr& reduced_out_grad, const std::shared_ptr& input, const std::shared_ptr& target, - const Optional& weight, const int64_t ignore_index) const { - auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index"); - attrs.SetAllAttrs(ignore_index); - + const std::shared_ptr& total_weight, + const Optional& weight, const int64_t ignore_index, + const std::string& reduction) const { + // auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index"); + // attrs.SetAllAttrs(ignore_index); + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index", "reduction"); + attrs.SetAllAttrs(ignore_index, reduction); if (weight) { return OpInterpUtil::Dispatch( - *op_weight_, {out_grad, input, target, JUST(JUST(weight)->detach())}, attrs); + *op_weight_, + {out_grad, reduced_out_grad, input, target, total_weight, JUST(JUST(weight)->detach())}, + attrs); } else { - return OpInterpUtil::Dispatch(*op_, {out_grad, input, target}, attrs); + return OpInterpUtil::Dispatch( + *op_, {out_grad, reduced_out_grad, input, target, total_weight}, attrs); } } diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 7532cfe441e..f364235dc9b 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -6106,10 +6106,13 @@ def OneFlow_NLLOp : OneFlow_BaseOp<"nll", [NoMemoryEffect, DeclareOpInterfaceMet ); let output = (outs OneFlow_Tensor:$output, - OneFlow_Tensor:$out_weight + OneFlow_Tensor:$out_weight, + OneFlow_Tensor:$reduced_out, + OneFlow_Tensor:$total_weight ); let attrs = (ins - DefaultValuedAttr:$ignore_index + DefaultValuedAttr:$ignore_index, + DefaultValuedAttr:$reduction ); let has_data_type_infer_fn = 1; let has_logical_tensor_desc_infer_fn = 1; @@ -6120,15 +6123,18 @@ def OneFlow_NLLOp : OneFlow_BaseOp<"nll", [NoMemoryEffect, DeclareOpInterfaceMet def OneFlow_NLLGradOp : OneFlow_BaseOp<"nll_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$out_grad, + OneFlow_Tensor:$reduced_out_grad, OneFlow_Tensor:$input, OneFlow_Tensor:$target, + OneFlow_Tensor:$total_weight, Optional:$weight ); let output = (outs OneFlow_Tensor:$in_grad ); let attrs = (ins - DefaultValuedAttr:$ignore_index + DefaultValuedAttr:$ignore_index, + DefaultValuedAttr:$reduction ); let has_data_type_infer_fn = 1; let has_logical_tensor_desc_infer_fn = 1; diff --git a/oneflow/user/ops/nll_op.cpp b/oneflow/user/ops/nll_op.cpp index 7f9035539b7..63c3d8d55ff 100644 --- a/oneflow/user/ops/nll_op.cpp +++ b/oneflow/user/ops/nll_op.cpp @@ -31,6 +31,8 @@ namespace oneflow { ctx->SetOutputDType("output", 0, input_dtype); ctx->SetOutputDType("out_weight", 0, input_dtype); + ctx->SetOutputDType("reduced_out", 0, input_dtype); + ctx->SetOutputDType("total_weight", 0, input_dtype); return Maybe::Ok(); } @@ -69,6 +71,14 @@ namespace oneflow { out_weight_desc->set_is_dynamic(is_dynamic); out_weight_desc->set_shape(Shape({N})); + user_op::TensorDesc* reduced_out_desc = ctx->MutOutputTensorDesc("reduced_out", 0); + reduced_out_desc->set_is_dynamic(false); + reduced_out_desc->set_shape(Shape({})); + + user_op::TensorDesc* total_weight_desc = ctx->MutOutputTensorDesc("total_weight", 0); + total_weight_desc->set_is_dynamic(false); + total_weight_desc->set_shape(Shape({})); + return Maybe::Ok(); } @@ -78,7 +88,9 @@ namespace oneflow { .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) .Split(user_op::OpArg("output", 0), 0) - .Split(user_op::OpArg("out_weight", 0), 0); + .Split(user_op::OpArg("out_weight", 0), 0) + .Broadcast(user_op::OpArg("reduced_out", 0)) + .Broadcast(user_op::OpArg("total_weight", 0)); if (ctx->user_op_conf().has_input("weight", 0)) { builder1.Broadcast(user_op::OpArg("weight", 0)); } @@ -89,8 +101,10 @@ namespace oneflow { auto builder2 = ctx->NewBuilder() .Split(user_op::OpArg("input", 0), shape.NumAxes() - 1) .Broadcast(user_op::OpArg("target", 0)) - .PartialSum(user_op::OpArg("output", 0)) - .PartialSum(user_op::OpArg("out_weight", 0)); + .Broadcast(user_op::OpArg("output", 0)) + .Broadcast(user_op::OpArg("out_weight", 0)) + .Broadcast(user_op::OpArg("reduced_out", 0)) + .Broadcast(user_op::OpArg("total_weight", 0)); if (ctx->user_op_conf().has_input("weight", 0)) { builder2.Split(user_op::OpArg("weight", 0), 0); } @@ -172,6 +186,8 @@ namespace oneflow { .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) .Split(user_op::OpArg("out_grad", 0), 0) + .Broadcast(user_op::OpArg("reduced_out_grad", 0)) + .Broadcast(user_op::OpArg("total_weight", 0)) .Split(user_op::OpArg("in_grad", 0), 0); if (ctx->user_op_conf().has_input("weight", 0)) { builder1.Broadcast(user_op::OpArg("weight", 0)); @@ -184,6 +200,8 @@ namespace oneflow { .Split(user_op::OpArg("input", 0), shape.NumAxes() - 1) .Broadcast(user_op::OpArg("target", 0)) .Broadcast(user_op::OpArg("out_grad", 0)) + .Broadcast(user_op::OpArg("reduced_out_grad", 0)) + .Broadcast(user_op::OpArg("total_weight", 0)) .Split(user_op::OpArg("in_grad", 0), shape.NumAxes() - 1); if (ctx->user_op_conf().has_input("weight", 0)) { builder2.Split(user_op::OpArg("weight", 0), 0);