Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions oneflow/core/autograd/gradient_funcs/nll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace one {
struct NLLCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
int64_t ignore_index = -100;
std::string reduction = "none";
};

class NLLGradFunction : public OpExprGradFunction<NLLCaptureState> {
Expand Down Expand Up @@ -53,37 +54,45 @@ Maybe<void> NLLGradFunction::Capture(NLLCaptureState* ctx, const TensorTuple& in

ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->ignore_index = JUST(composed_attrs.GetAttr<int64_t>("ignore_index"));
ctx->reduction = JUST(composed_attrs.GetAttr<std::string>("reduction"));
ctx->SaveTensorForBackward(input); // input
ctx->SaveTensorForBackward(JUST(VectorAt(inputs, 1))); // target
if (inputs.size() == 3) {
ctx->SaveTensorForBackward(inputs[2]); // weight
}
ctx->SaveTensorForBackward(outputs[3]); // total_weight
return Maybe<void>::Ok();
}

Maybe<void> NLLGradFunction::Apply(const NLLCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

CHECK_EQ_OR_RETURN(out_grads.size(), 2); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(out_grads.size(), 4); // NOLINT(maybe-need-error-msg)
CHECK_GE_OR_RETURN(ctx->SavedTensors().size(), 2)
<< Error::RuntimeError()
<< "The number of saved tensors is expected to be greater than or equal to 2, but got "
<< ctx->SavedTensors().size();
const auto& out_grad = out_grads[0];

// outputs: output, out_weight, reduced_out, total_weight
const auto& out_grad = out_grads[0]; // for reduction="none"
const auto& reduced_out_grad = out_grads[2]; // for reduction="mean"/"sum"
const auto& input = ctx->SavedTensors()[0];
const auto& target = ctx->SavedTensors()[1];

in_grads->resize(ctx->SavedTensors().size());

if (ctx->SavedTensors().size() == 2) {
if (ctx->SavedTensors().size() == 3) { // no weight
const auto& total_weight = ctx->SavedTensors()[2];
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::NLLGrad(out_grad, input, target, NullOpt, ctx->ignore_index));
} else {
// has weight
auto weight = JUST(VectorAt(ctx->SavedTensors(), 2));
JUST(functional::NLLGrad(out_grad, reduced_out_grad, input, target, total_weight, NullOpt,
ctx->ignore_index, ctx->reduction));
} else if (ctx->SavedTensors().size() == 4) { // has weight
const auto& weight = ctx->SavedTensors()[2];
const auto& total_weight = ctx->SavedTensors()[3];
JUST(VectorAt(*in_grads, 0)) =
JUST(functional::NLLGrad(out_grad, input, target, weight, ctx->ignore_index));
JUST(functional::NLLGrad(out_grad, reduced_out_grad, input, target, total_weight, weight,
ctx->ignore_index, ctx->reduction));
}

return Maybe<void>::Ok();
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@
bind_python: True

- name: "nll_grad"
signature: "Tensor(Tensor out_grad, Tensor input, Tensor target, Tensor weight=None, Int64 ignore_index) => NLLGrad"
signature: 'Tensor(Tensor out_grad, Tensor reduced_out_grad, Tensor input, Tensor target, Tensor total_weight, Tensor weight=None, Int64 ignore_index, String reduction="none") => NLLGrad'
bind_python: False

- name: "binary_cross_entropy_loss"
Expand Down
20 changes: 20 additions & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,27 @@ class ArgWhereFunctor {
const Symbol<DType>& dtype) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("dtype");
attrs.SetAllAttrs(dtype->data_type());
#ifdef WITH_NPU
auto device_type = DeviceType::kCPU;
if (x->is_global()) {
device_type = JUST(x->parallel_desc())->device_type();
} else {
device_type = JUST(x->device())->enum_type();
}
if (device_type == DeviceType::kNPU) {
// NOTE: use cpu argwhere when device="npu"
auto cpu_tensor = JUST(one::functional::To(x, "cpu"));
auto result = JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_, {cpu_tensor}, attrs));
for (int i = 0; i < result->size(); ++i) {
(*result)[i] = JUST(one::functional::To((*result)[i], "npu"));
}
return result;
} else {
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {x}, attrs);
}
#else
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {x}, attrs);
#endif // WITH_NPU
}

