Skip to content

Commit 3d81290

Browse files
Validation for the buffer indices in input, output, intermedite tensors and subgraphs.
Support for DECODE operator (#3162) * Support for DECODE operator @tensorflow/micro Add initial support for DECODE operator. Add reference implementation. Add LUT decompression support. Update op resolvers. Update Makefiles and Bazel BUILD files. Add kernel unit test. bug=fixes #3131 * update copyright * Don't use constructors with global objects (bluepill will not call them). Cleanup unit test. * Support for DECODE operator @tensorflow/micro Additional support for DECODE operator. Add Xtensa optimizations for LUT decompression. Move all Xtensa kernel source references to the Xtensa target makefile. bug=fixes #3150 * Updates to Xtensa makefiles @tensorflow/micro Reorganize Xtensa makefiles such that all references to optimized kernel sources are moved to the Xtensa target makefile. Move hifimini kernel sources to the parent directory, and rename them so they do not interfere with the target overlay mechanism of the root makefile. bug=fixes #3153 * Fix incorrect include path. Fix code style errors. * fix copyright * update generic benchmark op resolver size * Support for DECODE operator @tensorflow/micro Add reference implementation of pruning to DECODE operator. Makefile and Bazel BUILD file changes. Additional unit tests. bug=fixes #3161 * Split decode tests into seperate files. Update pruning code with zero-point checks. Add const-tensor checks. * Add decode_test_helpers.h file. Cleanup tests. Added end-of-line symbols
1 parent cee9550 commit 3d81290

16 files changed

+1265
-150
lines changed

tensorflow/lite/micro/kernels/BUILD

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,20 @@ tflm_cc_library(
7979
],
8080
)
8181

