88 */
99package org .elasticsearch .benchmark .vector ;
1010
11+ import org .apache .lucene .index .VectorSimilarityFunction ;
1112import org .apache .lucene .store .Directory ;
1213import org .apache .lucene .store .IOContext ;
1314import org .apache .lucene .store .IndexInput ;
1415import org .apache .lucene .store .IndexOutput ;
1516import org .apache .lucene .store .MMapDirectory ;
1617import org .apache .lucene .util .VectorUtil ;
18+ import org .apache .lucene .util .quantization .OptimizedScalarQuantizer ;
1719import org .elasticsearch .common .logging .LogConfigurator ;
1820import org .elasticsearch .core .IOUtils ;
1921import org .elasticsearch .simdvec .ES91Int4VectorsScorer ;
@@ -52,20 +54,26 @@ public class Int4ScorerBenchmark {
5254 LogConfigurator .configureESLogging (); // native access requires logging to be initialized
5355 }
5456
55- @ Param ({ "384" , "702 " , "1024" })
57+ @ Param ({ "384" , "782 " , "1024" })
5658 int dims ;
5759
58- int numVectors = 200 ;
59- int numQueries = 10 ;
60+ int numVectors = 20 * ES91Int4VectorsScorer . BULK_SIZE ;
61+ int numQueries = 5 ;
6062
6163 byte [] scratch ;
6264 byte [][] binaryVectors ;
6365 byte [][] binaryQueries ;
66+ float [] scores = new float [ES91Int4VectorsScorer .BULK_SIZE ];
67+
68+ float [] scratchFloats = new float [3 ];
6469
6570 ES91Int4VectorsScorer scorer ;
6671 Directory dir ;
6772 IndexInput in ;
6873
74+ OptimizedScalarQuantizer .QuantizationResult queryCorrections ;
75+ float centroidDp ;
76+
6977 @ Setup
7078 public void setup () throws IOException {
7179 binaryVectors = new byte [numVectors ][dims ];
@@ -77,9 +85,19 @@ public void setup() throws IOException {
7785 binaryVector [i ] = (byte ) ThreadLocalRandom .current ().nextInt (16 );
7886 }
7987 out .writeBytes (binaryVector , 0 , binaryVector .length );
88+ ThreadLocalRandom .current ().nextBytes (binaryVector );
89+ out .writeBytes (binaryVector , 0 , 14 ); // corrections
8090 }
8191 }
8292
93+ queryCorrections = new OptimizedScalarQuantizer .QuantizationResult (
94+ ThreadLocalRandom .current ().nextFloat (),
95+ ThreadLocalRandom .current ().nextFloat (),
96+ ThreadLocalRandom .current ().nextFloat (),
97+ Short .toUnsignedInt ((short ) ThreadLocalRandom .current ().nextInt ())
98+ );
99+ centroidDp = ThreadLocalRandom .current ().nextFloat ();
100+
83101 in = dir .openInput ("vectors" , IOContext .DEFAULT );
84102 binaryQueries = new byte [numVectors ][dims ];
85103 for (byte [] binaryVector : binaryVectors ) {
@@ -105,18 +123,66 @@ public void scoreFromArray(Blackhole bh) throws IOException {
105123 in .seek (0 );
106124 for (int i = 0 ; i < numVectors ; i ++) {
107125 in .readBytes (scratch , 0 , dims );
108- bh .consume (VectorUtil .int4DotProduct (binaryQueries [j ], scratch ));
126+ int dp = VectorUtil .int4DotProduct (binaryQueries [j ], scratch );
127+ in .readFloats (scratchFloats , 0 , 3 );
128+ float score = scorer .applyCorrections (
129+ queryCorrections .lowerInterval (),
130+ queryCorrections .upperInterval (),
131+ queryCorrections .quantizedComponentSum (),
132+ queryCorrections .additionalCorrection (),
133+ VectorSimilarityFunction .EUCLIDEAN ,
134+ centroidDp , // assuming no centroid dot product for this benchmark
135+ scratchFloats [0 ],
136+ scratchFloats [1 ],
137+ Short .toUnsignedInt (in .readShort ()),
138+ scratchFloats [2 ],
139+ dp
140+ );
141+ bh .consume (score );
109142 }
110143 }
111144 }
112145
113146 @ Benchmark
114147 @ Fork (jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
115- public void scoreFromMemorySegmentOnlyVector (Blackhole bh ) throws IOException {
148+ public void scoreFromMemorySegment (Blackhole bh ) throws IOException {
116149 for (int j = 0 ; j < numQueries ; j ++) {
117150 in .seek (0 );
118151 for (int i = 0 ; i < numVectors ; i ++) {
119- bh .consume (scorer .int4DotProduct (binaryQueries [j ]));
152+ bh .consume (
153+ scorer .score (
154+ binaryQueries [j ],
155+ queryCorrections .lowerInterval (),
156+ queryCorrections .upperInterval (),
157+ queryCorrections .quantizedComponentSum (),
158+ queryCorrections .additionalCorrection (),
159+ VectorSimilarityFunction .EUCLIDEAN ,
160+ centroidDp
161+ )
162+ );
163+ }
164+ }
165+ }
166+
167+ @ Benchmark
168+ @ Fork (jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
169+ public void scoreFromMemorySegmentBulk (Blackhole bh ) throws IOException {
170+ for (int j = 0 ; j < numQueries ; j ++) {
171+ in .seek (0 );
172+ for (int i = 0 ; i < numVectors ; i += ES91Int4VectorsScorer .BULK_SIZE ) {
173+ scorer .scoreBulk (
174+ binaryQueries [j ],
175+ queryCorrections .lowerInterval (),
176+ queryCorrections .upperInterval (),
177+ queryCorrections .quantizedComponentSum (),
178+ queryCorrections .additionalCorrection (),
179+ VectorSimilarityFunction .EUCLIDEAN ,
180+ centroidDp ,
181+ scores
182+ );
183+ for (float score : scores ) {
184+ bh .consume (score );
185+ }
120186 }
121187 }
122188 }
0 commit comments