Skip to content

Commit 07e35d6

Browse files
authored
Merge pull request #86 from tensil-ai/peter/sc-479/support-mobilenet
Support MobileNet
2 parents 20b4832 + 533c75d commit 07e35d6

File tree

18 files changed

+524
-95
lines changed

18 files changed

+524
-95
lines changed

common/src/tensil/ArchitectureDataTypeUtil.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,12 @@ object ArchitectureDataTypeUtil {
7272
val csvStream = new DataOutputStream(new FileOutputStream(csvFileName))
7373

7474
for (_ <- 0L until size) {
75-
for (_ <- 0 until arraySize)
76-
csvStream.writeBytes(s"${dataType.readFloatConst(dataStream)},")
75+
for (i <- 0 until arraySize) {
76+
csvStream.writeBytes(s"${dataType.readFloatConst(dataStream)}")
77+
78+
if (i != arraySize - 1)
79+
csvStream.writeBytes(",")
80+
}
7781

7882
csvStream.writeBytes("\r\n")
7983
}

emulator/src/tensil/tools/emulator/Main.scala

Lines changed: 102 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import java.io.{
1515

1616
import scala.reflect.ClassTag
1717
import scala.collection.mutable
18+
import scala.io.Source
1819

1920
import tensil.{
2021
Architecture,
@@ -33,11 +34,16 @@ case class Args(
3334
modelFile: File = new File("."),
3435
inputFiles: Seq[File] = Nil,
3536
outputFiles: Seq[File] = Nil,
37+
compareFiles: Seq[File] = Nil,
3638
numberOfRuns: Int = 1,
3739
localConsts: Boolean = false,
3840
)
3941

4042
object Main extends App {
43+
private val ANSI_RESET = "\u001B[0m"
44+
private val ANSI_RED = "\u001B[31m"
45+
private val ANSI_GREEN = "\u001B[32m"
46+
4147
val argParser = new scopt.OptionParser[Args]("compile") {
4248
help("help").text("Prints this usage text")
4349

@@ -58,6 +64,11 @@ object Main extends App {
5864
.action((x, c) => c.copy(outputFiles = x))
5965
.text("Output (.csv) files")
6066

67+
opt[Seq[File]]('c', "compare")
68+
.valueName("<file>, ...")
69+
.action((x, c) => c.copy(compareFiles = x))
70+
.text("Compare output (.csv) files")
71+
6172
opt[Int]('r', "number-of-runs")
6273
.valueName("<integer>")
6374
.action((x, c) => c.copy(numberOfRuns = x))
@@ -86,6 +97,7 @@ object Main extends App {
8697
model,
8798
args.inputFiles,
8899
args.outputFiles,
100+
args.compareFiles,
89101
args.numberOfRuns,
90102
args.localConsts,
91103
traceContext
@@ -97,6 +109,7 @@ object Main extends App {
97109
model,
98110
args.inputFiles,
99111
args.outputFiles,
112+
args.compareFiles,
100113
args.numberOfRuns,
101114
args.localConsts,
102115
traceContext
@@ -108,6 +121,7 @@ object Main extends App {
108121
model,
109122
args.inputFiles,
110123
args.outputFiles,
124+
args.compareFiles,
111125
args.numberOfRuns,
112126
args.localConsts,
113127
traceContext
@@ -119,6 +133,7 @@ object Main extends App {
119133
model,
120134
args.inputFiles,
121135
args.outputFiles,
136+
args.compareFiles,
122137
args.numberOfRuns,
123138
args.localConsts,
124139
traceContext
@@ -130,6 +145,7 @@ object Main extends App {
130145
model,
131146
args.inputFiles,
132147
args.outputFiles,
148+
args.compareFiles,
133149
args.numberOfRuns,
134150
args.localConsts,
135151
traceContext
@@ -145,6 +161,7 @@ object Main extends App {
145161
model: Model,
146162
inputFiles: Seq[File],
147163
outputFiles: Seq[File],
164+
compareFiles: Seq[File],
148165
numberOfRuns: Int,
149166
localConsts: Boolean,
150167
traceContext: ExecutiveTraceContext
@@ -201,7 +218,22 @@ object Main extends App {
201218
if (outputFiles.size != 0)
202219
for ((output, file) <- model.outputs.zip(outputFiles)) yield {
203220
val outputPrep = new ByteArrayOutputStream()
204-
(output, Some(file), outputPrep, new DataOutputStream(outputPrep))
221+
(
222+
output,
223+
Some((false, file)),
224+
outputPrep,
225+
new DataOutputStream(outputPrep)
226+
)
227+
}
228+
else if (compareFiles.size != 0)
229+
for ((output, file) <- model.outputs.zip(compareFiles)) yield {
230+
val outputPrep = new ByteArrayOutputStream()
231+
(
232+
output,
233+
Some((true, file)),
234+
outputPrep,
235+
new DataOutputStream(outputPrep)
236+
)
205237
}
206238
else
207239
for (output <- model.outputs) yield {
@@ -231,19 +263,79 @@ object Main extends App {
231263
trace.printTrace()
232264
}
233265

234-
for ((output, file, outputPrep, _) <- outputPreps) {
266+
for ((output, compareAndFile, outputPrep, _) <- outputPreps) {
235267
val outputStream = new DataInputStream(
236268
new ByteArrayInputStream(outputPrep.toByteArray())
237269
)
238270

239-
if (file.isDefined)
240-
ArchitectureDataTypeUtil.readToCsv(
241-
dataType,
242-
outputStream,
243-
model.arch.arraySize,
244-
output.size * numberOfRuns,
245-
file.get.getAbsolutePath()
246-
)
271+
if (compareAndFile.isDefined)
272+
if (compareAndFile.get._1) {
273+
val source = Source.fromFile(compareAndFile.get._2.getAbsolutePath())
274+
val compare =
275+
source.getLines().map(_.split(",").map(_.toFloat)).flatten.toArray
276+
source.close()
277+
278+
val result = ArchitectureDataTypeUtil.readResult(
279+
dataType,
280+
outputStream,
281+
model.arch.arraySize,
282+
output.size.toInt * numberOfRuns * model.arch.arraySize
283+
)
284+
285+
require(result.size == compare.size)
286+
287+
for (i <- 0 until numberOfRuns) {
288+
val tb = new TablePrinter(Some(s"COMPARE ${output.name}, RUN ${i}"))
289+
290+
for (j <- 0 until output.size.toInt) {
291+
val offset = (i * output.size.toInt + j) * model.arch.arraySize
292+
val resultVector =
293+
result.slice(offset, offset + model.arch.arraySize)
294+
val compareVector =
295+
compare.slice(offset, offset + model.arch.arraySize)
296+
297+
val withDeltas = resultVector
298+
.zip(compareVector)
299+
.map {
300+
case (resultScalar, compareScalar) =>
301+
(resultScalar, compareScalar, resultScalar - compareScalar)
302+
}
303+
304+
if (withDeltas.exists(t => Math.abs(t._3) > dataType.error)) {
305+
tb
306+
.addLine(
307+
TableLine(
308+
f"${j}%08d",
309+
withDeltas
310+
.grouped(8)
311+
.map(_.map({
312+
case (resultScalar, compareScalar, delta) => {
313+
val mismatch = Math.abs(delta) > dataType.error
314+
val s =
315+
if (mismatch) f"($delta%.4f)"
316+
else f"$resultScalar%.4f"
317+
val sColored = (if (mismatch) ANSI_RED
318+
else ANSI_GREEN) + s + ANSI_RESET
319+
" " * (12 - s.length()) + sColored
320+
}
321+
}).mkString)
322+
.toIterable
323+
)
324+
)
325+
}
326+
}
327+
328+
print(tb.toString())
329+
}
330+
} else {
331+
ArchitectureDataTypeUtil.readToCsv(
332+
dataType,
333+
outputStream,
334+
model.arch.arraySize,
335+
output.size * numberOfRuns,
336+
compareAndFile.get._2.getAbsolutePath()
337+
)
338+
}
247339
else {
248340
val r = ArchitectureDataTypeUtil.readResult(
249341
dataType,

sim/test/src/zynq/tcu/AXIWrapperTCUSpec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import tensil.tcu.instruction.{
3636
Configure
3737
}
3838
import tensil.mem.MemoryImplementation
39-
import tensil.tools.{ArchitectureDataTypeUtil, ResNet}
39+
import tensil.tools.{ArchitectureDataTypeUtil, Cifar}
4040
import tensil.util.divCeil
4141

4242
class AXIWrapperTCUSpec extends FunUnitSpec {
@@ -328,7 +328,7 @@ class AXIWrapperTCUSpec extends FunUnitSpec {
328328
)
329329

330330
val resultsSize =
331-
divCeil(ResNet.ClassSize, m.layout.arch.arraySize)
331+
divCeil(Cifar.ClassSize, m.layout.arch.arraySize)
332332

333333
for (l <- 0 until batchSize) {
334334

@@ -343,8 +343,8 @@ class AXIWrapperTCUSpec extends FunUnitSpec {
343343

344344
assert(
345345
ArchitectureDataTypeUtil.argMax(
346-
result.flatten.take(ResNet.ClassSize).toArray
347-
) == ResNet.GoldenClasses(k * batchSize + l)
346+
result.flatten.take(Cifar.ClassSize).toArray
347+
) == Cifar.GoldenClasses(k * batchSize + l)
348348
)
349349
}
350350
}

tools/src/tensil/tools/compiler/Frontend.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,9 @@ abstract class Frontend {
99
def traverse(outputNames: Seq[String]): Seq[String]
1010
def rewrite(program: Seq[String]): Seq[Emitter]
1111

12-
def mkConstsDimensions(shape: Shape, transpose: Boolean): MemoryDimensions
12+
def mkConstsDimensions(
13+
shape: Shape,
14+
groupSize: Option[Int],
15+
transpose: Boolean
16+
): MemoryDimensions
1317
}

tools/src/tensil/tools/compiler/HIR.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ trait HIR {
4343
outputObj: MemoryObject
4444
): Unit
4545

46+
def emitClip(
47+
inputObj: MemoryObject,
48+
minObj: MemoryObject,
49+
maxObj: MemoryObject,
50+
outputObj: MemoryObject
51+
): Unit
52+
4653
def emitLeakyRelu(
4754
inputObj: MemoryObject,
4855
alphaObj: MemoryObject,

tools/src/tensil/tools/compiler/MemoryDimensions.scala

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -167,33 +167,90 @@ class MemoryDimensions private (
167167
dimension
168168
}
169169

170-
def buildConsts(build: (Option[Int]) => Unit): Unit = {
170+
def buildConsts(
171+
sourceShape: Seq[Int],
172+
broadcast: Boolean,
173+
groupSize: Option[Int],
174+
build: (Option[Int]) => Unit
175+
): Unit = {
176+
require(sourceShape.size == order)
177+
171178
def atLayout(i: Int) =
172179
atVectors(layout(i)) * (if (i == order - 1) arraySize else 1)
173180

174181
def modelPos(layoutPos: Int*) =
175182
(0 until layoutPos.size).map(i => layoutPos(layout.indexOf(i)))
176183

184+
/**
185+
* If group size is specified the channel weights will
186+
* be transformed into a square with the side of Ci (Co == Ci).
187+
* The weights will be placed on the diagonal in square blocks
188+
* with the side of Ci / groupSize.
189+
*
190+
* For example, when groupSize == 4 and the channel
191+
* weights in the model are of size 8(Ci)x2(Co):
192+
*
193+
* XX
194+
* XX
195+
* XX
196+
* XX
197+
* XX
198+
* XX
199+
* XX
200+
* XX
201+
*
202+
* This function will transform channel weights into
203+
* 8(Ci)x8(Co) square with zero padding outside of the
204+
* diagonal:
205+
*
206+
* XX------
207+
* XX------
208+
* --XX----
209+
* --XX----
210+
* ----XX--
211+
* ----XX--
212+
* ------XX
213+
* ------XX
214+
*/
215+
216+
val shift = if (groupSize.isDefined) {
217+
require(channelsIn == channelsOut)
218+
Some(channelsIn / groupSize.get)
219+
} else None
220+
221+
def shiftPos(modelPos: Seq[Int]): Seq[Int] =
222+
if (groupSize.isDefined) {
223+
val modelPosArray = modelPos.toArray
224+
modelPosArray(channelsIndex) -= (modelPosArray(
225+
channelsOutIndex
226+
) / shift.get) * shift.get
227+
modelPosArray
228+
} else modelPos
229+
230+
def broadcastPos(modelPos: Seq[Int]): Seq[Int] =
231+
if (broadcast)
232+
modelPos.zip(sourceShape).map { case (pos, shape) => pos % shape }
233+
else modelPos
234+
177235
def modelOffset(modelPos: Seq[Int]): Option[Int] =
178236
if (
179-
dimensions.zipWithIndex
180-
.forall {
181-
case (dim, i) => modelPos(i) < dim
182-
}
237+
modelPos.zip(sourceShape).forall {
238+
case (pos, shape) => pos >= 0 && pos < shape
239+
}
183240
)
184241
Some(
185242
if (order > 0)
186243
modelPos(modelPos.size - 1) +
187244
(if (order > 1)
188245
((if (order > 2)
189246
((if (order > 3)
190-
modelPos(modelPos.size - 4) * dimensions(
247+
modelPos(modelPos.size - 4) * sourceShape(
191248
modelPos.size - 3
192249
)
193-
else 0) + modelPos(modelPos.size - 3)) * dimensions(
250+
else 0) + modelPos(modelPos.size - 3)) * sourceShape(
194251
modelPos.size - 2
195252
)
196-
else 0) + modelPos(modelPos.size - 2)) * dimensions(
253+
else 0) + modelPos(modelPos.size - 2)) * sourceShape(
197254
modelPos.size - 1
198255
)
199256
else 0)
@@ -209,13 +266,19 @@ class MemoryDimensions private (
209266
for (i2 <- 0 until atLayout(2))
210267
if (order > 3)
211268
for (i3 <- 0 until atLayout(3))
212-
build(modelOffset(modelPos(i0, i1, i2, i3)))
269+
build(
270+
modelOffset(
271+
broadcastPos(shiftPos(modelPos(i0, i1, i2, i3)))
272+
)
273+
)
213274
else
214-
build(modelOffset(modelPos(i0, i1, i2)))
275+
build(
276+
modelOffset(broadcastPos(shiftPos(modelPos(i0, i1, i2))))
277+
)
215278
else
216-
build(modelOffset(modelPos(i0, i1)))
279+
build(modelOffset(broadcastPos(shiftPos(modelPos(i0, i1)))))
217280
else
218-
build(modelOffset(modelPos(i0)))
281+
build(modelOffset(broadcastPos(shiftPos(modelPos(i0)))))
219282
}
220283

221284
def numberVectors = atVectors(numberIndex)

0 commit comments

Comments
 (0)