@@ -9,7 +9,11 @@ import scala.collection.mutable
9
9
10
10
import onnx .onnx .{NodeProto , ModelProto , TensorProto , ValueInfoProto }
11
11
12
- import _root_ .tensil .tools .{CompilerException , TracepointCondition , CompilerOptions }
12
+ import _root_ .tensil .tools .{
13
+ CompilerException ,
14
+ TracepointCondition ,
15
+ CompilerOptions
16
+ }
13
17
import _root_ .tensil .tools .data .{Shape , TensorData }
14
18
import _root_ .tensil .tools .util
15
19
import _root_ .tensil .{TablePrinter , Architecture }
@@ -231,7 +235,7 @@ class OnnxFrontend(
231
235
case Nil => emitters
232
236
case nodeProto :: remainingProtos =>
233
237
nodeProto.opType.get match {
234
- case " Gemm" | " Conv" =>
238
+ case " MatMul " | " Gemm" | " Conv" =>
235
239
rewriteLayer(remainingProtos, nodeProto, emitters)
236
240
case " Reshape" =>
237
241
rewriteSimple(remainingProtos, emitReshape(_, nodeProto), emitters)
@@ -312,7 +316,7 @@ class OnnxFrontend(
312
316
* This function takes `layerStepOps`, which describes
313
317
* the pattern to which we expect nodes to adhere in order
314
318
* to form a layer. The initial and the only required node is
315
- * matched in `recursiveRewrite` to be either `Gemm` or
319
+ * matched in `recursiveRewrite` to be either `MatMul`, ` Gemm` or
316
320
* `Conv`. This node is followed by "layer step operations"
317
321
* where each step can optionally be one of the operations
318
322
* included in the set.
@@ -383,7 +387,8 @@ class OnnxFrontend(
383
387
private var layerIndex = 0
384
388
385
389
private def startLayer (nodeProtos : Seq [NodeProto ]): Scheduler = {
386
- val name = s " LAYER $layerIndex ( ${nodeProtos.map(_.name.get).mkString(" ," )}) "
390
+ val name =
391
+ s " LAYER $layerIndex ( ${nodeProtos.map(_.name.get).mkString(" ," )}) "
387
392
388
393
layerIndex += 1
389
394
@@ -441,7 +446,13 @@ class OnnxFrontend(
441
446
)
442
447
443
448
val matMulTemp =
444
- if (nodeProto.opType.get == " Gemm" )
449
+ if (nodeProto.opType.get == " MatMul" )
450
+ emitLayerMatMul(
451
+ context,
452
+ scheduler,
453
+ nodeProto
454
+ )
455
+ else if (nodeProto.opType.get == " Gemm" )
445
456
emitLayerGemm(
446
457
context,
447
458
scheduler,
@@ -2042,6 +2053,55 @@ class OnnxFrontend(
2042
2053
outputTemp
2043
2054
}
2044
2055
2056
+ private def emitLayerMatMul (
2057
+ context : EmitContext ,
2058
+ scheduler : Scheduler ,
2059
+ matMulProto : NodeProto
2060
+ ): MemoryObject = {
2061
+ context.mm.addPendingConst(
2062
+ matMulProto.input(1 ),
2063
+ getTensorData(tensorProtos(matMulProto.input(1 )))
2064
+ )
2065
+
2066
+ val (weights, bias) =
2067
+ context.mm.getOrEmitWeightsAndBiasObjects(
2068
+ matMulProto.input(1 ),
2069
+ None
2070
+ )
2071
+
2072
+ val inputVars =
2073
+ context.mm.consumeObject(matMulProto.input(0 ), Seq (matMulProto.name.get))
2074
+
2075
+ val outputTemp =
2076
+ context.mm.allocateTempObject(
2077
+ matMulProto.output(0 ),
2078
+ VarsDimensions (
2079
+ inputVars.dims.number,
2080
+ weights.dims.width
2081
+ )
2082
+ )
2083
+
2084
+ val pairs = Seq (
2085
+ MemoryOptionalInputOutputObjects (Some (inputVars), outputTemp)
2086
+ )
2087
+
2088
+ scheduler.emitMatMul(
2089
+ weights,
2090
+ bias,
2091
+ pairs
2092
+ )
2093
+
2094
+ if (graphPrinter.isDefined)
2095
+ graphPrinter.get.printOp(
2096
+ matMulProto,
2097
+ Seq (outputTemp),
2098
+ Seq (inputVars),
2099
+ Seq ((" Weights" , weights))
2100
+ )
2101
+
2102
+ outputTemp
2103
+ }
2104
+
2045
2105
private def emitLayerPool (
2046
2106
context : EmitContext ,
2047
2107
scheduler : Scheduler ,
0 commit comments