Skip to content

Commit 6f6467a

Browse files
committed
zksdk: add fiat-shamir static verification
1 parent 5a62f90 commit 6f6467a

File tree

10 files changed

+653
-306
lines changed

10 files changed

+653
-306
lines changed

src/zksdk/merlin.zig

Lines changed: 180 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
const std = @import("std");
66
const builtin = @import("builtin");
77
const sig = @import("../sig.zig");
8+
9+
const zksdk = sig.zksdk;
10+
811
const Keccak1600 = std.crypto.core.keccak.KeccakF(1600);
912
const Ed25519 = std.crypto.ecc.Edwards25519;
1013
const Scalar = Ed25519.scalar.Scalar;
@@ -213,40 +216,115 @@ pub const Transcript = struct {
213216
@"batched-grouped-ciphertext-validity-3-handles-instruction",
214217
};
215218

216-
pub fn init(comptime seperator: DomainSeperator) Transcript {
219+
const TranscriptInput = struct {
220+
label: []const u8,
221+
message: Message,
222+
};
223+
224+
const Message = union(enum) {
225+
bytes: []const u8,
226+
227+
point: Ristretto255,
228+
pubkey: zksdk.el_gamal.Pubkey,
229+
scalar: Scalar,
230+
ciphertext: zksdk.el_gamal.Ciphertext,
231+
commitment: zksdk.pedersen.Commitment,
232+
u64: u64,
233+
234+
grouped_2: zksdk.el_gamal.GroupedElGamalCiphertext(2),
235+
grouped_3: zksdk.el_gamal.GroupedElGamalCiphertext(3),
236+
};
237+
238+
pub fn init(comptime seperator: DomainSeperator, inputs: []const TranscriptInput) Transcript {
217239
var transcript: Transcript = .{ .strobe = Strobe128.init("Merlin v1.0") };
218240
transcript.appendDomSep(seperator);
241+
for (inputs) |input| transcript.appendMessage(input.label, input.message);
219242
return transcript;
220243
}
221244

222245
pub fn initTest(label: []const u8) Transcript {
223246
comptime if (!builtin.is_test) @compileError("should only be used during tests");
224247
var transcript: Transcript = .{ .strobe = Strobe128.init("Merlin v1.0") };
225-
transcript.appendMessage("dom-sep", label);
248+
transcript.appendBytes("dom-sep", label);
226249
return transcript;
227250
}
228251

229-
/// NOTE: be very careful with this function, there are only a specific few
230-
/// usages of it. always use helper functions if possible.
231-
pub fn appendMessage(
232-
self: *Transcript,
233-
comptime label: []const u8,
234-
message: []const u8,
235-
) void {
252+
fn appendBytes(self: *Transcript, label: []const u8, bytes: []const u8) void {
236253
var data_len: [4]u8 = undefined;
237-
std.mem.writeInt(u32, &data_len, @intCast(message.len), .little);
254+
std.mem.writeInt(u32, &data_len, @intCast(bytes.len), .little);
238255
self.strobe.metaAd(label, false);
239256
self.strobe.metaAd(&data_len, true);
240-
self.strobe.ad(message, false);
257+
self.strobe.ad(bytes, false);
241258
}
242259

243-
pub fn appendDomSep(self: *Transcript, comptime seperator: DomainSeperator) void {
244-
self.appendMessage("dom-sep", @tagName(seperator));
260+
fn appendMessage(self: *Transcript, label: []const u8, message: Message) void {
261+
var buffer: [64]u8 = @splat(0);
262+
const bytes: []const u8 = switch (message) {
263+
.bytes => |b| b,
264+
.point => |*point| &point.toBytes(),
265+
.pubkey => |*pubkey| &pubkey.toBytes(),
266+
.scalar => |*scalar| &scalar.toBytes(),
267+
.ciphertext => |*ct| b: {
268+
@memcpy(buffer[0..32], &ct.commitment.point.toBytes());
269+
@memcpy(buffer[32..64], &ct.handle.point.toBytes());
270+
break :b &buffer;
271+
},
272+
.commitment => |*c| &c.toBytes(),
273+
.u64 => |x| b: {
274+
std.mem.writeInt(u64, buffer[0..8], x, .little);
275+
break :b buffer[0..8];
276+
},
277+
inline .grouped_2, .grouped_3 => |*g| &g.toBytes(),
278+
};
279+
self.appendBytes(label, bytes);
245280
}
246281

247-
pub fn challengeBytes(
282+
pub inline fn append(
248283
self: *Transcript,
284+
comptime session: *Session,
285+
comptime t: Input.Type,
249286
comptime label: []const u8,
287+
data: t.Data(),
288+
) if (t == .validate_point) error{IdentityElement}!void else void {
289+
// if validate_point fails to validate, we no longer want to check the contract
290+
// because the function calling append will now return early.
291+
errdefer session.cancel();
292+
293+
if (t == .bytes and !builtin.is_test)
294+
@compileError("message type `bytes` only allowed in tests");
295+
296+
// assert correctness
297+
const input = comptime session.nextInput(t, label);
298+
if (t == .validate_point) try data.rejectIdentity();
299+
300+
// add the message
301+
self.appendMessage(input.label, @unionInit(
302+
Message,
303+
@tagName(switch (t) {
304+
.validate_point => .point,
305+
else => t,
306+
}),
307+
data,
308+
));
309+
}
310+
311+
/// Helper function to be used in proof creation. We often need to test what will
312+
/// happen if points are zeroed, and to make sure that the verification fails.
313+
/// Shouldn't be used outside of the `init` functions.
314+
pub inline fn appendNoValidate(
315+
self: *Transcript,
316+
comptime session: *Session,
317+
comptime label: []const u8,
318+
point: Ristretto255,
319+
) void {
320+
const input = comptime session.nextInput(.validate_point, label);
321+
point.rejectIdentity() catch {}; // ignore the error
322+
self.appendMessage(input.label, .{ .point = point });
323+
}
324+
325+
fn challengeBytes(
326+
self: *Transcript,
327+
label: []const u8,
250328
destination: []u8,
251329
) void {
252330
var data_len: [4]u8 = undefined;
@@ -257,68 +335,24 @@ pub const Transcript = struct {
257335
self.strobe.prf(destination, false);
258336
}
259337

260-
pub fn challengeScalar(
338+
pub inline fn challengeScalar(
261339
self: *Transcript,
340+
comptime session: *Session,
262341
comptime label: []const u8,
263342
) Scalar {
264-
var buffer: [64]u8 = .{0} ** 64;
265-
self.challengeBytes(label, &buffer);
343+
const input = comptime session.nextInput(.challenge, label);
344+
var buffer: [64]u8 = @splat(0);
345+
self.challengeBytes(input.label, &buffer);
266346
// Specifically need reduce64 instead of Scalar.fromBytes64, since
267347
// we need the Barret reduction to be done with 10 limbs, not 5.
268348
const compressed = Ed25519.scalar.reduce64(buffer);
269349
return Scalar.fromBytes(compressed);
270350
}
271351

272-
pub fn validateAndAppendPoint(
273-
self: *Transcript,
274-
comptime label: []const u8,
275-
point: Ristretto255,
276-
) !void {
277-
try point.rejectIdentity();
278-
self.appendPoint(label, point);
279-
}
280-
281-
// helper functions
282-
283-
pub fn appendPoint(self: *Transcript, comptime label: []const u8, point: Ristretto255) void {
284-
self.appendMessage(label, &point.toBytes());
285-
}
286-
287-
pub fn appendScalar(self: *Transcript, comptime label: []const u8, scalar: Scalar) void {
288-
self.appendMessage(label, &scalar.toBytes());
289-
}
290-
291-
pub fn appendPubkey(
292-
self: *Transcript,
293-
comptime label: []const u8,
294-
pubkey: sig.zksdk.ElGamalPubkey,
295-
) void {
296-
self.appendPoint(label, pubkey.point);
297-
}
298-
299-
pub fn appendCiphertext(
300-
self: *Transcript,
301-
comptime label: []const u8,
302-
ciphertext: sig.zksdk.ElGamalCiphertext,
303-
) void {
304-
var buffer: [64]u8 = .{0} ** 64;
305-
@memcpy(buffer[0..32], &ciphertext.commitment.point.toBytes());
306-
@memcpy(buffer[32..64], &ciphertext.handle.point.toBytes());
307-
self.appendMessage(label, &buffer);
308-
}
309-
310-
pub fn appendCommitment(
311-
self: *Transcript,
312-
comptime label: []const u8,
313-
commitment: sig.zksdk.pedersen.Commitment,
314-
) void {
315-
self.appendMessage(label, &commitment.point.toBytes());
316-
}
352+
// domain seperation helpers
317353

318-
pub fn appendU64(self: *Transcript, comptime label: []const u8, x: u64) void {
319-
var buffer: [8]u8 = .{0} ** 8;
320-
std.mem.writeInt(u64, &buffer, x, .little);
321-
self.appendMessage(label, &buffer);
354+
pub fn appendDomSep(self: *Transcript, comptime seperator: DomainSeperator) void {
355+
self.appendBytes("dom-sep", @tagName(seperator));
322356
}
323357

324358
pub fn appendHandleDomSep(
@@ -330,10 +364,10 @@ pub const Transcript = struct {
330364
.batched => .@"batched-validity-proof",
331365
.unbatched => .@"validity-proof",
332366
});
333-
self.appendU64("handles", switch (handles) {
367+
self.appendMessage("handles", .{ .u64 = switch (handles) {
334368
.two => 2,
335369
.three => 3,
336-
});
370+
} });
337371
}
338372

339373
pub fn appendRangeProof(
@@ -345,14 +379,90 @@ pub const Transcript = struct {
345379
.range => .@"range-proof",
346380
.inner => .@"inner-product",
347381
});
348-
self.appendU64("n", n);
382+
self.appendMessage("n", .{ .u64 = n });
383+
}
384+
385+
// sessions
386+
387+
pub const Input = struct {
388+
label: []const u8,
389+
type: Type,
390+
391+
const Type = enum {
392+
bytes,
393+
scalar,
394+
challenge,
395+
point,
396+
validate_point,
397+
pubkey,
398+
399+
pub fn Data(comptime t: Type) type {
400+
return switch (t) {
401+
.bytes => []const u8,
402+
.scalar => Scalar,
403+
.validate_point, .point => Ristretto255,
404+
.pubkey => zksdk.el_gamal.Pubkey,
405+
.challenge => unreachable, // call `challenge*`
406+
};
407+
}
408+
};
409+
410+
fn check(self: Input, t: Type, label: []const u8) void {
411+
std.debug.assert(self.type == t);
412+
std.debug.assert(std.mem.eql(u8, self.label, label));
413+
}
414+
};
415+
416+
pub const Contract = []const Input;
417+
418+
pub const Session = struct {
419+
i: u8,
420+
contract: Contract,
421+
err: bool, // if validate_point errors, we skip the finish() check
422+
423+
pub inline fn nextInput(comptime self: *Session, t: Input.Type, label: []const u8) Input {
424+
comptime {
425+
defer self.i += 1;
426+
const input = self.contract[self.i];
427+
input.check(t, label);
428+
return input;
429+
}
430+
}
431+
432+
pub inline fn finish(comptime self: *Session) void {
433+
// For performance, we have certain computations (specifically in `init` functions)
434+
// which skip the last parts of transcript when they aren't needed (i.e ciphertext_ciphertext proof).
435+
//
436+
// By performing this check, we still ensure that they do those extra computations when in Debug mode,
437+
// but are allowed to skip them in a release build.
438+
if (builtin.mode == .Debug and !self.err and self.i != self.contract.len) {
439+
@compileError("contract unfulfilled");
440+
}
441+
}
442+
443+
inline fn cancel(comptime self: *Session) void {
444+
comptime self.err = true;
445+
}
446+
};
447+
448+
pub inline fn getSession(comptime contract: []const Input) Session {
449+
comptime {
450+
// contract should always end in a challenge
451+
const last_contract = contract[contract.len - 1];
452+
std.debug.assert(last_contract.type == .challenge);
453+
return .{ .i = 0, .contract = contract, .err = false };
454+
}
349455
}
350456
};
351457

352458
test "equivalence" {
353459
var transcript = Transcript.initTest("test protocol");
354460

355-
transcript.appendMessage("some label", "some data");
461+
comptime var session = Transcript.getSession(&.{
462+
.{ .label = "some label", .type = .bytes },
463+
.{ .label = "challenge", .type = .challenge },
464+
});
465+
transcript.append(&session, .bytes, "some label", "some data");
356466

357467
var bytes: [32]u8 = undefined;
358468
transcript.challengeBytes("challenge", &bytes);

0 commit comments

Comments
 (0)