Skip to content

Commit e1062f4

Browse files
authored
Merge pull request #53 from tensil-ai/peter/sc-446/support-input-shape-and-onnx-frontend-improvements
2 parents d49f0d4 + cf685f7 commit e1062f4

File tree

8 files changed

+377
-113
lines changed

8 files changed

+377
-113
lines changed

sim/test/src/zynq/tcu/AXIWrapperTCUSpec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import tensil.{
2525
}
2626
import tensil.data.InstructionReader
2727

28-
import tensil.tools.{Compiler, CompilerOptions}
28+
import tensil.tools.{Compiler, CompilerOptions, CompilerInputShapes}
2929
import tensil.tools.compiler.MemoryAddressHelper
3030
import tensil.{InstructionLayout}
3131

@@ -100,7 +100,7 @@ class AXIWrapperTCUSpec extends FunUnitSpec {
100100
// compiler parameters
101101
val options = CompilerOptions(
102102
arch = arch,
103-
inputBatchSize = batchSize
103+
inputShapes = CompilerInputShapes.mkWithBatchSize(batchSize),
104104
)
105105

106106
// setup compiler input/output streams
@@ -225,7 +225,7 @@ class AXIWrapperTCUSpec extends FunUnitSpec {
225225
// compiler parameters
226226
val options = CompilerOptions(
227227
arch = arch,
228-
inputBatchSize = batchSize,
228+
inputShapes = CompilerInputShapes.mkWithBatchSize(batchSize),
229229
//printProgramFileName = Some(s"sim_resnet20v2_cifar.tasm"),
230230
)
231231

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package tensil.tools
2+
3+
object CompilerInputShapes {
4+
def parse(shapeStrings: String): CompilerInputShapes = {
5+
shapeStrings
6+
.split("]\\S*,")
7+
.map(_.trim())
8+
.map(shapeString => {
9+
val Seq(nameString, dimsString) =
10+
shapeString.split('[').map(_.trim()).toSeq
11+
12+
val shape = dimsString
13+
.split(']')(0)
14+
.split(',')
15+
.map(_.trim())
16+
.map(dimString =>
17+
if (dimString.isEmpty())
18+
None
19+
else
20+
Some(Integer.parseInt(dimString))
21+
)
22+
.toSeq
23+
24+
val name =
25+
if (nameString.isEmpty())
26+
None
27+
else
28+
Some(nameString)
29+
30+
(name, shape)
31+
})
32+
.toMap
33+
}
34+
35+
def mkWithBatchSize(batchSize: Int): CompilerInputShapes =
36+
Map(
37+
None -> Seq(Some(batchSize))
38+
)
39+
}

tools/src/tensil/tools/CompilerOptions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ case class TracepointCondition(
1212

1313
case class CompilerOptions(
1414
arch: Architecture,
15-
inputBatchSize: Int = 1,
15+
inputShapes: CompilerInputShapes = CompilerInputShapes.mkWithBatchSize(1),
1616
printSummary: Boolean = false,
1717
printLayersSummary: Boolean = false,
1818
printSchedulerSummary: Boolean = false,

tools/src/tensil/tools/Main.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ case class Args(
1313
archFile: File = new File("."),
1414
modelFile: File = new File("."),
1515
outputNodes: Seq[String] = Seq("Identity"),
16-
inputBatchSize: Int = 1,
16+
inputShapes: String = "[1]",
1717
verbose: Boolean = false,
1818
summary: Boolean = false,
1919
layersSummary: Boolean = false,
@@ -41,14 +41,16 @@ object Main extends App {
4141
.text("Tensil architecture descrition (.tarch) file")
4242

4343
opt[Seq[String]]('o', "output")
44-
.valueName("<names>")
44+
.valueName("<name>, ...")
4545
.action((x, c) => c.copy(outputNodes = x))
4646
.text("Optional list of output nodes, defaults to \"Identity\"")
4747

48-
opt[Int]('b', "batch")
49-
.valueName("<integer>")
50-
.action((x, c) => c.copy(inputBatchSize = x))
51-
.text("Optional size of input batch, defaults to 1")
48+
opt[String]('i', "input-shapes")
49+
.valueName("<name> [<dim>, ...], ...")
50+
.action((x, c) => c.copy(inputShapes = x))
51+
.text(
52+
"Optional input shapes, defaults to \"[1]\" (batch size of 1). The shape without <name> is a default for inputs that were not listed by name"
53+
)
5254

5355
opt[Boolean]('v', "verbose")
5456
.valueName("true|false")
@@ -96,7 +98,7 @@ object Main extends App {
9698

9799
val options = CompilerOptions(
98100
arch = arch,
99-
inputBatchSize = args.inputBatchSize,
101+
inputShapes = CompilerInputShapes.parse(args.inputShapes),
100102
printProgress = args.verbose,
101103
printSummary = args.summary,
102104
printLayersSummary = args.layersSummary,

tools/src/tensil/tools/compiler/OnnxFrontend.scala

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ import onnx.onnx.{NodeProto, ModelProto, TensorProto, ValueInfoProto}
1212
import _root_.tensil.tools.{
1313
CompilerException,
1414
TracepointCondition,
15-
CompilerOptions
15+
CompilerOptions,
16+
CompilerInputShapesHelper
1617
}
1718
import _root_.tensil.tools.data.{Shape, TensorData}
1819
import _root_.tensil.tools.util
@@ -579,11 +580,10 @@ class OnnxFrontend(
579580

580581
private def emitInput(context: EmitContext): EmitResult = {
581582
for ((name, valueInfoProto) <- inputValueInfoProtos) {
582-
val shape = Shape(
583+
val modelInputShape =
583584
valueInfoProto.`type`.get.value.tensorType.get.shape.get.dim
584-
.map(_.value.dimValue.get.toInt)
585-
.toArray
586-
)
585+
.map(_.value.dimValue.map(_.toInt))
586+
val shape = options.inputShapes.deduceInputShape(name, modelInputShape)
587587

588588
val consumers = inputNodeNames(name)
589589

@@ -1844,6 +1844,9 @@ class OnnxFrontend(
18441844
scheduler: Scheduler,
18451845
conv2DProto: NodeProto
18461846
): MemoryObject = {
1847+
val autoPadAttr = getAttr(conv2DProto, "auto_pad")
1848+
val autoPad = autoPadAttr.map(_.s.get.toStringUtf8())
1849+
18471850
val padsAttr = getAttr(conv2DProto, "pads")
18481851

18491852
val pads = if (padsAttr.isDefined) {
@@ -1880,15 +1883,6 @@ class OnnxFrontend(
18801883
)
18811884
}
18821885

1883-
val (paddingTop, paddingLeft, paddingBottom, paddingRight) =
1884-
pads.map(_.toInt) match {
1885-
case Seq(t, l, b, r) => (t, l, b, r)
1886-
case _ =>
1887-
throw new CompilerException(
1888-
s"Unsupported pads [${pads.mkString(", ")}]"
1889-
)
1890-
}
1891-
18921886
context.mm.addPendingConst(
18931887
conv2DProto.input(1),
18941888
getTensorData(tensorProtos(conv2DProto.input(1)))
@@ -1907,6 +1901,45 @@ class OnnxFrontend(
19071901
else None,
19081902
)
19091903

1904+
val (paddingTop, paddingLeft, paddingBottom, paddingRight) =
1905+
pads.map(_.toInt) match {
1906+
case Seq(t, l, b, r) =>
1907+
val paddingWidth =
1908+
(weights.dims.width.toDouble - 1) / 2
1909+
val paddingHeight =
1910+
(weights.dims.height.toDouble - 1) / 2
1911+
1912+
autoPad match {
1913+
case Some("SAME_UPPER") =>
1914+
(
1915+
Math.floor(paddingHeight).toInt,
1916+
Math.floor(paddingWidth).toInt,
1917+
Math.ceil(paddingHeight).toInt,
1918+
Math.ceil(paddingWidth).toInt
1919+
)
1920+
1921+
case Some("SAME_LOWER") =>
1922+
(
1923+
Math.ceil(paddingHeight).toInt,
1924+
Math.ceil(paddingWidth).toInt,
1925+
Math.floor(paddingHeight).toInt,
1926+
Math.floor(paddingWidth).toInt
1927+
)
1928+
1929+
case None | Some("NOTSET") => (t, l, b, r)
1930+
case Some(v) =>
1931+
throw new CompilerException(
1932+
s"Unsupported auto_pad attribute $v"
1933+
)
1934+
1935+
}
1936+
1937+
case _ =>
1938+
throw new CompilerException(
1939+
s"Unsupported pads [${pads.mkString(", ")}]"
1940+
)
1941+
}
1942+
19101943
val inputVars =
19111944
context.mm.consumeObject(conv2DProto.input(0), Seq(conv2DProto.name.get))
19121945

@@ -2471,7 +2504,15 @@ class OnnxFrontend(
24712504
val input1Name =
24722505
if (addProto.input(0) == input0Temp.name) addProto.input(1)
24732506
else addProto.input(0)
2474-
val input1Vars =
2507+
2508+
val input1Vars = if (tensorProtos.isDefinedAt(input1Name)) {
2509+
context.mm.addPendingConst(
2510+
input1Name,
2511+
getTensorData(tensorProtos(input1Name))
2512+
)
2513+
2514+
context.mm.getOrEmitConstObject(input1Name)
2515+
} else
24752516
context.mm.consumeObject(input1Name, Seq(addProto.name.get))
24762517

24772518
scheduler.emitAdd(

tools/src/tensil/tools/compiler/TfFrontend.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ import org.tensorflow.framework.types.DataType
1414
import _root_.tensil.tools.{
1515
CompilerException,
1616
TracepointCondition,
17-
CompilerOptions
17+
CompilerOptions,
18+
CompilerInputShapesHelper
1819
}
1920
import _root_.tensil.tools.data.{Shape, TensorData}
2021
import _root_.tensil.tools.util
@@ -575,16 +576,12 @@ class TfFrontend(
575576
context: EmitContext,
576577
placeholderDef: NodeDef
577578
): EmitResult = {
578-
val shape = util.getShape(placeholderDef)
579-
579+
val modelInputShape = util
580+
.getShape(placeholderDef)
581+
.map(dim => if (dim >= 0) Some(dim) else None)
582+
.toSeq
580583
val placeholderShape =
581-
if (shape(0) == -1)
582-
Shape(
583-
options.inputBatchSize +: shape
584-
.takeRight(shape.size - 1)
585-
.toArray
586-
)
587-
else shape
584+
options.inputShapes.deduceInputShape(placeholderDef.name, modelInputShape)
588585

589586
val placeholderDims = VarsDimensions(placeholderShape)
590587

tools/src/tensil/tools/package.scala

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,58 @@
44
package tensil
55

66
import tensil.tools.compiler.{MemoryAddress, MemoryObject}
7+
import tensil.tools.data.Shape
78

89
package object tools {
9-
type TracepointsMap = Map[MemoryAddress, List[MemoryObject]]
10-
type CompilerSourceType = String
10+
type TracepointsMap = Map[MemoryAddress, List[MemoryObject]]
11+
type CompilerSourceType = String
12+
type CompilerInputShape = Seq[Option[Int]]
13+
type CompilerInputShapes = Map[Option[String], CompilerInputShape]
14+
15+
implicit class CompilerInputShapeHelper(val inputShape: CompilerInputShape) {
16+
override def toString() =
17+
s"[${inputShape.map(v => if (v.isDefined) v.get.toString else "?").mkString(", ")}]"
18+
}
19+
20+
implicit class CompilerInputShapesHelper(
21+
val inputShapes: CompilerInputShapes
22+
) {
23+
def batchSize = inputShapes.head._2(0).get
24+
25+
def deduceInputShape(
26+
name: String,
27+
modelInputShape: CompilerInputShape
28+
): Shape = {
29+
Shape(
30+
modelInputShape.zipWithIndex
31+
.map({
32+
case (modelDim, i) =>
33+
val optionsInputShape = inputShapes
34+
.getOrElse(Some(name), inputShapes(None))
35+
36+
if (
37+
optionsInputShape
38+
.isDefinedAt(i) && optionsInputShape(i).isDefined
39+
) {
40+
val optionsDim = optionsInputShape(i).get
41+
42+
if (modelDim.isDefined && modelDim.get != optionsDim)
43+
throw new CompilerException(
44+
s"Specified input shape for $name ${CompilerInputShapeHelper(optionsInputShape)} is incompatible with model shape ${CompilerInputShapeHelper(modelInputShape)}"
45+
)
46+
47+
optionsDim
48+
} else {
49+
if (modelDim.isDefined)
50+
modelDim.get
51+
else
52+
throw new CompilerException(
53+
s"Specified input shape for $name ${CompilerInputShapeHelper(optionsInputShape)} has unspecified dimensions in model shape ${CompilerInputShapeHelper(modelInputShape)}"
54+
)
55+
}
56+
})
57+
.toArray
58+
)
59+
}
60+
}
1161
}

0 commit comments

Comments
 (0)