private:
Expand Down
46 changes: 34 additions & 12 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,9 +930,9 @@ class SkipLayerNormFunctor {
std::tuple<bool, bool, bool, bool>(has_skip, has_gamma, has_beta, has_bias),
op_expr));
} // has_bias
} // has_beta
} // has_gamma
} // has_skip
} // has_beta
} // has_gamma
} // has_skip
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
Expand Down Expand Up @@ -1170,8 +1170,8 @@ class SkipRMSNormFunctor {
ops_.insert(std::pair<std::tuple<bool, bool, bool>, std::shared_ptr<OpExpr>>(
std::tuple<bool, bool, bool>(has_weight, has_skip, has_bias), op_expr));
} // has_bias
} // has_skip
} // has_weight
} // has_skip
} // has_weight
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
Expand Down Expand Up @@ -1477,7 +1477,7 @@ class MaxUnpoolNDFunctor {
.Input("x")
.Input("indices")
.Output("y")
.Build())){};
.Build())) {};
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& indices,
const std::vector<int32_t>& kernel_size,
Expand Down Expand Up @@ -1819,6 +1819,8 @@ class NLLLossFunctor {
.Input("target")
.Output("output")
.Output("out_weight")
.Output("reduced_out")
.Output("total_weight")
.Build());

op_weight_ = CHECK_JUST(one::OpBuilder("nll")
Expand All @@ -1827,6 +1829,8 @@ class NLLLossFunctor {
.Input("weight")
.Output("output")
.Output("out_weight")
.Output("reduced_out")
.Output("total_weight")
.Build());
}

Expand Down Expand Up @@ -1875,8 +1879,10 @@ class NLLLossFunctor {
target_ = target;
}

auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index");
attrs.SetAllAttrs(ignore_index);
// auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index");
// attrs.SetAllAttrs(ignore_index);
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index", "reduction");
attrs.SetAllAttrs(ignore_index, reduction);

