Skip to content

Commit d3baeb6

Browse files
authored
Merge pull request #88 from tensil-ai/peter/sc-488/support-mobilenet-with-unknown-batch-size
Support MobileNet with unknown batch size
2 parents 07e35d6 + a23d801 commit d3baeb6

File tree

3 files changed

+152
-1
lines changed

3 files changed

+152
-1
lines changed

tools/src/tensil/tools/compiler/MemoryManager.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ class MemoryManager(
8585
def hasPendingFloatConst(name: String) =
8686
pendingFloatConsts.get(name).isDefined
8787

88+
def hasPendingLongConst(name: String) =
89+
pendingLongConsts.get(name).isDefined
90+
8891
def getPendingIntConst(name: String) = pendingIntConsts(name)
8992
def getPendingLongConst(name: String) = pendingLongConsts(name)
9093
def getPendingFloatConst(name: String) = pendingFloatConsts(name)

tools/src/tensil/tools/compiler/OnnxFrontend.scala

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,14 @@ class OnnxFrontend(
449449
emitGlobalPool(_, nodeProto),
450450
emitters
451451
)
452+
case "Gather" =>
453+
rewriteSimple(remainingProtos, emitGather(_, nodeProto), emitters)
454+
case "Unsqueeze" =>
455+
rewriteSimple(
456+
remainingProtos,
457+
emitUnsqueeze(_, nodeProto),
458+
emitters
459+
)
452460
case op =>
453461
throw new CompilerException(
454462
s"Unsupported op ${op} (${nodeProto.name.get})"
@@ -917,6 +925,69 @@ class OnnxFrontend(
917925
)
918926
}
919927

928+
private def emitGather(
929+
context: EmitContext,
930+
gatherProto: NodeProto
931+
): Unit = {
932+
val axisAttr = getAttr(gatherProto, "axis").get
933+
934+
require(axisAttr.`type`.get.isInt)
935+
936+
val axis = axisAttr.i.get
937+
938+
val data =
939+
context.mm
940+
.getPendingLongConst(gatherProto.input(0))
941+
.asInstanceOf[TensorData[Long]]
942+
943+
val indices = context.mm
944+
.getPendingLongConst(gatherProto.input(1))
945+
.asInstanceOf[TensorData[Long]]
946+
947+
if (axis != 0 || data.shape.size != 1 || indices.shape.size != 0)
948+
throw new CompilerException("Only 1D gather is supported");
949+
950+
if (indices.as1D(0) < 0 || indices.as1D(0) >= data.shape(0))
951+
throw new CompilerException("Gather index is outside of data shape");
952+
953+
context.mm.addPendingConst(
954+
gatherProto.output(0),
955+
new TensorData(
956+
Shape(),
957+
Seq(data.as1D(indices.as1D(0).toInt)),
958+
org.tensorflow.framework.types.DataType.DT_INT64
959+
)
960+
)
961+
}
962+
963+
private def emitUnsqueeze(
964+
context: EmitContext,
965+
unsqueezeProto: NodeProto
966+
): Unit = {
967+
val axesAttr = getAttr(unsqueezeProto, "axes").get
968+
969+
require(axesAttr.`type`.get.isInts)
970+
971+
val axes = axesAttr.ints
972+
973+
val data =
974+
context.mm
975+
.getPendingLongConst(unsqueezeProto.input(0))
976+
.asInstanceOf[TensorData[Long]]
977+
978+
if (axes.size != 1 || axes(0) != 0 || data.shape.size != 0)
979+
throw new CompilerException("Only scalar unsqueeze is supported");
980+
981+
context.mm.addPendingConst(
982+
unsqueezeProto.output(0),
983+
new TensorData(
984+
Shape(1),
985+
data.as1D,
986+
org.tensorflow.framework.types.DataType.DT_INT64
987+
)
988+
)
989+
}
990+
920991
private def emitConstant(
921992
context: EmitContext,
922993
constantProto: NodeProto
@@ -1058,7 +1129,10 @@ class OnnxFrontend(
10581129
context: EmitContext,
10591130
reshapeProto: NodeProto
10601131
): Unit = {
1061-
val shape = getTensorData(tensorProtos(reshapeProto.input(1)))
1132+
val shapeInputName = reshapeProto.input(1)
1133+
val shape = (if (tensorProtos.contains(shapeInputName))
1134+
getTensorData(tensorProtos(shapeInputName))
1135+
else context.mm.getPendingLongConst(shapeInputName))
10621136
.asInstanceOf[TensorData[Long]]
10631137
.as1D
10641138
.map(_.toInt)
@@ -1439,6 +1513,27 @@ class OnnxFrontend(
14391513
org.tensorflow.framework.types.DataType.DT_FLOAT
14401514
)
14411515
)
1516+
} else if (
1517+
concatProto.input.forall(name =>
1518+
context.mm.hasPendingLongConst(name) || tensorProtos.contains(name)
1519+
)
1520+
) {
1521+
val output = concatProto.input
1522+
.map(name =>
1523+
(if (context.mm.hasPendingLongConst(name))
1524+
context.mm.getPendingLongConst(name)
1525+
else getTensorData(tensorProtos(name))).as1D
1526+
)
1527+
.flatten
1528+
1529+
context.mm.addPendingConst(
1530+
concatProto.output(0),
1531+
new TensorData(
1532+
Shape(output.size),
1533+
output,
1534+
org.tensorflow.framework.types.DataType.DT_INT64
1535+
)
1536+
)
14421537
} else {
14431538

14441539
if (axis != 1)

tools/test/src/tools/CompilerSpec.scala

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,6 +2768,59 @@ class CompilerSpec extends AnyFlatSpec {
27682768
)
27692769
}
27702770

2771+
it should "Compile ONNX float MobileNetV2 with input batch of 3" taggedAs (Slow) in {
2772+
val name = "mobilenetv2_float_onnx"
2773+
val traceContext = new ExecutiveTraceContext()
2774+
val options = CompilerOptions(
2775+
arch = MobileNetFloat32Architecture,
2776+
inputShapes = CompilerInputShapes.mkWithBatchSize(3),
2777+
printSummary = true
2778+
)
2779+
2780+
Compiler.compile(
2781+
name,
2782+
s"${Models}/mobilenetv2.onnx",
2783+
List("output"),
2784+
options,
2785+
traceContext
2786+
)
2787+
2788+
EmulatorHelper.test(
2789+
name,
2790+
inputBatchSize = options.inputShapes.batchSize,
2791+
traceContext = traceContext
2792+
)
2793+
}
2794+
2795+
it should "Compile ONNX fixed18bp10 MobileNetV2 with input batch of 3" taggedAs (Slow) in {
2796+
val name = "mobilenetv2_fixed18bp10_onnx"
2797+
val traceContext = new ExecutiveTraceContext()
2798+
val options = CompilerOptions(
2799+
arch = MobileNetFp18bp10Architecture,
2800+
inputShapes = CompilerInputShapes.mkWithBatchSize(3),
2801+
printSummary = true,
2802+
printLayersSummary = true,
2803+
printGraph = true,
2804+
tracepointConditions = List(
2805+
TracepointCondition(MemoryTag.DRAM0, "output")
2806+
)
2807+
)
2808+
2809+
Compiler.compile(
2810+
name,
2811+
s"${Models}/mobilenetv2.onnx",
2812+
List("output"),
2813+
options,
2814+
traceContext
2815+
)
2816+
2817+
EmulatorHelper.test(
2818+
name,
2819+
inputBatchSize = options.inputShapes.batchSize,
2820+
traceContext = traceContext
2821+
)
2822+
}
2823+
27712824
val SpeechCommandsFp16bp8Architecture = Architecture.mkWithDefaults(
27722825
dataType = ArchitectureDataType.FP16BP8,
27732826
arraySize = 8,

0 commit comments

Comments
 (0)