Skip to content

Commit 66cdb49

Browse files
committed
Clarify HIR
1 parent 94aad03 commit 66cdb49

File tree

4 files changed

+119
-25
lines changed

4 files changed

+119
-25
lines changed

tools/src/tensil/tools/compiler/HIR.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@ trait HIR {
2121
outputObj: MemoryObject
2222
): Unit
2323

24-
def emitSIMDOp(
25-
simdOp: Int,
24+
def emitSub(
25+
input0Obj: MemoryObject,
26+
input1Obj: MemoryObject,
27+
outputObj: MemoryObject
28+
): Unit
29+
30+
def emitMul(
2631
input0Obj: MemoryObject,
2732
input1Obj: MemoryObject,
2833
outputObj: MemoryObject

tools/src/tensil/tools/compiler/OnnxFrontend.scala

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,10 @@ class OnnxFrontend(
258258
rewriteSimple(remainingProtos, emitActivate(_, nodeProto), emitters)
259259
case "Add" =>
260260
rewriteSimple(remainingProtos, emitAdd(_, nodeProto), emitters)
261-
case "Sub" | "Mul" =>
262-
rewriteSimple(remainingProtos, emitSIMDOp(_, nodeProto), emitters)
261+
case "Sub" =>
262+
rewriteSimple(remainingProtos, emitSub(_, nodeProto), emitters)
263+
case "Mul" =>
264+
rewriteSimple(remainingProtos, emitMul(_, nodeProto), emitters)
263265
case "Transpose" =>
264266
rewriteSimple(
265267
remainingProtos,
@@ -1835,7 +1837,7 @@ class OnnxFrontend(
18351837
finishLayer(scheduler, context)
18361838
}
18371839

1838-
private def emitSIMDOp(
1840+
private def emitSub(
18391841
context: EmitContext,
18401842
nodeProto: NodeProto
18411843
): EmitResult = {
@@ -1869,7 +1871,54 @@ class OnnxFrontend(
18691871
scheduler.emitLoad(input1VarsOrConst, input1Temp)
18701872

18711873
val outputTemp =
1872-
emitLayerSIMDOp(context, scheduler, nodeProto, input0Temp, input1Temp)
1874+
emitLayerSub(context, scheduler, nodeProto, input0Temp, input1Temp)
1875+
1876+
val outputVars = context.mm.allocateVarsObject(
1877+
outputTemp.name,
1878+
outputTemp.dims,
1879+
findInterLayerOutputs(context, nodeProto.output(0), None)
1880+
)
1881+
1882+
scheduler.emitSave(outputTemp, outputVars)
1883+
1884+
finishLayer(scheduler, context)
1885+
}
1886+
1887+
private def emitMul(
1888+
context: EmitContext,
1889+
nodeProto: NodeProto
1890+
): EmitResult = {
1891+
val scheduler = startLayer(Seq(nodeProto))
1892+
1893+
val input0Vars =
1894+
context.mm.consumeObject(nodeProto.input(0), Seq(nodeProto.name.get))
1895+
1896+
val input0Temp = context.mm.allocateTempObject(
1897+
input0Vars.name,
1898+
input0Vars.dims
1899+
)
1900+
1901+
scheduler.emitLoad(input0Vars, input0Temp)
1902+
1903+
val input1VarsOrConst = if (tensorProtos.isDefinedAt(nodeProto.input(1))) {
1904+
context.mm.addPendingConst(
1905+
nodeProto.input(1),
1906+
getTensorData(tensorProtos(nodeProto.input(1)))
1907+
)
1908+
1909+
context.mm.getOrEmitConstObject(nodeProto.input(1), Some(input0Temp.dims))
1910+
} else
1911+
context.mm.consumeObject(nodeProto.input(1), Seq(nodeProto.name.get))
1912+
1913+
val input1Temp = context.mm.allocateTempObject(
1914+
input1VarsOrConst.name,
1915+
input1VarsOrConst.dims
1916+
)
1917+
1918+
scheduler.emitLoad(input1VarsOrConst, input1Temp)
1919+
1920+
val outputTemp =
1921+
emitLayerMul(context, scheduler, nodeProto, input0Temp, input1Temp)
18731922

18741923
val outputVars = context.mm.allocateVarsObject(
18751924
outputTemp.name,
@@ -2574,7 +2623,7 @@ class OnnxFrontend(
25742623
outputTemp
25752624
}
25762625

2577-
private def emitLayerSIMDOp(
2626+
private def emitLayerSub(
25782627
context: EmitContext,
25792628
scheduler: Scheduler,
25802629
nodeProto: NodeProto,
@@ -2586,13 +2635,35 @@ class OnnxFrontend(
25862635
input0Temp.dims
25872636
)
25882637

2589-
val op = nodeProto.opType.get match {
2590-
case "Sub" => SIMDOp.Subtract
2591-
case "Mul" => SIMDOp.Multiply
2592-
}
2638+
scheduler.emitSub(
2639+
input0Temp,
2640+
input1Temp,
2641+
outputTemp
2642+
)
2643+
2644+
if (graphPrinter.isDefined)
2645+
graphPrinter.get.printOp(
2646+
nodeProto,
2647+
Seq(outputTemp),
2648+
Seq(input0Temp, input1Temp)
2649+
)
2650+
2651+
outputTemp
2652+
}
2653+
2654+
private def emitLayerMul(
2655+
context: EmitContext,
2656+
scheduler: Scheduler,
2657+
nodeProto: NodeProto,
2658+
input0Temp: MemoryObject,
2659+
input1Temp: MemoryObject,
2660+
): MemoryObject = {
2661+
val outputTemp = context.mm.allocateTempObject(
2662+
nodeProto.output(0),
2663+
input0Temp.dims
2664+
)
25932665

2594-
scheduler.emitSIMDOp(
2595-
op,
2666+
scheduler.emitMul(
25962667
input0Temp,
25972668
input1Temp,
25982669
outputTemp

tools/src/tensil/tools/compiler/Scheduler.scala

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,7 @@ class Scheduler(
192192
}
193193
}
194194

195-
def emitSIMDOp(
196-
op: Int,
195+
def emitSub(
197196
input0Obj: MemoryObject,
198197
input1Obj: MemoryObject,
199198
outputObj: MemoryObject
@@ -203,8 +202,27 @@ class Scheduler(
203202
for (i <- 0 until outputObj.dims.sizeVectors) {
204203
val output = outputObj.mkAddress(i)
205204
require(!tempOutputNodes.contains(output))
206-
tempOutputNodes(output) = new SIMDNode(
207-
op,
205+
tempOutputNodes(output) = new BinarySIMDNode(
206+
SIMDOp.Subtract,
207+
input0Obj.mkAddress(i),
208+
input1Obj.mkAddress(i),
209+
output
210+
)
211+
}
212+
}
213+
214+
def emitMul(
215+
input0Obj: MemoryObject,
216+
input1Obj: MemoryObject,
217+
outputObj: MemoryObject
218+
): Unit = {
219+
require(input0Obj.dims.sizeVectors == outputObj.dims.sizeVectors)
220+
require(input1Obj.dims.sizeVectors == outputObj.dims.sizeVectors)
221+
for (i <- 0 until outputObj.dims.sizeVectors) {
222+
val output = outputObj.mkAddress(i)
223+
require(!tempOutputNodes.contains(output))
224+
tempOutputNodes(output) = new BinarySIMDNode(
225+
SIMDOp.Multiply,
208226
input0Obj.mkAddress(i),
209227
input1Obj.mkAddress(i),
210228
output
@@ -1279,15 +1297,15 @@ class Scheduler(
12791297
addRollup.finalEmit()
12801298

12811299
for (
1282-
subNode <-
1300+
binarySIMDNode <-
12831301
nodes
1284-
.filter(_.isInstanceOf[SIMDNode])
1285-
.map(_.asInstanceOf[SIMDNode])
1302+
.filter(_.isInstanceOf[BinarySIMDNode])
1303+
.map(_.asInstanceOf[BinarySIMDNode])
12861304
.sortBy(_.output)
12871305
) {
1288-
val outputAccAddress = allocateAccumulator(subNode.output)
1289-
val input0AccAddress = locateAccumulator(subNode.input0)
1290-
val input1AccAddress = locateAccumulator(subNode.input1)
1306+
val outputAccAddress = allocateAccumulator(binarySIMDNode.output)
1307+
val input0AccAddress = locateAccumulator(binarySIMDNode.input0)
1308+
val input1AccAddress = locateAccumulator(binarySIMDNode.input1)
12911309

12921310
computeLir.emitSIMD(
12931311
accumulate = false,
@@ -1301,7 +1319,7 @@ class Scheduler(
13011319

13021320
computeLir.emitSIMD(
13031321
accumulate = false,
1304-
subNode.op,
1322+
binarySIMDNode.op,
13051323
SIMDSource.Register1,
13061324
SIMDSource.Input,
13071325
SIMDDestination.Output,

tools/src/tensil/tools/compiler/scheduler/Node.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class AddNode(
9494
if (input1.tag == MemoryTag.Consts) Seq(input1) else Nil
9595
}
9696

97-
class SIMDNode(
97+
class BinarySIMDNode(
9898
val op: Int,
9999
val input0: MemoryAddress,
100100
val input1: MemoryAddress,

0 commit comments

Comments
 (0)