82+
tflm_cc_library(
83+
name = "decode_test_helpers",
84+
hdrs = [
85+
"decode_test_helpers.h",
86+
],
87+
deps = [
88+
":kernel_runner",
89+
":micro_ops",
90+
"//tensorflow/lite/c:common",
91+
"//tensorflow/lite/micro:test_helpers",
92+
"//tensorflow/lite/micro/testing:micro_test",
93+
],
94+
)
95+
8296
tflm_cc_library(
8397
name = "decompress",
8498
srcs = [
@@ -239,6 +253,7 @@ tflm_kernel_cc_library(
239253
"decode.cc",
240254
"decode_state.cc",
241255
"decode_state_lut.cc",
256+
"decode_state_prune.cc",
242257
"depth_to_space.cc",
243258
"depthwise_conv.cc",
244259
"depthwise_conv_common.cc",
@@ -332,6 +347,7 @@ tflm_kernel_cc_library(
332347
"conv.h",
333348
"decode_state.h",
334349
"decode_state_lut.h",
350+
"decode_state_prune.h",
335351
"depthwise_conv.h",
336352
"dequantize.h",
337353
"ethosu.h",
@@ -648,12 +664,29 @@ tflm_cc_test(
648664
],
649665
)
650666

667+
tflm_cc_test(
668+
name = "decode_state_prune_test",
669+
srcs = [
670+
"decode_state_prune_test.cc",
671+
],
672+
deps = [
673+
":decode_test_helpers",
674+
":kernel_runner",
675+
"//tensorflow/lite/c:common",
676+
"//tensorflow/lite/micro:debug_log",
677+
"//tensorflow/lite/micro:op_resolvers",
678+
"//tensorflow/lite/micro:test_helpers",
679+
"//tensorflow/lite/micro/testing:micro_test",
680+
],
681+
)
682+
651683
tflm_cc_test(
652684
name = "decode_test",
653685
srcs = [
654686
"decode_test.cc",
655687
],
656688
deps = [
689+
":decode_test_helpers",
657690
":kernel_runner",
658691
"//tensorflow/lite/c:common",
659692
"//tensorflow/lite/micro:debug_log",

tensorflow/lite/micro/kernels/Makefile.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ceil_test.cc \
123123
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/comparisons_test.cc \
124124
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation_test.cc \
125125
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum_test.cc \
126+
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_prune_test.cc \
126127
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_test.cc \
127128
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \
128129
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \

tensorflow/lite/micro/kernels/arc_mli/mli_interface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class MliTensorInterface {
3333
public:
3434
// Make sure that lifetime of MliTensorInterface instance isn't bigger than
3535
// related mli_tensor.
36-
MliTensorInterface(mli_tensor* tensor) : tensor_(tensor) {};
36+
MliTensorInterface(mli_tensor* tensor) : tensor_(tensor){};
3737
MliTensorInterface() = default;
3838
~MliTensorInterface() = default;
3939

tensorflow/lite/micro/kernels/decode.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
6363
break;
6464
}
6565

66+
TF_LITE_ENSURE(context, IsConstantTensor(input));
67+
TF_LITE_ENSURE(context, IsConstantTensor(ancillary));
68+
6669
if (DecodeState::Version(*ancillary) != 1) {
6770
MicroPrintf("version %u != 1", DecodeState::Version(*ancillary));
6871
status = kTfLiteError;
@@ -75,6 +78,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
7578
dsp = DecodeState::CreateDecodeStateLUT(
7679
context, micro_context->GetAlternateProfiler());
7780
break;
81+
case DecodeState::kDcmTypePrune:
82+
dsp = DecodeState::CreateDecodeStatePrune(
83+
context, micro_context->GetAlternateProfiler());
84+
break;
7885
case DecodeState::kDcmTypeCustom:
7986
MicroPrintf("Custom decode type not yet supported");
8087
break;

tensorflow/lite/micro/kernels/decode_state.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "tensorflow/lite/micro/kernels/decode_state.h"
1717

1818
#include "tensorflow/lite/micro/kernels/decode_state_lut.h"
19+
#include "tensorflow/lite/micro/kernels/decode_state_prune.h"
1920
#include "tensorflow/lite/micro/micro_context.h"
2021

2122
namespace tflite {
@@ -33,4 +34,17 @@ DecodeState* DecodeState::CreateDecodeStateLUT(
3334
return dsp;
3435
}
3536

37+
DecodeState* DecodeState::CreateDecodeStatePrune(
38+
const TfLiteContext* context, MicroProfilerInterface* profiler) {
39+
MicroContext* const micro_context = GetMicroContext(context);
40+
void* buffer =
41+
micro_context->AllocatePersistentBuffer(sizeof(DecodeStatePrune));
42+
if (buffer == nullptr) {
43+
return nullptr;
44+
}
45+
DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler);
46+
47+
return dsp;
48+
}
49+
3650
} // namespace tflite

tensorflow/lite/micro/kernels/decode_state.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class DecodeState {
4343

4444
static DecodeState* CreateDecodeStateLUT(const TfLiteContext* context,
4545
MicroProfilerInterface* profiler);
46+
static DecodeState* CreateDecodeStatePrune(const TfLiteContext* context,
47+
MicroProfilerInterface* profiler);
4648

4749
static uint8_t Type(const TfLiteTensor& ancillary) {
4850
return GetTensorData<uint8_t>(&ancillary)[kDcmDecodeTypeOffset];
@@ -66,6 +68,7 @@ class DecodeState {
6668
// Decode Common Metadata constants
6769
public:
6870
static constexpr uint8_t kDcmTypeLUT = 0;
71+
static constexpr uint8_t kDcmTypePrune = 2;
6972
static constexpr uint8_t kDcmTypeCustom = 127;
7073

7174
static constexpr size_t kDcmSizeInBytes = 16;
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/lite/micro/kernels/decode_state_prune.h"
17+
18+
#include <algorithm>
19+
#include <cstddef>
20+
21+
#include "tensorflow/lite/kernels/internal/compatibility.h"
22+
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23+
#include "tensorflow/lite/kernels/kernel_util.h"
24+
#include "tensorflow/lite/micro/micro_context.h"
25+
#include "tensorflow/lite/micro/micro_log.h"
26+
#include "tensorflow/lite/micro/micro_profiler.h"
27+
28+
namespace tflite {
29+
30+
TfLiteStatus DecodeStatePrune::Setup(const TfLiteTensor& input,
31+
const TfLiteTensor& ancillary,
32+
const TfLiteTensor& output) {
33+
const uint8_t* const ancillary_data = GetTensorData<uint8_t>(&ancillary);
34+
if (ancillary_data[kDcmVersionOffset] != 1) {
35+
MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]);
36+
return kTfLiteError;
37+
}
38+
39+
// resolve num_channels_, use_alternate_axis_, and zero points
40+
if (output.quantization.type == kTfLiteAffineQuantization &&
41+
output.quantization.params != nullptr) {
42+
const TfLiteAffineQuantization* quantization =
43+
reinterpret_cast<TfLiteAffineQuantization*>(output.quantization.params);
44+
num_channels_ = quantization->scale->size;
45+
if ((quantization->quantized_dimension == output.dims->size - 1) &&
46+
num_channels_ > 1) {
47+
use_alternate_axis_ = true;
48+
} else if (quantization->quantized_dimension != 0) {
49+
MicroPrintf("unsupported quantization axis %u",
50+
quantization->quantized_dimension);
51+
return kTfLiteError;
52+
}
53+
54+
TFLITE_DCHECK(num_channels_ ==
55+
static_cast<size_t>(quantization->zero_point->size));
56+
bool has_non_zero_zp =
57+
std::any_of(quantization->zero_point->data,
58+
quantization->zero_point->data + num_channels_,
59+
[](int zp) { return zp != 0; });
60+
61+
if (output.type != kTfLiteInt8) {
62+
// make sure all zero points are 0 (zero)
63+
TF_LITE_ENSURE_MSG(const_cast<TfLiteContext*>(context_),
64+
has_non_zero_zp == false,
65+
"All zero-points must be zero");
66+
}
67+
68+
if (num_channels_ > 1 && has_non_zero_zp) {
69+
// copy zero points
70+
MicroContext* micro_context = GetMicroContext(context_);
71+
const size_t bufsize = num_channels_ * sizeof(*zero_points_);
72+
zero_points_ = static_cast<decltype(zero_points_)>(
73+
micro_context->AllocatePersistentBuffer(bufsize));
74+
if (zero_points_ == nullptr) {
75+
MicroPrintf("unable to allocate zero_points_");
76+
return kTfLiteError;
77+
}
78+
std::copy_n(quantization->zero_point->data, num_channels_, zero_points_);
79+
} else {
80+
single_zero_point_ = quantization->zero_point->data[0];
81+
}
82+
}
83+
84+
compressed_indices_ = GetTensorData<uint8_t>(&input);
85+
count_indices_ = NumElements(&output);
86+
elements_per_channel_ =
87+
use_alternate_axis_ ? 1 : count_indices_ / num_channels_;
88+
value_table_ = &ancillary_data[kDcmSizeInBytes];
89+
90+
return kTfLiteOk;
91+
}
92+
93+
TfLiteStatus DecodeStatePrune::Decode(const TfLiteEvalTensor& input,
94+
const TfLiteEvalTensor& ancillary,
95+
const TfLiteEvalTensor& output) {
96+
void* const buffer = const_cast<void*>(micro::GetTensorData<void>(&output));
97+
TFLITE_DCHECK(buffer != nullptr);
98+
99+
switch (output.type) {
100+
case kTfLiteBool:
101+
DecompressToBuffer<int8_t>(buffer);
102+
break;
103+
case kTfLiteFloat32:
104+
DecompressToBuffer<int32_t>(buffer);
105+
break;
106+
case kTfLiteInt8:
107+
if (num_channels_ > 1 && zero_points_ != nullptr) {
108+
DecompressToBufferPerChannelInt8(buffer);
109+
} else {
110+
DecompressToBuffer<int8_t>(buffer);
111+
}
112+
break;
113+
case kTfLiteInt16:
114+
DecompressToBuffer<int16_t>(buffer);
115+
break;
116+
case kTfLiteInt32:
117+
DecompressToBuffer<int32_t>(buffer);
118+
break;
119+
case kTfLiteInt64:
120+
DecompressToBuffer<int64_t>(buffer);
121+
break;
122+
default:
123+
MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type));
124+
return kTfLiteError;
125+
}
126+
127+
return kTfLiteOk;
128+
}
129+
130+
template <typename T>
131+
void DecodeStatePrune::DecompressToBuffer(void* vp) {
132+
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);
133+
134+
T* buffer = static_cast<T*>(vp);
135+
const T* value_table = static_cast<const T*>(value_table_);
136+
const size_t max_count = count_indices_;
137+
const uint8_t* const indices = compressed_indices_;
138+
139+
for (size_t index = 0; index < max_count; index++) {
140+
size_t shift = ~index & 0b111;
141+
size_t is_not_zp = (indices[index >> 3] >> shift) & 0b1;
142+
143+
if (is_not_zp) {
144+
*buffer++ = *value_table++;
145+
} else {
146+
*buffer++ = single_zero_point_;
147+
}
148+
}
149+
}
150+
151+
void DecodeStatePrune::DecompressToBufferPerChannelInt8(void* vp) {
152+
TFLITE_DCHECK(zero_points_ != nullptr);
153+
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);
154+
155+
int8_t* buffer = static_cast<int8_t*>(vp);
156+
size_t current_offset = 0;
157+
const uint8_t* const indices = compressed_indices_;
158+
const int8_t* value_table = static_cast<const int8_t*>(value_table_);
159+
160+
if (use_alternate_axis_) {
161+
const size_t max_channels = num_channels_;
162+
size_t count = count_indices_;
163+
164+
while (count > 0) {
165+
for (size_t channel = 0; channel < max_channels; channel++) {
166+
const int8_t zp = zero_points_[channel];
167+
size_t shift = ~current_offset & 0b111;
168+
size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1;
169+
170+
if (is_not_zp) {
171+
*buffer++ = *value_table++;
172+
} else {
173+
*buffer++ = zp;
174+
}
175+
current_offset++;
176+
}
177+
count -= max_channels;
178+
}
179+
} else {
180+
const size_t max_count = elements_per_channel_;
181+
182+
for (size_t channel = 0; channel < num_channels_; channel++) {
183+
size_t count = max_count;
184+
const int8_t zp = zero_points_[channel];
185+
186+
while (count-- > 0) {
187+
size_t shift = ~current_offset & 0b111;
188+
size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1;
189+
190+
if (is_not_zp) {
191+
*buffer++ = *value_table++;
192+
} else {
193+
*buffer++ = zp;
194+
}
195+
current_offset++;
196+
}
197+
}
198+
}
199+
}
200+
201+
template void DecodeStatePrune::DecompressToBuffer<int8_t>(void*);
202+
template void DecodeStatePrune::DecompressToBuffer<int16_t>(void*);
203+
template void DecodeStatePrune::DecompressToBuffer<int32_t>(void*);
204+
template void DecodeStatePrune::DecompressToBuffer<int64_t>(void*);
205+
206+
} // namespace tflite

0 commit comments

Comments
 (0)