@@ -12,7 +12,8 @@ import onnx.onnx.{NodeProto, ModelProto, TensorProto, ValueInfoProto}
12
12
import _root_ .tensil .tools .{
13
13
CompilerException ,
14
14
TracepointCondition ,
15
- CompilerOptions
15
+ CompilerOptions ,
16
+ CompilerInputShapesHelper
16
17
}
17
18
import _root_ .tensil .tools .data .{Shape , TensorData }
18
19
import _root_ .tensil .tools .util
@@ -579,11 +580,10 @@ class OnnxFrontend(
579
580
580
581
private def emitInput (context : EmitContext ): EmitResult = {
581
582
for ((name, valueInfoProto) <- inputValueInfoProtos) {
582
- val shape = Shape (
583
+ val modelInputShape =
583
584
valueInfoProto.`type`.get.value.tensorType.get.shape.get.dim
584
- .map(_.value.dimValue.get.toInt)
585
- .toArray
586
- )
585
+ .map(_.value.dimValue.map(_.toInt))
586
+ val shape = options.inputShapes.deduceInputShape(name, modelInputShape)
587
587
588
588
val consumers = inputNodeNames(name)
589
589
@@ -1844,6 +1844,9 @@ class OnnxFrontend(
1844
1844
scheduler : Scheduler ,
1845
1845
conv2DProto : NodeProto
1846
1846
): MemoryObject = {
1847
+ val autoPadAttr = getAttr(conv2DProto, " auto_pad" )
1848
+ val autoPad = autoPadAttr.map(_.s.get.toStringUtf8())
1849
+
1847
1850
val padsAttr = getAttr(conv2DProto, " pads" )
1848
1851
1849
1852
val pads = if (padsAttr.isDefined) {
@@ -1880,15 +1883,6 @@ class OnnxFrontend(
1880
1883
)
1881
1884
}
1882
1885
1883
- val (paddingTop, paddingLeft, paddingBottom, paddingRight) =
1884
- pads.map(_.toInt) match {
1885
- case Seq (t, l, b, r) => (t, l, b, r)
1886
- case _ =>
1887
- throw new CompilerException (
1888
- s " Unsupported pads [ ${pads.mkString(" , " )}] "
1889
- )
1890
- }
1891
-
1892
1886
context.mm.addPendingConst(
1893
1887
conv2DProto.input(1 ),
1894
1888
getTensorData(tensorProtos(conv2DProto.input(1 )))
@@ -1907,6 +1901,45 @@ class OnnxFrontend(
1907
1901
else None ,
1908
1902
)
1909
1903
1904
+ val (paddingTop, paddingLeft, paddingBottom, paddingRight) =
1905
+ pads.map(_.toInt) match {
1906
+ case Seq (t, l, b, r) =>
1907
+ val paddingWidth =
1908
+ (weights.dims.width.toDouble - 1 ) / 2
1909
+ val paddingHeight =
1910
+ (weights.dims.height.toDouble - 1 ) / 2
1911
+
1912
+ autoPad match {
1913
+ case Some (" SAME_UPPER" ) =>
1914
+ (
1915
+ Math .floor(paddingHeight).toInt,
1916
+ Math .floor(paddingWidth).toInt,
1917
+ Math .ceil(paddingHeight).toInt,
1918
+ Math .ceil(paddingWidth).toInt
1919
+ )
1920
+
1921
+ case Some (" SAME_LOWER" ) =>
1922
+ (
1923
+ Math .ceil(paddingHeight).toInt,
1924
+ Math .ceil(paddingWidth).toInt,
1925
+ Math .floor(paddingHeight).toInt,
1926
+ Math .floor(paddingWidth).toInt
1927
+ )
1928
+
1929
+ case None | Some (" NOTSET" ) => (t, l, b, r)
1930
+ case Some (v) =>
1931
+ throw new CompilerException (
1932
+ s " Unsupported auto_pad attribute $v"
1933
+ )
1934
+
1935
+ }
1936
+
1937
+ case _ =>
1938
+ throw new CompilerException (
1939
+ s " Unsupported pads [ ${pads.mkString(" , " )}] "
1940
+ )
1941
+ }
1942
+
1910
1943
val inputVars =
1911
1944
context.mm.consumeObject(conv2DProto.input(0 ), Seq (conv2DProto.name.get))
1912
1945
@@ -2471,7 +2504,15 @@ class OnnxFrontend(
2471
2504
val input1Name =
2472
2505
if (addProto.input(0 ) == input0Temp.name) addProto.input(1 )
2473
2506
else addProto.input(0 )
2474
- val input1Vars =
2507
+
2508
+ val input1Vars = if (tensorProtos.isDefinedAt(input1Name)) {
2509
+ context.mm.addPendingConst(
2510
+ input1Name,
2511
+ getTensorData(tensorProtos(input1Name))
2512
+ )
2513
+
2514
+ context.mm.getOrEmitConstObject(input1Name)
2515
+ } else
2475
2516
context.mm.consumeObject(input1Name, Seq (addProto.name.get))
2476
2517
2477
2518
scheduler.emitAdd(
0 commit comments