std::shared_ptr<TensorTuple> nll_result;
if (weight) {
Expand All @@ -1890,9 +1896,14 @@ class NLLLossFunctor {
if (K > 2) { output = JUST(functional::Reshape(output, *target_shape)); }

if (reduction == "none") { return output; }

#ifdef WITH_NPU
// npu device
if (!input->is_cpu()) {
auto reduced_out = JUST(VectorAt(*nll_result, 2));
if (reduction == "sum" || reduction == "mean") { return reduced_out; }
}
#endif // WITH_NPU
auto sum = JUST(functional::ReduceSum(output, {}, false, NullOpt));

if (reduction == "sum") { return sum; }

auto total_weight =
Expand All @@ -1915,6 +1926,8 @@ class CrossEntropyFunctor {
.Input("target")
.Output("output")
.Output("out_weight")
.Output("reduced_out")
.Output("total_weight")
.Build());

op_nll_weight_ = CHECK_JUST(one::OpBuilder("nll")
Expand All @@ -1923,6 +1936,8 @@ class CrossEntropyFunctor {
.Input("weight")
.Output("output")
.Output("out_weight")
.Output("reduced_out")
.Output("total_weight")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
Expand Down Expand Up @@ -1959,8 +1974,8 @@ class CrossEntropyFunctor {

const auto target_ = JUST(functional::Flatten(target, 0, target->shape()->NumAxes() - 1));

auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index");
attrs.SetAllAttrs(ignore_index);
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index", "reduction");
attrs.SetAllAttrs(ignore_index, reduction);

std::shared_ptr<TensorTuple> nll_result;
if (weight) {
Expand All @@ -1974,6 +1989,13 @@ class CrossEntropyFunctor {
output = JUST(functional::Reshape(output, *target_shape));
if (reduction == "none") { return output; }

#ifdef WITH_NPU
if (!input->is_cpu()) {
auto reduced_out = JUST(VectorAt(*nll_result, 2));
// auto total_weight = JUST(VectorAt(*nll_result, 3));
if (reduction == "sum" || reduction == "mean") { return reduced_out; }
}
#endif // WITH_NPU
auto sum = JUST(functional::ReduceSum(output, {}, false, NullOpt));
if (reduction == "sum") { return sum; }

Expand Down
23 changes: 17 additions & 6 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,32 +451,43 @@ class NLLGradFunctor {
NLLGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("nll_grad")
.Input("out_grad")
.Input("reduced_out_grad")
.Input("input")
.Input("target")
.Input("total_weight")
.Output("in_grad")
.Build());

op_weight_ = CHECK_JUST(one::OpBuilder("nll_grad")
.Input("out_grad")
.Input("reduced_out_grad")
.Input("input")
.Input("target")
.Input("total_weight")
.Input("weight")
.Output("in_grad")
.Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& out_grad,
const std::shared_ptr<one::Tensor>& reduced_out_grad,
const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& target,
const Optional<one::Tensor>& weight, const int64_t ignore_index) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index");
attrs.SetAllAttrs(ignore_index);

const std::shared_ptr<one::Tensor>& total_weight,
const Optional<one::Tensor>& weight, const int64_t ignore_index,
const std::string& reduction) const {
// auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index");
// attrs.SetAllAttrs(ignore_index);
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("ignore_index", "reduction");
attrs.SetAllAttrs(ignore_index, reduction);
if (weight) {
return OpInterpUtil::Dispatch<one::Tensor>(
*op_weight_, {out_grad, input, target, JUST(JUST(weight)->detach())}, attrs);
*op_weight_,
{out_grad, reduced_out_grad, input, target, total_weight, JUST(JUST(weight)->detach())},
attrs);
} else {
return OpInterpUtil::Dispatch<one::Tensor>(*op_, {out_grad, input, target}, attrs);
return OpInterpUtil::Dispatch<one::Tensor>(
*op_, {out_grad, reduced_out_grad, input, target, total_weight}, attrs);
}
}

Expand Down
12 changes: 9 additions & 3 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6106,10 +6106,13 @@ def OneFlow_NLLOp : OneFlow_BaseOp<"nll", [NoMemoryEffect, DeclareOpInterfaceMet
);
let output = (outs
OneFlow_Tensor:$output,
OneFlow_Tensor:$out_weight
OneFlow_Tensor:$out_weight,
OneFlow_Tensor:$reduced_out,
OneFlow_Tensor:$total_weight
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "0">:$ignore_index
DefaultValuedAttr<SI64Attr, "0">:$ignore_index,
DefaultValuedAttr<StrAttr, "\"none\"">:$reduction
);
let has_data_type_infer_fn = 1;
let has_logical_tensor_desc_infer_fn = 1;
Expand All @@ -6120,15 +6123,18 @@ def OneFlow_NLLOp : OneFlow_BaseOp<"nll", [NoMemoryEffect, DeclareOpInterfaceMet
def OneFlow_NLLGradOp : OneFlow_BaseOp<"nll_grad", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$out_grad,
OneFlow_Tensor:$reduced_out_grad,
OneFlow_Tensor:$input,
OneFlow_Tensor:$target,
OneFlow_Tensor:$total_weight,
Optional<OneFlow_Tensor>:$weight
);
let output = (outs
OneFlow_Tensor:$in_grad
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "0">:$ignore_index
DefaultValuedAttr<SI64Attr, "0">:$ignore_index,
DefaultValuedAttr<StrAttr, "\"none\"">:$reduction
);
let has_data_type_infer_fn = 1;
let has_logical_tensor_desc_infer_fn = 1;
Expand Down
24 changes: 21 additions & 3 deletions oneflow/user/ops/nll_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ namespace oneflow {

ctx->SetOutputDType("output", 0, input_dtype);
ctx->SetOutputDType("out_weight", 0, input_dtype);
ctx->SetOutputDType("reduced_out", 0, input_dtype);
ctx->SetOutputDType("total_weight", 0, input_dtype);

return Maybe<void>::Ok();
}
Expand Down Expand Up @@ -69,6 +71,14 @@ namespace oneflow {
out_weight_desc->set_is_dynamic(is_dynamic);
out_weight_desc->set_shape(Shape({N}));

user_op::TensorDesc* reduced_out_desc = ctx->MutOutputTensorDesc("reduced_out", 0);
reduced_out_desc->set_is_dynamic(false);
reduced_out_desc->set_shape(Shape({}));

user_op::TensorDesc* total_weight_desc = ctx->MutOutputTensorDesc("total_weight", 0);
total_weight_desc->set_is_dynamic(false);
total_weight_desc->set_shape(Shape({}));

return Maybe<void>::Ok();
}

Expand All @@ -78,7 +88,9 @@ namespace oneflow {
.Split(user_op::OpArg("input", 0), 0)
.Split(user_op::OpArg("target", 0), 0)
.Split(user_op::OpArg("output", 0), 0)
.Split(user_op::OpArg("out_weight", 0), 0);
.Split(user_op::OpArg("out_weight", 0), 0)
.Broadcast(user_op::OpArg("reduced_out", 0))
.Broadcast(user_op::OpArg("total_weight", 0));
if (ctx->user_op_conf().has_input("weight", 0)) {
builder1.Broadcast(user_op::OpArg("weight", 0));
}
Expand All @@ -89,8 +101,10 @@ namespace oneflow {
auto builder2 = ctx->NewBuilder()
.Split(user_op::OpArg("input", 0), shape.NumAxes() - 1)
.Broadcast(user_op::OpArg("target", 0))
.PartialSum(user_op::OpArg("output", 0))
.PartialSum(user_op::OpArg("out_weight", 0));
.Broadcast(user_op::OpArg("output", 0))
.Broadcast(user_op::OpArg("out_weight", 0))
.Broadcast(user_op::OpArg("reduced_out", 0))
.Broadcast(user_op::OpArg("total_weight", 0));
if (ctx->user_op_conf().has_input("weight", 0)) {
builder2.Split(user_op::OpArg("weight", 0), 0);
}
Expand Down Expand Up @@ -172,6 +186,8 @@ namespace oneflow {
.Split(user_op::OpArg("input", 0), 0)
.Split(user_op::OpArg("target", 0), 0)
.Split(user_op::OpArg("out_grad", 0), 0)
.Broadcast(user_op::OpArg("reduced_out_grad", 0))
.Broadcast(user_op::OpArg("total_weight", 0))
.Split(user_op::OpArg("in_grad", 0), 0);
if (ctx->user_op_conf().has_input("weight", 0)) {
builder1.Broadcast(user_op::OpArg("weight", 0));
Expand All @@ -184,6 +200,8 @@ namespace oneflow {
.Split(user_op::OpArg("input", 0), shape.NumAxes() - 1)
.Broadcast(user_op::OpArg("target", 0))
.Broadcast(user_op::OpArg("out_grad", 0))
.Broadcast(user_op::OpArg("reduced_out_grad", 0))
.Broadcast(user_op::OpArg("total_weight", 0))
.Split(user_op::OpArg("in_grad", 0), shape.NumAxes() - 1);
if (ctx->user_op_conf().has_input("weight", 0)) {
builder2.Split(user_op::OpArg("weight", 0), 0);
Expand Down