Skip to content
This repository was archived by the owner on Oct 26, 2025. It is now read-only.

Commit 20b4832

Browse files
authored
Merge pull request #85 from tensil-ai/onnx-flatten-and-gemm-transa-transb
Support ONNX Flatten and Gemm with transA, transB attributes
2 parents 35ed258 + 1f5de28 commit 20b4832

File tree

4 files changed

+137
-36
lines changed

4 files changed

+137
-36
lines changed

tools/src/tensil/tools/compiler/Frontend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ abstract class Frontend {
99
def traverse(outputNames: Seq[String]): Seq[String]
1010
def rewrite(program: Seq[String]): Seq[Emitter]
1111

12-
def mkConstsDimensions(shape: Shape): MemoryDimensions
12+
def mkConstsDimensions(shape: Shape, transpose: Boolean): MemoryDimensions
1313
}

tools/src/tensil/tools/compiler/MemoryManager.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class MemoryManager(
3131
constsStream: OutputStream,
3232
dataType: ArchitectureDataType,
3333
arch: Architecture,
34-
mkConstsDimensions: (Shape) => MemoryDimensions,
34+
mkConstsDimensions: (Shape, Boolean) => MemoryDimensions,
3535
traceContext: TraceContext,
3636
tracepointConditions: Seq[TracepointCondition]
3737
) {
@@ -192,7 +192,8 @@ class MemoryManager(
192192

193193
def getOrEmitWeightsAndBiasObjects(
194194
weightsName: String,
195-
biasName: Option[String]
195+
biasName: Option[String],
196+
transposeWeights: Boolean = false
196197
): (MemoryObject, Option[MemoryObject]) = {
197198
val biasObject =
198199
if (biasName.isDefined) {
@@ -217,7 +218,11 @@ class MemoryManager(
217218
Nil
218219
)
219220
else
220-
mkConstObject(resolvedWeightsName, constsDataStream)
221+
mkConstObject(
222+
resolvedWeightsName,
223+
constsDataStream,
224+
transpose = transposeWeights
225+
)
221226

222227
(weightsObject, biasObject)
223228
}
@@ -230,7 +235,7 @@ class MemoryManager(
230235
if (freeableAllocator.hasObject(name))
231236
freeableAllocator.consumeObject(name, Nil)
232237
else
233-
mkConstObject(name, constsDataStream, broadcastDims)
238+
mkConstObject(name, constsDataStream, broadcastDims = broadcastDims)
234239

235240
constObject
236241
}
@@ -256,7 +261,8 @@ class MemoryManager(
256261
private def mkConstObject(
257262
name: String,
258263
stream: DataOutputStream,
259-
broadcastDims: Option[MemoryDimensions] = None
264+
broadcastDims: Option[MemoryDimensions] = None,
265+
transpose: Boolean = false
260266
): MemoryObject = {
261267
val tensorData = pendingFloatConsts(name)
262268
val tensorSize = tensorData.shape.product
@@ -269,7 +275,7 @@ class MemoryManager(
269275
throw new CompilerException("Only scalar broadcast is supported")
270276

271277
(broadcastDims.get, true)
272-
} else (mkConstsDimensions(tensorData.shape), false)
278+
} else (mkConstsDimensions(tensorData.shape, transpose), false)
273279

274280
dims.buildConsts((offset: Option[Int]) =>
275281
dataType.writeFloatConst(

tools/src/tensil/tools/compiler/OnnxFrontend.scala

Lines changed: 117 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ class OnnxFrontend(
3232
graphStream: Option[OutputStream],
3333
options: CompilerOptions
3434
) extends Frontend {
35+
val opsetVersion = modelProto.opsetImport(0).version.get
36+
37+
val MinOpsetVersion = 9
38+
val MaxOpsetVersion = 10
39+
40+
if (opsetVersion < MinOpsetVersion || opsetVersion > MaxOpsetVersion)
41+
throw new CompilerException(
42+
s"ONNX opset ${opsetVersion} is not supported. Supported range is [${MinOpsetVersion}, ${MaxOpsetVersion}]."
43+
)
44+
3545
private object VarsDimensions {
3646
def apply(
3747
number: Int,
@@ -89,13 +99,29 @@ class OnnxFrontend(
8999
)
90100

91101
def apply(height: Int, width: Int): MemoryDimensions =
92-
MemoryDimensions(
93-
arraySize = arch.arraySize,
94-
"HW",
95-
"HW",
96-
isWeights = true,
97-
dimensions = Vector(height, width)
98-
)
102+
apply(height, width, false)
103+
104+
def apply(
105+
height: Int,
106+
width: Int,
107+
transpose: Boolean
108+
): MemoryDimensions =
109+
if (transpose)
110+
MemoryDimensions(
111+
arraySize = arch.arraySize,
112+
"WH",
113+
"HW",
114+
isWeights = true,
115+
dimensions = Vector(height, width)
116+
)
117+
else
118+
MemoryDimensions(
119+
arraySize = arch.arraySize,
120+
"HW",
121+
"HW",
122+
isWeights = true,
123+
dimensions = Vector(height, width)
124+
)
99125

100126
def apply(
101127
channelsOut: Int,
@@ -111,21 +137,27 @@ class OnnxFrontend(
111137
dimensions = Vector(channelsOut, channelsIn, height, width)
112138
)
113139

114-
def apply(shape: Shape): MemoryDimensions =
115-
if (shape.size == 1)
116-
ConstsDimensions(shape(0))
117-
else if (shape.size == 2)
118-
ConstsDimensions(shape(0), shape(1))
119-
else if (shape.size == 4)
120-
ConstsDimensions(shape(0), shape(1), shape(2), shape(3))
121-
else
140+
def apply(shape: Shape, transpose: Boolean): MemoryDimensions =
141+
if (transpose && shape.size != 2)
122142
throw new CompilerException(
123-
s"Consts tensor shape of ${shape} is not supported"
143+
s"Transposing consts is supported for 2D tensors only"
124144
)
145+
else {
146+
if (shape.size == 1)
147+
ConstsDimensions(shape(0))
148+
else if (shape.size == 2)
149+
ConstsDimensions(shape(0), shape(1), transpose)
150+
else if (shape.size == 4)
151+
ConstsDimensions(shape(0), shape(1), shape(2), shape(3))
152+
else
153+
throw new CompilerException(
154+
s"Consts tensor shape of ${shape} is not supported"
155+
)
156+
}
125157
}
126158

127-
def mkConstsDimensions(shape: Shape): MemoryDimensions =
128-
ConstsDimensions(shape)
159+
def mkConstsDimensions(shape: Shape, transpose: Boolean): MemoryDimensions =
160+
ConstsDimensions(shape, transpose)
129161

130162
private val nodeProtos = mutable.Map.empty[String, NodeProto]
131163
private val tensorProtos = mutable.Map.empty[String, TensorProto]
@@ -346,6 +378,8 @@ class OnnxFrontend(
346378
rewriteLayer(remainingProtos, nodeProto, emitters)
347379
case "Reshape" =>
348380
rewriteSimple(remainingProtos, emitReshape(_, nodeProto), emitters)
381+
case "Flatten" =>
382+
rewriteSimple(remainingProtos, emitFlatten(_, nodeProto), emitters)
349383
case "Split" =>
350384
rewriteSimple(remainingProtos, emitSplit(_, nodeProto), emitters)
351385
case "Concat" =>
@@ -1007,17 +1041,54 @@ class OnnxFrontend(
10071041
context: EmitContext,
10081042
reshapeProto: NodeProto
10091043
): Unit = {
1010-
val inputVars =
1011-
context.mm
1012-
.consumeObject(reshapeProto.input(0), Seq(reshapeProto.name.get))
1013-
val inputDims = inputVars.dims
1014-
10151044
val shape = getTensorData(tensorProtos(reshapeProto.input(1)))
10161045
.asInstanceOf[TensorData[Long]]
10171046
.as1D
10181047
.map(_.toInt)
10191048
.toArray
10201049

1050+
val inputVars =
1051+
context.mm
1052+
.consumeObject(reshapeProto.input(0), Seq(reshapeProto.name.get))
1053+
1054+
doEmitReshape(context, reshapeProto, inputVars, shape)
1055+
}
1056+
1057+
private def emitFlatten(
1058+
context: EmitContext,
1059+
flattenProto: NodeProto
1060+
): Unit = {
1061+
val axisAttr = getAttr(flattenProto, "axis").get
1062+
1063+
require(axisAttr.`type`.get.isInt)
1064+
1065+
val axis = axisAttr.i.get.toInt
1066+
1067+
val inputVars =
1068+
context.mm
1069+
.consumeObject(flattenProto.input(0), Seq(flattenProto.name.get))
1070+
1071+
val shape =
1072+
if (axis == 0) Array(1, inputVars.dims.modelDimensions.product)
1073+
else
1074+
Array(
1075+
inputVars.dims.modelDimensions.slice(0, axis).product,
1076+
inputVars.dims.modelDimensions
1077+
.slice(axis, inputVars.dims.order)
1078+
.product
1079+
)
1080+
1081+
doEmitReshape(context, flattenProto, inputVars, shape)
1082+
}
1083+
1084+
private def doEmitReshape(
1085+
context: EmitContext,
1086+
nodeProto: NodeProto,
1087+
inputVars: MemoryObject,
1088+
shape: Array[Int]
1089+
): Unit = {
1090+
val inputDims = inputVars.dims
1091+
10211092
var pixelDims = VarsDimensions(1, arch.arraySize)
10221093
val outputDims = VarsDimensions(Shape(if (shape.exists(_ == -1)) {
10231094
val inferred = inputDims.sizeScalars / shape.filter(_ != -1).product
@@ -1082,7 +1153,7 @@ class OnnxFrontend(
10821153
val outputNames =
10831154
if (groupedByOffsetPairs.size > 1) {
10841155
val adjustedOutputTemp = context.mm.allocateTempObject(
1085-
reshapeProto.output(0),
1156+
nodeProto.output(0),
10861157
outputDims
10871158
)
10881159

@@ -1139,16 +1210,16 @@ class OnnxFrontend(
11391210
Seq(inputVars.name)
11401211

11411212
val outputVars = context.mm.blendObjects(
1142-
reshapeProto.output(0),
1213+
nodeProto.output(0),
11431214
outputDims,
1144-
findInterLayerOutputs(context, reshapeProto.output(0), None),
1215+
findInterLayerOutputs(context, nodeProto.output(0), None),
11451216
outputNames,
11461217
outputAddresses
11471218
)
11481219

11491220
if (context.graphPrinter.isDefined)
11501221
context.graphPrinter.get.printOp(
1151-
reshapeProto,
1222+
nodeProto,
11521223
Seq(outputVars),
11531224
Seq(inputVars)
11541225
)
@@ -2111,6 +2182,24 @@ class OnnxFrontend(
21112182
context: EmitContext,
21122183
matMulProto: NodeProto
21132184
): MemoryObject = {
2185+
val transAAttr = getAttr(matMulProto, "transA")
2186+
val transBAttr = getAttr(matMulProto, "transB")
2187+
2188+
val transA = if (transAAttr.isDefined) {
2189+
require(transAAttr.get.`type`.get.isInt)
2190+
transAAttr.get.i.get.toInt
2191+
} else 0
2192+
2193+
val transB = if (transBAttr.isDefined) {
2194+
require(transBAttr.get.`type`.get.isInt)
2195+
transBAttr.get.i.get.toInt
2196+
} else 0
2197+
2198+
if (transA != 0)
2199+
throw new CompilerException(
2200+
s"Gemm with transposed input A is not supported"
2201+
)
2202+
21142203
context.mm.addPendingConst(
21152204
matMulProto.input(1),
21162205
getTensorData(tensorProtos(matMulProto.input(1)))
@@ -2127,6 +2216,7 @@ class OnnxFrontend(
21272216
matMulProto.input(1),
21282217
if (matMulProto.input.isDefinedAt(2)) Some(matMulProto.input(2))
21292218
else None,
2219+
transB != 0
21302220
)
21312221

21322222
val inputVars =

tools/src/tensil/tools/compiler/TfFrontend.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,13 @@ class TfFrontend(
124124
)
125125
}
126126

127-
def mkConstsDimensions(shape: Shape): MemoryDimensions =
128-
ConstsDimensions(shape)
127+
def mkConstsDimensions(shape: Shape, transpose: Boolean): MemoryDimensions =
128+
if (transpose)
129+
throw new CompilerException(
130+
s"Transposing consts is not supported"
131+
)
132+
else
133+
ConstsDimensions(shape)
129134

130135
private val nodeDefs = mutable.Map.empty[String, NodeDef]
131136
private val nodeEdges = mutable.Map.empty[String, Seq[String]]

0 commit comments

Comments
 (0)