Skip to content

Commit 4e2d75a

Browse files
authored
Merge pull request #45 from trulyspinach/fix-reshape
Fix ONNX frontend Reshape behavior
2 parents 47b4443 + 5a53b6d commit 4e2d75a

File tree

4 files changed

+121
-5
lines changed

4 files changed

+121
-5
lines changed

tools/src/tensil/tools/compiler/OnnxFrontend.scala

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -975,11 +975,39 @@ class OnnxFrontend(
975975
throw new CompilerException("Scalar sizes must match for reshape")
976976

977977
val indexesAndOffsetsPairs =
978-
for (i <- 0 until inputDims.sizeScalars) yield {
979-
val (inputIndex, inputOffset) = inputDims.vectorIndexOffsetAt(i)
980-
val (outputIndex, outputOffset) = outputDims.vectorIndexOffsetAt(i)
981-
((outputIndex, inputIndex), (outputOffset, inputOffset))
982-
}
978+
if (inputDims.order == 2 && outputDims.order == 4) {
979+
// when reshaping from 1D to 4D, a NCHW to NHWC permute(transpose)
980+
// need to be performed to match ONNX's behavior.
981+
982+
val dimMap = (3 to 0 by -1).map(shape.takeRight(_).fold(1)(_ * _))
983+
val outShape = Seq(shape(0), shape(2), shape(3), shape(1))
984+
val dimMapOut = (3 to 0 by -1).map(outShape.takeRight(_).fold(1)(_ * _))
985+
986+
for (
987+
n <- 0 until shape(0);
988+
c <- 0 until shape(1);
989+
h <- 0 until shape(2);
990+
w <- 0 until shape(3)
991+
) yield {
992+
993+
val from = Array(n, c, h, w).zipWithIndex
994+
.map { case (e, i) => dimMap(i) * e }
995+
.fold(0)(_ + _)
996+
997+
val to = Array(n, h, w, c).zipWithIndex
998+
.map { case (e, i) => dimMapOut(i) * e }
999+
.fold(0)(_ + _)
1000+
1001+
val (inputIndex, inputOffset) = inputDims.vectorIndexOffsetAt(from)
1002+
val (outputIndex, outputOffset) = outputDims.vectorIndexOffsetAt(to)
1003+
((outputIndex, inputIndex), (outputOffset, inputOffset))
1004+
}
1005+
} else
1006+
for (i <- 0 until inputDims.sizeScalars) yield {
1007+
val (inputIndex, inputOffset) = inputDims.vectorIndexOffsetAt(i)
1008+
val (outputIndex, outputOffset) = outputDims.vectorIndexOffsetAt(i)
1009+
((outputIndex, inputIndex), (outputOffset, inputOffset))
1010+
}
9831011

9841012
val groupedByOffsetPairs = indexesAndOffsetsPairs
9851013
.groupBy(_._1)

tools/test/src/tools/CompilerSpec.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,23 @@ class CompilerSpec extends FlatSpec {
12511251
localDepth = 256,
12521252
)
12531253

1254+
it should "Compile ONNX Reshape from 1D(NCHW) to 4D(NHWC)" in {
1255+
val name = "reshape_1d_4d"
1256+
val options = CompilerOptions(
1257+
arch = Conv2DTiny4x4Architecure,
1258+
printSummary = true
1259+
)
1260+
1261+
Compiler.compile(
1262+
name,
1263+
s"${Models}/reshape_1d_4d.onnx",
1264+
List("output"),
1265+
options
1266+
)
1267+
1268+
GoldenProcessorHelper.test(name, inputBatchSize = options.inputBatchSize)
1269+
}
1270+
12541271
it should "Compile TF Conv2D (VALID padding) 3x3x4 image with 2x2x4x4 kernel" in {
12551272
val name = "conv2d_4x4_valid"
12561273
val options = CompilerOptions(

tools/test/src/tools/GoldenProcessorHelper.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ object GoldenProcessorHelper {
9595
ResNet.prepareInputStream(dataType, arraySize, count)
9696
else if (modelName.startsWith("resnet50v2"))
9797
ResNet50.prepareInputStream(dataType, arraySize, count)
98+
else if (modelName.startsWith("reshape_1d_4d"))
99+
Reshape.prepareInputStream(dataType, arraySize, count)
98100
else if (yoloPattern.findFirstIn(modelName).isDefined) {
99101
val yoloPattern(yoloSize) = modelName
100102
TinyYolo(yoloSize.toInt, onnx = modelName.endsWith("onnx"))
@@ -174,6 +176,8 @@ object GoldenProcessorHelper {
174176
ResNet.assertOutput(dataType, arraySize, bytes, count)
175177
else if (modelName.startsWith("resnet50v2"))
176178
ResNet50.assertOutput(dataType, arraySize, bytes, count)
179+
else if (modelName.startsWith("reshape_1d_4d"))
180+
Reshape.assertOutput(dataType, arraySize, bytes, count)
177181
else if (yoloPattern.findFirstIn(modelName).isDefined) {
178182
val yoloPattern(yoloSize) = modelName
179183
TinyYolo(yoloSize.toInt, onnx = modelName.endsWith("onnx"))

tools/test/src/tools/Reshape.scala

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/* SPDX-License-Identifier: Apache-2.0 */
2+
/* Copyright © 2019-2022 Tensil AI Company */
3+
4+
package tensil.tools
5+
6+
import java.io._
7+
import scala.reflect.ClassTag
8+
import tensil.tools.golden.{Processor, ExecutiveTraceContext}
9+
import scala.collection.mutable
10+
import tensil.ArchitectureDataType
11+
12+
object Reshape {
13+
def prepareInputStream(
14+
dataType: ArchitectureDataType,
15+
arraySize: Int,
16+
count: Int = 1
17+
): InputStream =
18+
new ByteArrayInputStream(prepareInputBytes(dataType, arraySize, count))
19+
20+
def prepareInputBytes(
21+
dataType: ArchitectureDataType,
22+
arraySize: Int,
23+
count: Int = 1
24+
): Array[Byte] = {
25+
val inputPrep = new ByteArrayOutputStream()
26+
val inputPrepDataStream = new DataOutputStream(inputPrep)
27+
28+
val seq = (1 to 8).map(_.toFloat).toArray.grouped(arraySize)
29+
for (s <- seq)
30+
Util.writeArgs(dataType, inputPrepDataStream, arraySize, s: _*)
31+
32+
inputPrep.toByteArray()
33+
}
34+
35+
def assertOutput(
36+
dataType: ArchitectureDataType,
37+
arraySize: Int,
38+
bytes: Array[Byte],
39+
count: Int = 1
40+
): Unit = {
41+
val rmse = new RMSE()
42+
43+
val output =
44+
new DataInputStream(new ByteArrayInputStream(bytes))
45+
46+
val outputSize = arraySize * 4
47+
48+
val result = Util
49+
.readResult(dataType, output, arraySize, outputSize)
50+
.toArray
51+
52+
val golden = Golden
53+
.grouped(2)
54+
.map(_ ++ Array.fill(arraySize - 2)(0f))
55+
.flatten
56+
.toArray
57+
58+
for (i <- 0 until outputSize)
59+
rmse.addSample(result(i), golden(i))
60+
61+
assert(rmse.compute < dataType.error)
62+
}
63+
64+
private val Golden = Seq(
65+
1f, 5f, 2f, 6f, 3f, 7f, 4f, 8f
66+
)
67+
}

0 commit comments

Comments
 (0)