5
5
const std = @import ("std" );
6
6
const builtin = @import ("builtin" );
7
7
const sig = @import ("../sig.zig" );
8
+
9
+ const zksdk = sig .zksdk ;
10
+
8
11
const Keccak1600 = std .crypto .core .keccak .KeccakF (1600 );
9
12
const Ed25519 = std .crypto .ecc .Edwards25519 ;
10
13
const Scalar = Ed25519 .scalar .Scalar ;
@@ -213,40 +216,115 @@ pub const Transcript = struct {
213
216
@"batched-grouped-ciphertext-validity-3-handles-instruction" ,
214
217
};
215
218
216
- pub fn init (comptime seperator : DomainSeperator ) Transcript {
219
+ const TranscriptInput = struct {
220
+ label : []const u8 ,
221
+ message : Message ,
222
+ };
223
+
224
+ const Message = union (enum ) {
225
+ bytes : []const u8 ,
226
+
227
+ point : Ristretto255 ,
228
+ pubkey : zksdk.el_gamal.Pubkey ,
229
+ scalar : Scalar ,
230
+ ciphertext : zksdk.el_gamal.Ciphertext ,
231
+ commitment : zksdk.pedersen.Commitment ,
232
+ u64 : u64 ,
233
+
234
+ grouped_2 : zksdk .el_gamal .GroupedElGamalCiphertext (2 ),
235
+ grouped_3 : zksdk .el_gamal .GroupedElGamalCiphertext (3 ),
236
+ };
237
+
238
+ pub fn init (comptime seperator : DomainSeperator , inputs : []const TranscriptInput ) Transcript {
217
239
var transcript : Transcript = .{ .strobe = Strobe128 .init ("Merlin v1.0" ) };
218
240
transcript .appendDomSep (seperator );
241
+ for (inputs ) | input | transcript .appendMessage (input .label , input .message );
219
242
return transcript ;
220
243
}
221
244
222
245
pub fn initTest (label : []const u8 ) Transcript {
223
246
comptime if (! builtin .is_test ) @compileError ("should only be used during tests" );
224
247
var transcript : Transcript = .{ .strobe = Strobe128 .init ("Merlin v1.0" ) };
225
- transcript .appendMessage ("dom-sep" , label );
248
+ transcript .appendBytes ("dom-sep" , label );
226
249
return transcript ;
227
250
}
228
251
229
- /// NOTE: be very careful with this function, there are only a specific few
230
- /// usages of it. always use helper functions if possible.
231
- pub fn appendMessage (
232
- self : * Transcript ,
233
- comptime label : []const u8 ,
234
- message : []const u8 ,
235
- ) void {
252
+ fn appendBytes (self : * Transcript , label : []const u8 , bytes : []const u8 ) void {
236
253
var data_len : [4 ]u8 = undefined ;
237
- std .mem .writeInt (u32 , & data_len , @intCast (message .len ), .little );
254
+ std .mem .writeInt (u32 , & data_len , @intCast (bytes .len ), .little );
238
255
self .strobe .metaAd (label , false );
239
256
self .strobe .metaAd (& data_len , true );
240
- self .strobe .ad (message , false );
257
+ self .strobe .ad (bytes , false );
241
258
}
242
259
243
- pub fn appendDomSep (self : * Transcript , comptime seperator : DomainSeperator ) void {
244
- self .appendMessage ("dom-sep" , @tagName (seperator ));
260
+ fn appendMessage (self : * Transcript , label : []const u8 , message : Message ) void {
261
+ var buffer : [64 ]u8 = @splat (0 );
262
+ const bytes : []const u8 = switch (message ) {
263
+ .bytes = > | b | b ,
264
+ .point = > | * point | & point .toBytes (),
265
+ .pubkey = > | * pubkey | & pubkey .toBytes (),
266
+ .scalar = > | * scalar | & scalar .toBytes (),
267
+ .ciphertext = > | * ct | b : {
268
+ @memcpy (buffer [0.. 32], & ct .commitment .point .toBytes ());
269
+ @memcpy (buffer [32.. 64], & ct .handle .point .toBytes ());
270
+ break :b & buffer ;
271
+ },
272
+ .commitment = > | * c | & c .toBytes (),
273
+ .u64 = > | x | b : {
274
+ std .mem .writeInt (u64 , buffer [0.. 8], x , .little );
275
+ break :b buffer [0.. 8];
276
+ },
277
+ inline .grouped_2 , .grouped_3 = > | * g | & g .toBytes (),
278
+ };
279
+ self .appendBytes (label , bytes );
245
280
}
246
281
247
- pub fn challengeBytes (
282
+ pub inline fn append (
248
283
self : * Transcript ,
284
+ comptime session : * Session ,
285
+ comptime t : Input.Type ,
249
286
comptime label : []const u8 ,
287
+ data : t .Data (),
288
+ ) if (t == .validate_point ) error {IdentityElement }! void else void {
289
+ // if validate_point fails to validate, we no longer want to check the contract
290
+ // because the function calling append will now return early.
291
+ errdefer session .cancel ();
292
+
293
+ if (t == .bytes and ! builtin .is_test )
294
+ @compileError ("message type `bytes` only allowed in tests" );
295
+
296
+ // assert correctness
297
+ const input = comptime session .nextInput (t , label );
298
+ if (t == .validate_point ) try data .rejectIdentity ();
299
+
300
+ // add the message
301
+ self .appendMessage (input .label , @unionInit (
302
+ Message ,
303
+ @tagName (switch (t ) {
304
+ .validate_point = > .point ,
305
+ else = > t ,
306
+ }),
307
+ data ,
308
+ ));
309
+ }
310
+
311
+ /// Helper function to be used in proof creation. We often need to test what will
312
+ /// happen if points are zeroed, and to make sure that the verification fails.
313
+ /// Shouldn't be used outside of the `init` functions.
314
+ pub inline fn appendNoValidate (
315
+ self : * Transcript ,
316
+ comptime session : * Session ,
317
+ comptime label : []const u8 ,
318
+ point : Ristretto255 ,
319
+ ) void {
320
+ const input = comptime session .nextInput (.validate_point , label );
321
+ point .rejectIdentity () catch {}; // ignore the error
322
+ self .appendMessage (input .label , .{ .point = point });
323
+ }
324
+
325
+ fn challengeBytes (
326
+ self : * Transcript ,
327
+ label : []const u8 ,
250
328
destination : []u8 ,
251
329
) void {
252
330
var data_len : [4 ]u8 = undefined ;
@@ -257,68 +335,24 @@ pub const Transcript = struct {
257
335
self .strobe .prf (destination , false );
258
336
}
259
337
260
- pub fn challengeScalar (
338
+ pub inline fn challengeScalar (
261
339
self : * Transcript ,
340
+ comptime session : * Session ,
262
341
comptime label : []const u8 ,
263
342
) Scalar {
264
- var buffer : [64 ]u8 = .{0 } ** 64 ;
265
- self .challengeBytes (label , & buffer );
343
+ const input = comptime session .nextInput (.challenge , label );
344
+ var buffer : [64 ]u8 = @splat (0 );
345
+ self .challengeBytes (input .label , & buffer );
266
346
// Specifically need reduce64 instead of Scalar.fromBytes64, since
267
347
// we need the Barret reduction to be done with 10 limbs, not 5.
268
348
const compressed = Ed25519 .scalar .reduce64 (buffer );
269
349
return Scalar .fromBytes (compressed );
270
350
}
271
351
272
- pub fn validateAndAppendPoint (
273
- self : * Transcript ,
274
- comptime label : []const u8 ,
275
- point : Ristretto255 ,
276
- ) ! void {
277
- try point .rejectIdentity ();
278
- self .appendPoint (label , point );
279
- }
280
-
281
- // helper functions
282
-
283
- pub fn appendPoint (self : * Transcript , comptime label : []const u8 , point : Ristretto255 ) void {
284
- self .appendMessage (label , & point .toBytes ());
285
- }
286
-
287
- pub fn appendScalar (self : * Transcript , comptime label : []const u8 , scalar : Scalar ) void {
288
- self .appendMessage (label , & scalar .toBytes ());
289
- }
290
-
291
- pub fn appendPubkey (
292
- self : * Transcript ,
293
- comptime label : []const u8 ,
294
- pubkey : sig.zksdk.ElGamalPubkey ,
295
- ) void {
296
- self .appendPoint (label , pubkey .point );
297
- }
298
-
299
- pub fn appendCiphertext (
300
- self : * Transcript ,
301
- comptime label : []const u8 ,
302
- ciphertext : sig.zksdk.ElGamalCiphertext ,
303
- ) void {
304
- var buffer : [64 ]u8 = .{0 } ** 64 ;
305
- @memcpy (buffer [0.. 32], & ciphertext .commitment .point .toBytes ());
306
- @memcpy (buffer [32.. 64], & ciphertext .handle .point .toBytes ());
307
- self .appendMessage (label , & buffer );
308
- }
309
-
310
- pub fn appendCommitment (
311
- self : * Transcript ,
312
- comptime label : []const u8 ,
313
- commitment : sig.zksdk.pedersen.Commitment ,
314
- ) void {
315
- self .appendMessage (label , & commitment .point .toBytes ());
316
- }
352
+ // domain seperation helpers
317
353
318
- pub fn appendU64 (self : * Transcript , comptime label : []const u8 , x : u64 ) void {
319
- var buffer : [8 ]u8 = .{0 } ** 8 ;
320
- std .mem .writeInt (u64 , & buffer , x , .little );
321
- self .appendMessage (label , & buffer );
354
+ pub fn appendDomSep (self : * Transcript , comptime seperator : DomainSeperator ) void {
355
+ self .appendBytes ("dom-sep" , @tagName (seperator ));
322
356
}
323
357
324
358
pub fn appendHandleDomSep (
@@ -330,10 +364,10 @@ pub const Transcript = struct {
330
364
.batched = > .@"batched-validity-proof" ,
331
365
.unbatched = > .@"validity-proof" ,
332
366
});
333
- self .appendU64 ("handles" , switch (handles ) {
367
+ self .appendMessage ("handles" , .{ . u64 = switch (handles ) {
334
368
.two = > 2 ,
335
369
.three = > 3 ,
336
- });
370
+ } } );
337
371
}
338
372
339
373
pub fn appendRangeProof (
@@ -345,14 +379,90 @@ pub const Transcript = struct {
345
379
.range = > .@"range-proof" ,
346
380
.inner = > .@"inner-product" ,
347
381
});
348
- self .appendU64 ("n" , n );
382
+ self .appendMessage ("n" , .{ .u64 = n });
383
+ }
384
+
385
+ // sessions
386
+
387
+ pub const Input = struct {
388
+ label : []const u8 ,
389
+ type : Type ,
390
+
391
+ const Type = enum {
392
+ bytes ,
393
+ scalar ,
394
+ challenge ,
395
+ point ,
396
+ validate_point ,
397
+ pubkey ,
398
+
399
+ pub fn Data (comptime t : Type ) type {
400
+ return switch (t ) {
401
+ .bytes = > []const u8 ,
402
+ .scalar = > Scalar ,
403
+ .validate_point , .point = > Ristretto255 ,
404
+ .pubkey = > zksdk .el_gamal .Pubkey ,
405
+ .challenge = > unreachable , // call `challenge*`
406
+ };
407
+ }
408
+ };
409
+
410
+ fn check (self : Input , t : Type , label : []const u8 ) void {
411
+ std .debug .assert (self .type == t );
412
+ std .debug .assert (std .mem .eql (u8 , self .label , label ));
413
+ }
414
+ };
415
+
416
+ pub const Contract = []const Input ;
417
+
418
+ pub const Session = struct {
419
+ i : u8 ,
420
+ contract : Contract ,
421
+ err : bool , // if validate_point errors, we skip the finish() check
422
+
423
+ pub inline fn nextInput (comptime self : * Session , t : Input.Type , label : []const u8 ) Input {
424
+ comptime {
425
+ defer self .i += 1 ;
426
+ const input = self .contract [self .i ];
427
+ input .check (t , label );
428
+ return input ;
429
+ }
430
+ }
431
+
432
+ pub inline fn finish (comptime self : * Session ) void {
433
+ // For performance, we have certain computations (specifically in `init` functions)
434
+ // which skip the last parts of transcript when they aren't needed (i.e ciphertext_ciphertext proof).
435
+ //
436
+ // By performing this check, we still ensure that they do those extra computations when in Debug mode,
437
+ // but are allowed to skip them in a release build.
438
+ if (builtin .mode == .Debug and ! self .err and self .i != self .contract .len ) {
439
+ @compileError ("contract unfulfilled" );
440
+ }
441
+ }
442
+
443
+ inline fn cancel (comptime self : * Session ) void {
444
+ comptime self .err = true ;
445
+ }
446
+ };
447
+
448
+ pub inline fn getSession (comptime contract : []const Input ) Session {
449
+ comptime {
450
+ // contract should always end in a challenge
451
+ const last_contract = contract [contract .len - 1 ];
452
+ std .debug .assert (last_contract .type == .challenge );
453
+ return .{ .i = 0 , .contract = contract , .err = false };
454
+ }
349
455
}
350
456
};
351
457
352
458
test "equivalence" {
353
459
var transcript = Transcript .initTest ("test protocol" );
354
460
355
- transcript .appendMessage ("some label" , "some data" );
461
+ comptime var session = Transcript .getSession (&.{
462
+ .{ .label = "some label" , .type = .bytes },
463
+ .{ .label = "challenge" , .type = .challenge },
464
+ });
465
+ transcript .append (& session , .bytes , "some label" , "some data" );
356
466
357
467
var bytes : [32 ]u8 = undefined ;
358
468
transcript .challengeBytes ("challenge" , & bytes );
0 commit comments