@@ -32,6 +32,16 @@ class OnnxFrontend(
3232 graphStream : Option [OutputStream ],
3333 options : CompilerOptions
3434) extends Frontend {
35+ val opsetVersion = modelProto.opsetImport(0 ).version.get
36+
37+ val MinOpsetVersion = 9
38+ val MaxOpsetVersion = 10
39+
40+ if (opsetVersion < MinOpsetVersion || opsetVersion > MaxOpsetVersion )
41+ throw new CompilerException (
42+ s " ONNX opset ${opsetVersion} is not supported. Supported range is [ ${MinOpsetVersion }, ${MaxOpsetVersion }]. "
43+ )
44+
3545 private object VarsDimensions {
3646 def apply (
3747 number : Int ,
@@ -89,13 +99,29 @@ class OnnxFrontend(
8999 )
90100
91101 def apply (height : Int , width : Int ): MemoryDimensions =
92- MemoryDimensions (
93- arraySize = arch.arraySize,
94- " HW" ,
95- " HW" ,
96- isWeights = true ,
97- dimensions = Vector (height, width)
98- )
102+ apply(height, width, false )
103+
104+ def apply (
105+ height : Int ,
106+ width : Int ,
107+ transpose : Boolean
108+ ): MemoryDimensions =
109+ if (transpose)
110+ MemoryDimensions (
111+ arraySize = arch.arraySize,
112+ " WH" ,
113+ " HW" ,
114+ isWeights = true ,
115+ dimensions = Vector (height, width)
116+ )
117+ else
118+ MemoryDimensions (
119+ arraySize = arch.arraySize,
120+ " HW" ,
121+ " HW" ,
122+ isWeights = true ,
123+ dimensions = Vector (height, width)
124+ )
99125
100126 def apply (
101127 channelsOut : Int ,
@@ -111,21 +137,27 @@ class OnnxFrontend(
111137 dimensions = Vector (channelsOut, channelsIn, height, width)
112138 )
113139
114- def apply (shape : Shape ): MemoryDimensions =
115- if (shape.size == 1 )
116- ConstsDimensions (shape(0 ))
117- else if (shape.size == 2 )
118- ConstsDimensions (shape(0 ), shape(1 ))
119- else if (shape.size == 4 )
120- ConstsDimensions (shape(0 ), shape(1 ), shape(2 ), shape(3 ))
121- else
140+ def apply (shape : Shape , transpose : Boolean ): MemoryDimensions =
141+ if (transpose && shape.size != 2 )
122142 throw new CompilerException (
123- s " Consts tensor shape of ${shape} is not supported "
143+ s " Transposing consts is supported for 2D tensors only "
124144 )
145+ else {
146+ if (shape.size == 1 )
147+ ConstsDimensions (shape(0 ))
148+ else if (shape.size == 2 )
149+ ConstsDimensions (shape(0 ), shape(1 ), transpose)
150+ else if (shape.size == 4 )
151+ ConstsDimensions (shape(0 ), shape(1 ), shape(2 ), shape(3 ))
152+ else
153+ throw new CompilerException (
154+ s " Consts tensor shape of ${shape} is not supported "
155+ )
156+ }
125157 }
126158
127- def mkConstsDimensions (shape : Shape ): MemoryDimensions =
128- ConstsDimensions (shape)
159+ def mkConstsDimensions (shape : Shape , transpose : Boolean ): MemoryDimensions =
160+ ConstsDimensions (shape, transpose )
129161
130162 private val nodeProtos = mutable.Map .empty[String , NodeProto ]
131163 private val tensorProtos = mutable.Map .empty[String , TensorProto ]
@@ -346,6 +378,8 @@ class OnnxFrontend(
346378 rewriteLayer(remainingProtos, nodeProto, emitters)
347379 case " Reshape" =>
348380 rewriteSimple(remainingProtos, emitReshape(_, nodeProto), emitters)
381+ case " Flatten" =>
382+ rewriteSimple(remainingProtos, emitFlatten(_, nodeProto), emitters)
349383 case " Split" =>
350384 rewriteSimple(remainingProtos, emitSplit(_, nodeProto), emitters)
351385 case " Concat" =>
@@ -1007,17 +1041,54 @@ class OnnxFrontend(
10071041 context : EmitContext ,
10081042 reshapeProto : NodeProto
10091043 ): Unit = {
1010- val inputVars =
1011- context.mm
1012- .consumeObject(reshapeProto.input(0 ), Seq (reshapeProto.name.get))
1013- val inputDims = inputVars.dims
1014-
10151044 val shape = getTensorData(tensorProtos(reshapeProto.input(1 )))
10161045 .asInstanceOf [TensorData [Long ]]
10171046 .as1D
10181047 .map(_.toInt)
10191048 .toArray
10201049
1050+ val inputVars =
1051+ context.mm
1052+ .consumeObject(reshapeProto.input(0 ), Seq (reshapeProto.name.get))
1053+
1054+ doEmitReshape(context, reshapeProto, inputVars, shape)
1055+ }
1056+
1057+ private def emitFlatten (
1058+ context : EmitContext ,
1059+ flattenProto : NodeProto
1060+ ): Unit = {
1061+ val axisAttr = getAttr(flattenProto, " axis" ).get
1062+
1063+ require(axisAttr.`type`.get.isInt)
1064+
1065+ val axis = axisAttr.i.get.toInt
1066+
1067+ val inputVars =
1068+ context.mm
1069+ .consumeObject(flattenProto.input(0 ), Seq (flattenProto.name.get))
1070+
1071+ val shape =
1072+ if (axis == 0 ) Array (1 , inputVars.dims.modelDimensions.product)
1073+ else
1074+ Array (
1075+ inputVars.dims.modelDimensions.slice(0 , axis).product,
1076+ inputVars.dims.modelDimensions
1077+ .slice(axis, inputVars.dims.order)
1078+ .product
1079+ )
1080+
1081+ doEmitReshape(context, flattenProto, inputVars, shape)
1082+ }
1083+
1084+ private def doEmitReshape (
1085+ context : EmitContext ,
1086+ nodeProto : NodeProto ,
1087+ inputVars : MemoryObject ,
1088+ shape : Array [Int ]
1089+ ): Unit = {
1090+ val inputDims = inputVars.dims
1091+
10211092 var pixelDims = VarsDimensions (1 , arch.arraySize)
10221093 val outputDims = VarsDimensions (Shape (if (shape.exists(_ == - 1 )) {
10231094 val inferred = inputDims.sizeScalars / shape.filter(_ != - 1 ).product
@@ -1082,7 +1153,7 @@ class OnnxFrontend(
10821153 val outputNames =
10831154 if (groupedByOffsetPairs.size > 1 ) {
10841155 val adjustedOutputTemp = context.mm.allocateTempObject(
1085- reshapeProto .output(0 ),
1156+ nodeProto .output(0 ),
10861157 outputDims
10871158 )
10881159
@@ -1139,16 +1210,16 @@ class OnnxFrontend(
11391210 Seq (inputVars.name)
11401211
11411212 val outputVars = context.mm.blendObjects(
1142- reshapeProto .output(0 ),
1213+ nodeProto .output(0 ),
11431214 outputDims,
1144- findInterLayerOutputs(context, reshapeProto .output(0 ), None ),
1215+ findInterLayerOutputs(context, nodeProto .output(0 ), None ),
11451216 outputNames,
11461217 outputAddresses
11471218 )
11481219
11491220 if (context.graphPrinter.isDefined)
11501221 context.graphPrinter.get.printOp(
1151- reshapeProto ,
1222+ nodeProto ,
11521223 Seq (outputVars),
11531224 Seq (inputVars)
11541225 )
@@ -2111,6 +2182,24 @@ class OnnxFrontend(
21112182 context : EmitContext ,
21122183 matMulProto : NodeProto
21132184 ): MemoryObject = {
2185+ val transAAttr = getAttr(matMulProto, " transA" )
2186+ val transBAttr = getAttr(matMulProto, " transB" )
2187+
2188+ val transA = if (transAAttr.isDefined) {
2189+ require(transAAttr.get.`type`.get.isInt)
2190+ transAAttr.get.i.get.toInt
2191+ } else 0
2192+
2193+ val transB = if (transBAttr.isDefined) {
2194+ require(transBAttr.get.`type`.get.isInt)
2195+ transBAttr.get.i.get.toInt
2196+ } else 0
2197+
2198+ if (transA != 0 )
2199+ throw new CompilerException (
2200+ s " Gemm with transposed input A is not supported "
2201+ )
2202+
21142203 context.mm.addPendingConst(
21152204 matMulProto.input(1 ),
21162205 getTensorData(tensorProtos(matMulProto.input(1 )))
@@ -2127,6 +2216,7 @@ class OnnxFrontend(
21272216 matMulProto.input(1 ),
21282217 if (matMulProto.input.isDefinedAt(2 )) Some (matMulProto.input(2 ))
21292218 else None ,
2219+ transB != 0
21302220 )
21312221
21322222 val inputVars =
0 commit comments