Skip to content

Commit

Permalink
[hannk] augment L2NormOp to allow specifying axis (#6335)
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-johnson authored Oct 20, 2021
1 parent d80bb23 commit c3641b6
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 6 deletions.
3 changes: 2 additions & 1 deletion apps/hannk/delegate/hannk_delegate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,8 @@ class HannkDelegateKernel final {
OpPtr BuildL2Normalization(TfLiteContext *context, TfLiteNode *node) {
auto input = GetTensorById(context, node->inputs->data[0]);
auto output = GetTensorById(context, node->outputs->data[0]);
return make_op<L2NormalizationOp>(input, output);
const int axis = 0; // In TFLite, normalization is always against the first axis.
return make_op<L2NormalizationOp>(input, output, axis);
}

OpPtr BuildUnary(TfLiteContext *context, TfLiteNode *node, UnaryOp::Operator type) {
Expand Down
8 changes: 8 additions & 0 deletions apps/hannk/halide/normalizations_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ class L2Normalization : public Generator<L2Normalization> {
.update()
.atomic()
.vectorize(rx, vector_size);

// Normally we'd expect both buffers to be planar, but in unusual
// cases, Hannk can transpose the buffers (to normalize along another
// dimension), so for those cases, we'll just fall back to less-efficient
// code.
input_.dim(0).set_stride(Expr());
output_.dim(0).set_stride(Expr());
output_.specialize(input_.dim(0).stride() == 1 && output_.dim(0).stride() == 1);
}
};

Expand Down
17 changes: 15 additions & 2 deletions apps/hannk/interpreter/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1137,10 +1137,23 @@ void L2NormalizationOp::execute() {
const TensorPtr &in = input();
const TensorPtr &out = output();

// Negative values for axis_ must be normalized by the parser
assert(axis_ >= 0 && axis_ < in->rank());

if (in->type() == halide_type_of<uint8_t>() &&
out->type() == halide_type_of<uint8_t>()) {
const auto &in_buf = in->buffer();
const auto &out_buf = out->buffer();
// Make local copies in case we need to transpose them
HalideBuffer<void> in_buf = in->buffer();
HalideBuffer<void> out_buf = out->buffer();

// TODO: we currently assume that the axis-is-0 case is the most common
// and most important, and optimize for it; the other cases, we just transpose,
// which currently requires less-efficient specializations in the Halide code.
// Revisit if this proves too slow in practice.
if (axis_ != 0) {
in_buf.transpose(0, axis_);
out_buf.transpose(0, axis_);
}

const int input_zero = in->quantization().uniform_zero();
assert(input_zero >= 0 && input_zero <= 255);
Expand Down
6 changes: 4 additions & 2 deletions apps/hannk/interpreter/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,11 @@ class GatherOp : public Op {
};

class L2NormalizationOp : public Op {
const int axis_;

public:
L2NormalizationOp(const TensorPtr &input, const TensorPtr &output)
: Op({input}, {output}) {
L2NormalizationOp(const TensorPtr &input, const TensorPtr &output, int axis)
: Op({input}, {output}), axis_(axis) {
}

void accept(OpVisitor *v) override;
Expand Down
3 changes: 2 additions & 1 deletion apps/hannk/tflite/tflite_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ class Parser {
OpPtr parse_l2_normalization(const tflite::Operator *op) {
TensorPtr input = tensors_[op->inputs()->Get(0)];
TensorPtr output = tensors_[op->outputs()->Get(0)];
return make_op<L2NormalizationOp>(input, output);
const int axis = 0; // In TFLite, normalization is always against the first axis.
return make_op<L2NormalizationOp>(input, output, axis);
}

OpPtr parse_reduction(const tflite::Operator *op, ReductionOp::Operator reduction_op) {
Expand Down

0 comments on commit c3641b6

Please sign in to comment.