Skip to content

Commit d49f0d4

Browse files
authored
Merge pull request #51 from tensil-ai/peter/sc-448/support-matmul-in-onnx-frontend
Support MatMul in ONNX frontend
2 parents 8f17bf2 + 1bb6fd7 commit d49f0d4

File tree

1 file changed

+65
-5
lines changed

1 file changed

+65
-5
lines changed

tools/src/tensil/tools/compiler/OnnxFrontend.scala

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ import scala.collection.mutable
99

1010
import onnx.onnx.{NodeProto, ModelProto, TensorProto, ValueInfoProto}
1111

12-
import _root_.tensil.tools.{CompilerException, TracepointCondition, CompilerOptions}
12+
import _root_.tensil.tools.{
13+
CompilerException,
14+
TracepointCondition,
15+
CompilerOptions
16+
}
1317
import _root_.tensil.tools.data.{Shape, TensorData}
1418
import _root_.tensil.tools.util
1519
import _root_.tensil.{TablePrinter, Architecture}
@@ -231,7 +235,7 @@ class OnnxFrontend(
231235
case Nil => emitters
232236
case nodeProto :: remainingProtos =>
233237
nodeProto.opType.get match {
234-
case "Gemm" | "Conv" =>
238+
case "MatMul" | "Gemm" | "Conv" =>
235239
rewriteLayer(remainingProtos, nodeProto, emitters)
236240
case "Reshape" =>
237241
rewriteSimple(remainingProtos, emitReshape(_, nodeProto), emitters)
@@ -312,7 +316,7 @@ class OnnxFrontend(
312316
* This function takes `layerStepOps`, which describes
313317
* the pattern to which we expect nodes to adhere in order
314318
* to form a layer. The initial and the only required node is
315-
* matched in `recursiveRewrite` to be either `Gemm` or
319+
* matched in `recursiveRewrite` to be either `MatMul`, `Gemm` or
316320
* `Conv`. This node is followed by "layer step operations"
317321
* where each step can optionally be one of the operations
318322
* included in the set.
@@ -383,7 +387,8 @@ class OnnxFrontend(
383387
private var layerIndex = 0
384388

385389
private def startLayer(nodeProtos: Seq[NodeProto]): Scheduler = {
386-
val name = s"LAYER $layerIndex (${nodeProtos.map(_.name.get).mkString(",")})"
390+
val name =
391+
s"LAYER $layerIndex (${nodeProtos.map(_.name.get).mkString(",")})"
387392

388393
layerIndex += 1
389394

@@ -441,7 +446,13 @@ class OnnxFrontend(
441446
)
442447

443448
val matMulTemp =
444-
if (nodeProto.opType.get == "Gemm")
449+
if (nodeProto.opType.get == "MatMul")
450+
emitLayerMatMul(
451+
context,
452+
scheduler,
453+
nodeProto
454+
)
455+
else if (nodeProto.opType.get == "Gemm")
445456
emitLayerGemm(
446457
context,
447458
scheduler,
@@ -2042,6 +2053,55 @@ class OnnxFrontend(
20422053
outputTemp
20432054
}
20442055

2056+
private def emitLayerMatMul(
2057+
context: EmitContext,
2058+
scheduler: Scheduler,
2059+
matMulProto: NodeProto
2060+
): MemoryObject = {
2061+
context.mm.addPendingConst(
2062+
matMulProto.input(1),
2063+
getTensorData(tensorProtos(matMulProto.input(1)))
2064+
)
2065+
2066+
val (weights, bias) =
2067+
context.mm.getOrEmitWeightsAndBiasObjects(
2068+
matMulProto.input(1),
2069+
None
2070+
)
2071+
2072+
val inputVars =
2073+
context.mm.consumeObject(matMulProto.input(0), Seq(matMulProto.name.get))
2074+
2075+
val outputTemp =
2076+
context.mm.allocateTempObject(
2077+
matMulProto.output(0),
2078+
VarsDimensions(
2079+
inputVars.dims.number,
2080+
weights.dims.width
2081+
)
2082+
)
2083+
2084+
val pairs = Seq(
2085+
MemoryOptionalInputOutputObjects(Some(inputVars), outputTemp)
2086+
)
2087+
2088+
scheduler.emitMatMul(
2089+
weights,
2090+
bias,
2091+
pairs
2092+
)
2093+
2094+
if (graphPrinter.isDefined)
2095+
graphPrinter.get.printOp(
2096+
matMulProto,
2097+
Seq(outputTemp),
2098+
Seq(inputVars),
2099+
Seq(("Weights", weights))
2100+
)
2101+
2102+
outputTemp
2103+
}
2104+
20452105
private def emitLayerPool(
20462106
context: EmitContext,
20472107
scheduler: Scheduler,

0 commit comments

Comments
 (0)