Skip to content

Commit 1ad44af

Browse files
authored
Improve perofrmnace of output (#93)
- Introduced `OutputBenchmark` - Removed unused `counter` - Follow benchmark results
1 parent 18894f0 commit 1ad44af

File tree

3 files changed

+1085
-101
lines changed

3 files changed

+1085
-101
lines changed
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
/*
2+
* scala-blake3 - highly optimized blake3 implementation for scala, scala-js and scala-native.
3+
*
4+
* Written in 2020, 2021 by Kirill A. Korinsky <[email protected]>
5+
*
6+
* Supported since 2022 by Kcrypt Lab UG <[email protected]>
7+
*
8+
* This work is released into the public domain with CC0 1.0.
9+
* Alternatively, it is licensed under the Apache License 2.0.
10+
*/
11+
12+
package pt.kcry.blake3
13+
package benchmark
14+
15+
import CompressRounds._
16+
17+
import org.openjdk.jmh.annotations._
18+
19+
import scala.util.Random
20+
21+
@State(Scope.Benchmark)
22+
class OutputBenchmark {
23+
private val bytes = new Array[Byte](CHUNK_LEN)
24+
25+
private val blockWords = new Array[Int](BLOCK_LEN_WORDS)
26+
private val tmpBlockWords = new Array[Int](BLOCK_LEN_WORDS)
27+
private val inputChainingValue = new Array[Int](BLOCK_LEN_WORDS)
28+
29+
private var blockLen: Int = 0
30+
private var flags: Int = 0
31+
32+
@Setup
33+
def setup(): Unit = {
34+
val random = new Random()
35+
for (i <- 0 until BLOCK_LEN_WORDS) {
36+
blockWords(i) = random.nextInt()
37+
tmpBlockWords(i) = random.nextInt()
38+
inputChainingValue(i) = random.nextInt()
39+
}
40+
41+
blockLen = random.nextInt()
42+
flags = random.nextInt()
43+
}
44+
45+
@Benchmark
46+
def hugeChunk(): Unit = hugeChunkImpl(bytes, 0, CHUNK_LEN)
47+
48+
@Benchmark
49+
def subLoop(): Unit = subLoopImpl(bytes, 0, CHUNK_LEN)
50+
51+
@Benchmark
52+
def subLoopInline(): Unit = subLoopInlineImpl(bytes, 0, CHUNK_LEN)
53+
54+
private def subLoopImpl(out: Array[Byte], off: Int, len: Int): Unit = {
55+
var outputBlockCounter = 0
56+
var pos = off
57+
val lim = off + len
58+
59+
val blockLenWords = BLOCK_LEN_WORDS
60+
val words = tmpBlockWords
61+
val flags = this.flags | ROOT
62+
63+
while (pos < lim) {
64+
compressRounds(words, blockWords, inputChainingValue, outputBlockCounter,
65+
blockLen, flags)
66+
67+
var wordIdx = 0
68+
while (wordIdx < blockLenWords && pos < lim) {
69+
val word = words(wordIdx)
70+
lim - pos match {
71+
case x if x <= 0 =>
72+
throw new RuntimeException(
73+
s"x: $x; pos: $pos; lim: $lim; wordIdx: $wordIdx; off: $off; len: $len"
74+
)
75+
76+
case 1 =>
77+
out(pos) = word.toByte
78+
pos += 1
79+
80+
case 2 =>
81+
out(pos) = word.toByte
82+
pos += 1
83+
out(pos) = (word >>> 8).toByte
84+
pos += 1
85+
86+
case 3 =>
87+
out(pos) = word.toByte
88+
pos += 1
89+
out(pos) = (word >>> 8).toByte
90+
pos += 1
91+
out(pos) = (word >>> 16).toByte
92+
pos += 1
93+
94+
case _ =>
95+
out(pos) = word.toByte
96+
pos += 1
97+
out(pos) = (word >>> 8).toByte
98+
pos += 1
99+
out(pos) = (word >>> 16).toByte
100+
pos += 1
101+
out(pos) = (word >>> 24).toByte
102+
pos += 1
103+
}
104+
wordIdx += 1
105+
}
106+
107+
outputBlockCounter += 1
108+
}
109+
}
110+
111+
private def subLoopInlineImpl(out: Array[Byte], off: Int, len: Int): Unit = {
112+
var outputBlockCounter = 0
113+
var pos = off
114+
val lim = off + len
115+
116+
val blockLenWords = BLOCK_LEN_WORDS
117+
val words = tmpBlockWords
118+
val flags = this.flags | ROOT
119+
120+
while (pos < lim) {
121+
compressRounds(words, blockWords, inputChainingValue, outputBlockCounter,
122+
blockLen, flags)
123+
124+
var wordIdx = 0
125+
while (wordIdx < blockLenWords && pos < lim) {
126+
val word = words(wordIdx)
127+
lim - pos match {
128+
case x if x <= 0 =>
129+
throw new RuntimeException(
130+
s"x: $x; pos: $pos; lim: $lim; wordIdx: $wordIdx; off: $off; len: $len"
131+
)
132+
133+
case 1 =>
134+
out(pos) = word.toByte
135+
pos += 1
136+
137+
case 2 =>
138+
out(pos) = word.toByte
139+
out(1 + pos) = (word >>> 8).toByte
140+
pos += 2
141+
142+
case 3 =>
143+
out(pos) = word.toByte
144+
out(1 + pos) = (word >>> 8).toByte
145+
out(2 + pos) = (word >>> 16).toByte
146+
pos += 3
147+
148+
case _ =>
149+
out(pos) = word.toByte
150+
out(1 + pos) = (word >>> 8).toByte
151+
out(2 + pos) = (word >>> 16).toByte
152+
out(3 + pos) = (word >>> 24).toByte
153+
pos += 4
154+
}
155+
wordIdx += 1
156+
}
157+
158+
outputBlockCounter += 1
159+
}
160+
}
161+
162+
private def hugeChunkImpl(out: Array[Byte], off: Int, len: Int): Unit = {
163+
var outputBlockCounter = 0
164+
var pos = off
165+
166+
val words = tmpBlockWords
167+
val flags = this.flags | ROOT
168+
169+
var lim = off + len - 63
170+
while (pos < lim) {
171+
compressRounds(words, blockWords, inputChainingValue, outputBlockCounter,
172+
blockLen, flags)
173+
174+
val word_0 = words(0)
175+
val word_1 = words(1)
176+
val word_2 = words(2)
177+
val word_3 = words(3)
178+
val word_4 = words(4)
179+
val word_5 = words(5)
180+
val word_6 = words(6)
181+
val word_7 = words(7)
182+
val word_8 = words(8)
183+
val word_9 = words(9)
184+
val word_10 = words(10)
185+
val word_11 = words(11)
186+
val word_12 = words(12)
187+
val word_13 = words(13)
188+
val word_14 = words(14)
189+
val word_15 = words(15)
190+
191+
out(pos) = word_0.toByte
192+
out(1 + pos) = (word_0 >>> 8).toByte
193+
out(2 + pos) = (word_0 >>> 16).toByte
194+
out(3 + pos) = (word_0 >>> 24).toByte
195+
out(4 + pos) = word_1.toByte
196+
out(5 + pos) = (word_1 >>> 8).toByte
197+
out(6 + pos) = (word_1 >>> 16).toByte
198+
out(7 + pos) = (word_1 >>> 24).toByte
199+
out(8 + pos) = word_2.toByte
200+
out(9 + pos) = (word_2 >>> 8).toByte
201+
out(10 + pos) = (word_2 >>> 16).toByte
202+
out(11 + pos) = (word_2 >>> 24).toByte
203+
out(12 + pos) = word_3.toByte
204+
out(13 + pos) = (word_3 >>> 8).toByte
205+
out(14 + pos) = (word_3 >>> 16).toByte
206+
out(15 + pos) = (word_3 >>> 24).toByte
207+
out(16 + pos) = word_4.toByte
208+
out(17 + pos) = (word_4 >>> 8).toByte
209+
out(18 + pos) = (word_4 >>> 16).toByte
210+
out(19 + pos) = (word_4 >>> 24).toByte
211+
out(20 + pos) = word_5.toByte
212+
out(21 + pos) = (word_5 >>> 8).toByte
213+
out(22 + pos) = (word_5 >>> 16).toByte
214+
out(23 + pos) = (word_5 >>> 24).toByte
215+
out(24 + pos) = word_6.toByte
216+
out(25 + pos) = (word_6 >>> 8).toByte
217+
out(26 + pos) = (word_6 >>> 16).toByte
218+
out(27 + pos) = (word_6 >>> 24).toByte
219+
out(28 + pos) = word_7.toByte
220+
out(29 + pos) = (word_7 >>> 8).toByte
221+
out(30 + pos) = (word_7 >>> 16).toByte
222+
out(31 + pos) = (word_7 >>> 24).toByte
223+
out(32 + pos) = word_8.toByte
224+
out(33 + pos) = (word_8 >>> 8).toByte
225+
out(34 + pos) = (word_8 >>> 16).toByte
226+
out(35 + pos) = (word_8 >>> 24).toByte
227+
out(36 + pos) = word_9.toByte
228+
out(37 + pos) = (word_9 >>> 8).toByte
229+
out(38 + pos) = (word_9 >>> 16).toByte
230+
out(39 + pos) = (word_9 >>> 24).toByte
231+
out(40 + pos) = word_10.toByte
232+
out(41 + pos) = (word_10 >>> 8).toByte
233+
out(42 + pos) = (word_10 >>> 16).toByte
234+
out(43 + pos) = (word_10 >>> 24).toByte
235+
out(44 + pos) = word_11.toByte
236+
out(45 + pos) = (word_11 >>> 8).toByte
237+
out(46 + pos) = (word_11 >>> 16).toByte
238+
out(47 + pos) = (word_11 >>> 24).toByte
239+
out(48 + pos) = word_12.toByte
240+
out(49 + pos) = (word_12 >>> 8).toByte
241+
out(50 + pos) = (word_12 >>> 16).toByte
242+
out(51 + pos) = (word_12 >>> 24).toByte
243+
out(52 + pos) = word_13.toByte
244+
out(53 + pos) = (word_13 >>> 8).toByte
245+
out(54 + pos) = (word_13 >>> 16).toByte
246+
out(55 + pos) = (word_13 >>> 24).toByte
247+
out(56 + pos) = word_14.toByte
248+
out(57 + pos) = (word_14 >>> 8).toByte
249+
out(58 + pos) = (word_14 >>> 16).toByte
250+
out(59 + pos) = (word_14 >>> 24).toByte
251+
out(60 + pos) = word_15.toByte
252+
out(61 + pos) = (word_15 >>> 8).toByte
253+
out(62 + pos) = (word_15 >>> 16).toByte
254+
out(63 + pos) = (word_15 >>> 24).toByte
255+
256+
pos += 64
257+
outputBlockCounter += 1
258+
}
259+
260+
lim += 63
261+
if (pos < lim) {
262+
compressRounds(words, blockWords, inputChainingValue, outputBlockCounter,
263+
blockLen, flags)
264+
265+
var wordIdx = 0
266+
while (pos < lim) {
267+
val word = words(wordIdx)
268+
lim - pos match {
269+
case x if x <= 0 =>
270+
throw new RuntimeException(
271+
s"x: $x; pos: $pos; lim: $lim; wordIdx: $wordIdx; off: $off; len: $len"
272+
)
273+
274+
case 1 =>
275+
out(pos) = word.toByte
276+
pos += 1
277+
278+
case 2 =>
279+
out(pos) = word.toByte
280+
out(1 + pos) = (word >>> 8).toByte
281+
pos += 2
282+
283+
case 3 =>
284+
out(pos) = word.toByte
285+
out(1 + pos) = (word >>> 8).toByte
286+
out(2 + pos) = (word >>> 16).toByte
287+
pos += 3
288+
289+
case _ =>
290+
out(pos) = word.toByte
291+
out(1 + pos) = (word >>> 8).toByte
292+
out(2 + pos) = (word >>> 16).toByte
293+
out(3 + pos) = (word >>> 24).toByte
294+
pos += 4
295+
}
296+
wordIdx += 1
297+
}
298+
}
299+
}
300+
}

shared/src/main/scala/pt/kcry/blake3/HasherImpl.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ private[blake3] class HasherImpl(val key: Array[Int], val flags: Int)
2727
new ChunkState(key, 0, flags, tmpChunkCV, tmpBlockWords)
2828

2929
private val output =
30-
new Output(key, tmpBlockWords, 0, BLOCK_LEN, flags, tmpChunkCV)
30+
new Output(key, tmpBlockWords, BLOCK_LEN, flags, tmpChunkCV)
3131

3232
// Space for 54 subtree chaining values
3333
private val cvStack: Array[Array[Int]] = {
@@ -199,7 +199,6 @@ private[blake3] class HasherImpl(val key: Array[Int], val flags: Int)
199199

200200
// reset cached output
201201
output.inputChainingValue = inputChainingValue
202-
output.counter = counter
203202
output.blockLen = blockLen
204203
output.flags = outputFlags
205204

0 commit comments

Comments
 (0)