@@ -15,6 +15,7 @@ import java.io.{
15
15
16
16
import scala .reflect .ClassTag
17
17
import scala .collection .mutable
18
+ import scala .io .Source
18
19
19
20
import tensil .{
20
21
Architecture ,
@@ -33,11 +34,16 @@ case class Args(
33
34
modelFile : File = new File (" ." ),
34
35
inputFiles : Seq [File ] = Nil ,
35
36
outputFiles : Seq [File ] = Nil ,
37
+ compareFiles : Seq [File ] = Nil ,
36
38
numberOfRuns : Int = 1 ,
37
39
localConsts : Boolean = false ,
38
40
)
39
41
40
42
object Main extends App {
43
+ private val ANSI_RESET = " \u001B [0m"
44
+ private val ANSI_RED = " \u001B [31m"
45
+ private val ANSI_GREEN = " \u001B [32m"
46
+
41
47
val argParser = new scopt.OptionParser [Args ](" compile" ) {
42
48
help(" help" ).text(" Prints this usage text" )
43
49
@@ -58,6 +64,11 @@ object Main extends App {
58
64
.action((x, c) => c.copy(outputFiles = x))
59
65
.text(" Output (.csv) files" )
60
66
67
+ opt[Seq [File ]]('c' , " compare" )
68
+ .valueName(" <file>, ..." )
69
+ .action((x, c) => c.copy(compareFiles = x))
70
+ .text(" Compare output (.csv) files" )
71
+
61
72
opt[Int ]('r' , " number-of-runs" )
62
73
.valueName(" <integer>" )
63
74
.action((x, c) => c.copy(numberOfRuns = x))
@@ -86,6 +97,7 @@ object Main extends App {
86
97
model,
87
98
args.inputFiles,
88
99
args.outputFiles,
100
+ args.compareFiles,
89
101
args.numberOfRuns,
90
102
args.localConsts,
91
103
traceContext
@@ -97,6 +109,7 @@ object Main extends App {
97
109
model,
98
110
args.inputFiles,
99
111
args.outputFiles,
112
+ args.compareFiles,
100
113
args.numberOfRuns,
101
114
args.localConsts,
102
115
traceContext
@@ -108,6 +121,7 @@ object Main extends App {
108
121
model,
109
122
args.inputFiles,
110
123
args.outputFiles,
124
+ args.compareFiles,
111
125
args.numberOfRuns,
112
126
args.localConsts,
113
127
traceContext
@@ -119,6 +133,7 @@ object Main extends App {
119
133
model,
120
134
args.inputFiles,
121
135
args.outputFiles,
136
+ args.compareFiles,
122
137
args.numberOfRuns,
123
138
args.localConsts,
124
139
traceContext
@@ -130,6 +145,7 @@ object Main extends App {
130
145
model,
131
146
args.inputFiles,
132
147
args.outputFiles,
148
+ args.compareFiles,
133
149
args.numberOfRuns,
134
150
args.localConsts,
135
151
traceContext
@@ -145,6 +161,7 @@ object Main extends App {
145
161
model : Model ,
146
162
inputFiles : Seq [File ],
147
163
outputFiles : Seq [File ],
164
+ compareFiles : Seq [File ],
148
165
numberOfRuns : Int ,
149
166
localConsts : Boolean ,
150
167
traceContext : ExecutiveTraceContext
@@ -201,7 +218,22 @@ object Main extends App {
201
218
if (outputFiles.size != 0 )
202
219
for ((output, file) <- model.outputs.zip(outputFiles)) yield {
203
220
val outputPrep = new ByteArrayOutputStream ()
204
- (output, Some (file), outputPrep, new DataOutputStream (outputPrep))
221
+ (
222
+ output,
223
+ Some ((false , file)),
224
+ outputPrep,
225
+ new DataOutputStream (outputPrep)
226
+ )
227
+ }
228
+ else if (compareFiles.size != 0 )
229
+ for ((output, file) <- model.outputs.zip(compareFiles)) yield {
230
+ val outputPrep = new ByteArrayOutputStream ()
231
+ (
232
+ output,
233
+ Some ((true , file)),
234
+ outputPrep,
235
+ new DataOutputStream (outputPrep)
236
+ )
205
237
}
206
238
else
207
239
for (output <- model.outputs) yield {
@@ -231,19 +263,79 @@ object Main extends App {
231
263
trace.printTrace()
232
264
}
233
265
234
- for ((output, file , outputPrep, _) <- outputPreps) {
266
+ for ((output, compareAndFile , outputPrep, _) <- outputPreps) {
235
267
val outputStream = new DataInputStream (
236
268
new ByteArrayInputStream (outputPrep.toByteArray())
237
269
)
238
270
239
- if (file.isDefined)
240
- ArchitectureDataTypeUtil .readToCsv(
241
- dataType,
242
- outputStream,
243
- model.arch.arraySize,
244
- output.size * numberOfRuns,
245
- file.get.getAbsolutePath()
246
- )
271
+ if (compareAndFile.isDefined)
272
+ if (compareAndFile.get._1) {
273
+ val source = Source .fromFile(compareAndFile.get._2.getAbsolutePath())
274
+ val compare =
275
+ source.getLines().map(_.split(" ," ).map(_.toFloat)).flatten.toArray
276
+ source.close()
277
+
278
+ val result = ArchitectureDataTypeUtil .readResult(
279
+ dataType,
280
+ outputStream,
281
+ model.arch.arraySize,
282
+ output.size.toInt * numberOfRuns * model.arch.arraySize
283
+ )
284
+
285
+ require(result.size == compare.size)
286
+
287
+ for (i <- 0 until numberOfRuns) {
288
+ val tb = new TablePrinter (Some (s " COMPARE ${output.name}, RUN ${i}" ))
289
+
290
+ for (j <- 0 until output.size.toInt) {
291
+ val offset = (i * output.size.toInt + j) * model.arch.arraySize
292
+ val resultVector =
293
+ result.slice(offset, offset + model.arch.arraySize)
294
+ val compareVector =
295
+ compare.slice(offset, offset + model.arch.arraySize)
296
+
297
+ val withDeltas = resultVector
298
+ .zip(compareVector)
299
+ .map {
300
+ case (resultScalar, compareScalar) =>
301
+ (resultScalar, compareScalar, resultScalar - compareScalar)
302
+ }
303
+
304
+ if (withDeltas.exists(t => Math .abs(t._3) > dataType.error)) {
305
+ tb
306
+ .addLine(
307
+ TableLine (
308
+ f " ${j}%08d " ,
309
+ withDeltas
310
+ .grouped(8 )
311
+ .map(_.map({
312
+ case (resultScalar, compareScalar, delta) => {
313
+ val mismatch = Math .abs(delta) > dataType.error
314
+ val s =
315
+ if (mismatch) f " ( $delta%.4f) "
316
+ else f " $resultScalar%.4f "
317
+ val sColored = (if (mismatch) ANSI_RED
318
+ else ANSI_GREEN ) + s + ANSI_RESET
319
+ " " * (12 - s.length()) + sColored
320
+ }
321
+ }).mkString)
322
+ .toIterable
323
+ )
324
+ )
325
+ }
326
+ }
327
+
328
+ print(tb.toString())
329
+ }
330
+ } else {
331
+ ArchitectureDataTypeUtil .readToCsv(
332
+ dataType,
333
+ outputStream,
334
+ model.arch.arraySize,
335
+ output.size * numberOfRuns,
336
+ compareAndFile.get._2.getAbsolutePath()
337
+ )
338
+ }
247
339
else {
248
340
val r = ArchitectureDataTypeUtil .readResult(
249
341
dataType,
0 commit comments