Skip to content

Commit 2d13a23

Browse files
committed
Test with speech commands model
1 parent 66cdb49 commit 2d13a23

File tree

3 files changed

+141
-3
lines changed

3 files changed

+141
-3
lines changed

tools/test/src/tools/CompilerSpec.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2536,4 +2536,78 @@ class CompilerSpec extends FlatSpec {
25362536
traceContext = traceContext
25372537
)
25382538
}
2539+
2540+
val SpeechCommandsFp16bp8Architecture = Architecture.mkWithDefaults(
2541+
dataType = ArchitectureDataType.FP16BP8,
2542+
arraySize = 8,
2543+
accumulatorDepth = Kibi * 2,
2544+
localDepth = Kibi * 8,
2545+
stride0Depth = 8,
2546+
stride1Depth = 8,
2547+
)
2548+
2549+
val SpeechCommandsFloatArchitecture = Architecture.mkWithDefaults(
2550+
dataType = ArchitectureDataType.FLOAT32,
2551+
arraySize = 8,
2552+
accumulatorDepth = Kibi * 2,
2553+
localDepth = Kibi * 8,
2554+
stride0Depth = 8,
2555+
stride1Depth = 8,
2556+
)
2557+
2558+
it should "Compile ONNX fixed16bp8 SpeechCommands" in {
2559+
val name = "speech_commands_fixed16bp8_onnx"
2560+
val traceContext = new ExecutiveTraceContext()
2561+
val options = CompilerOptions(
2562+
arch = SpeechCommandsFp16bp8Architecture,
2563+
printSummary = true,
2564+
printLayersSummary = true,
2565+
printGraphFileName = Some(s"${name}.dot"),
2566+
tracepointConditions = List(
2567+
TracepointCondition(MemoryTag.Vars, "dense_3")
2568+
)
2569+
)
2570+
2571+
Compiler.compile(
2572+
name,
2573+
s"${Models}/speech_commands.onnx",
2574+
List("dense_3"),
2575+
options,
2576+
traceContext
2577+
)
2578+
2579+
EmulatorHelper.test(
2580+
name,
2581+
inputBatchSize = options.inputShapes.batchSize,
2582+
traceContext = traceContext
2583+
)
2584+
}
2585+
2586+
it should "Compile ONNX float SpeechCommands" taggedAs (Slow) in {
2587+
val name = "speech_commands_fixed16bp8_onnx"
2588+
val traceContext = new ExecutiveTraceContext()
2589+
val options = CompilerOptions(
2590+
arch = SpeechCommandsFloatArchitecture,
2591+
printSummary = true,
2592+
printLayersSummary = true,
2593+
printGraphFileName = Some(s"${name}.dot"),
2594+
tracepointConditions = List(
2595+
TracepointCondition(MemoryTag.Vars, "dense_3")
2596+
)
2597+
)
2598+
2599+
Compiler.compile(
2600+
name,
2601+
s"${Models}/speech_commands.onnx",
2602+
List("dense_3"),
2603+
options,
2604+
traceContext
2605+
)
2606+
2607+
EmulatorHelper.test(
2608+
name,
2609+
inputBatchSize = options.inputShapes.batchSize,
2610+
traceContext = traceContext
2611+
)
2612+
}
25392613
}

tools/test/src/tools/EmulatorHelper.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ object EmulatorHelper {
105105
arraySize,
106106
count
107107
)
108-
} else
108+
} else if (modelName.startsWith("speech_commands"))
109+
SpeechCommands.prepareInputStream(dataType, arraySize, count)
110+
else
109111
throw new IllegalArgumentException()
110112

111113
private def assertOutput(
@@ -182,15 +184,20 @@ object EmulatorHelper {
182184
val yoloPattern(yoloSize) = modelName
183185
TinyYolo(yoloSize.toInt, onnx = modelName.endsWith("onnx"))
184186
.assertOutput(outputName, dataType, arraySize, bytes)
185-
} else
187+
} else if (modelName.startsWith("speech_commands"))
188+
SpeechCommands.assertOutput(dataType, arraySize, bytes, count)
189+
else
186190
throw new IllegalArgumentException()
187191

188192
private def minimumInputCount(modelName: String): Int =
189193
if (modelName.startsWith("xor"))
190194
4
191195
else if (modelName.startsWith("resnet50v2"))
192196
3
193-
else if (modelName.startsWith("resnet20v2"))
197+
else if (
198+
modelName
199+
.startsWith("resnet20v2") || modelName.startsWith("speech_commands")
200+
)
194201
10
195202
else
196203
1
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* SPDX-License-Identifier: Apache-2.0 */
2+
/* Copyright © 2019-2022 Tensil AI Company */
3+
4+
package tensil.tools
5+
6+
import java.io._
7+
import scala.reflect.ClassTag
8+
import scala.io.Source
9+
import tensil.tools.emulator.{Emulator, ExecutiveTraceContext}
10+
import tensil.ArchitectureDataType
11+
12+
object SpeechCommands {
13+
def prepareInputStream(
14+
dataType: ArchitectureDataType,
15+
arraySize: Int,
16+
count: Int
17+
): InputStream = {
18+
val fileName = s"./models/data/speech_commands_input_${count}x${arraySize}.csv"
19+
20+
val inputPrep = new ByteArrayOutputStream()
21+
val inputPrepDataStream = new DataOutputStream(inputPrep)
22+
23+
ArchitectureDataTypeUtil.writeFromCsv(
24+
dataType,
25+
inputPrepDataStream,
26+
arraySize,
27+
fileName
28+
)
29+
30+
new ByteArrayInputStream(inputPrep.toByteArray())
31+
}
32+
33+
val ClassSize = 8
34+
val GoldenClasses: Array[Int] = Array(
35+
7, 2, 4, 6, 1, 6, 6, 5, 0, 5
36+
)
37+
38+
def assertOutput(
39+
dataType: ArchitectureDataType,
40+
arraySize: Int,
41+
bytes: Array[Byte],
42+
count: Int
43+
): Unit = {
44+
val output =
45+
new DataInputStream(new ByteArrayInputStream(bytes))
46+
47+
for (i <- 0 until count) {
48+
val expected = GoldenClasses(i)
49+
val actual = ArchitectureDataTypeUtil.argMax(
50+
ArchitectureDataTypeUtil.readResult(dataType, output, arraySize, ClassSize)
51+
)
52+
53+
println(s"expected=$expected, actual=$actual")
54+
assert(expected == actual)
55+
}
56+
}
57+
}

0 commit comments

Comments
 (0)