Skip to content

Commit 9245002

Browse files
authored
Fix issue with transposing shape in CMSIS-NN batch matmul (tensorflow#2741)
BUG=tensorflow#2740
1 parent 4bb78c7 commit 9245002

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,12 @@ inline TfLiteStatus PopulateEvalData(
121121
RuntimeShape tmp_r = SwapRowColumnDims(*rhs_shape);
122122
rhs_shape->ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData());
123123
}
124-
if (!params->adj_x) {
124+
// ReferenceOps and CMSIS-NN have different requirements for when the
125+
// lhs shape should be transposed, so we have to treat float differently.
126+
if (!params->adj_x && original_lhs_input->type == kTfLiteFloat32) {
127+
RuntimeShape tmp_l = SwapRowColumnDims(*lhs_shape);
128+
lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData());
129+
} else if (params->adj_x && original_lhs_input->type != kTfLiteFloat32) {
125130
RuntimeShape tmp_l = SwapRowColumnDims(*lhs_shape);
126131
lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData());
127132
}

0 commit comments

Comments
 (0)