Skip to content

Commit 5969915

Browse files
authored
Improve packed pcs (#39)
* wip * wip * wip * lots of duplications + unoptimized, but works * wip * parralel * skip one computation * wip * wip * fixed * wip * works, now we need to optimize * simplify * sparse point * parralelize inner loop * add TODO * wip (issue with proof size) * wip * malloc_vec with vector size >= 8 * better * same proof size as in branch main * wip * wip * gud * add TODO * gud gud * extension commitment packed * fix big memory inefficiency * benchs --------- Co-authored-by: Tom Wambsgans <[email protected]>
1 parent 0a42e69 commit 5969915

36 files changed

+1832
-1188
lines changed

Cargo.lock

Lines changed: 9 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ RUSTFLAGS='-C target-cpu=native' NUM_XMSS_AGGREGATED='500' cargo test --release
5656
![Alt text](docs/benchmark_graphs/graphs/xmss_aggregated_time.svg)
5757
![Alt text](docs/benchmark_graphs/graphs/xmss_aggregated_overhead.svg)
5858

59+
### Proof size
60+
61+
With conjecture "up to capacity", current proofs with rate = 1/2 are about about ≈ 400 - 500 KiB, in which ≈ 300 KiB comes from WHIR.
62+
63+
- The remaining 100 - 200 KiB will be significantly reduced in the future (this part has not been optimized at all).
64+
- WHIR proof size will also be reduced, thanks to merkle pruning (TODO).
65+
66+
Reasonable target: 256 KiB for fast proof, 128 KiB for slower proofs (rate = 1/4 or 1/8).
67+
5968
## Credits
6069

6170
- [Plonky3](https://github.com/Plonky3/Plonky3) for its various performant crates (Finite fields, poseidon2 AIR etc)

TODO.md

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
- use RowMAjorMatrix instead of Vec<Vec> for witness, and avoid any transpositions as suggested by Thomas
1111
- Fill Precompile tables during bytecode execution
1212
- Use Univariate Skip to commit to tables with k.2^n rows (k small)
13-
- increase density of multi commitments -> we can almost reduce by 2x commitment costs (question: will perf be good enough in order to avoid using the "jagged pcs" (cf sp1 hypercube)?)
1413
- avoid field embedding in the initial sumcheck of logup*, when table / values are in base field
1514
- opti logup* GKR:
1615
- when the indexes are not a power of 2 (which is the case in the execution table)
@@ -23,8 +22,50 @@
2322
- Sumcheck, case z = 0, no need to fold, only keep first half of the values (done in PR 33 by Lambda) (and also in WHIR?)
2423
- Custom AVX2 / AVX512 / Neon implem in Plonky3 for all of the finite field operations (done for degree 4 extension, but not degree 5)
2524
- the 2 executions of the program, before generating the validity proof, can be merged, using some kind of placeholders
26-
- both WHIR verif + XMSS aggregation programs have 40% of unused memory!! -> TODO improve the compiler to reduce memory fragmentation
2725
- Many times, we evaluate different multilinear polynomials (diferent columns of the same table etc) at a common point. OPTI = compute the eq(.) once, and then dot_product with everything
26+
- To commit to multiple AIR table using 1 single pcs, the most general form our "packed pcs" api should accept is:
27+
a list of n (n not a power of 2) columns, each ending with m repeated values (in this manner we can reduce proof size when they are a lot of columns (poseidons ...))
28+
29+
About "the packed pcs" (similar to SP1 Jagged PCS, slightly less efficient, but simpler (no sumchecks)):
30+
- The best strategy is probably to pack as much as possible (the cost increasing the density = additional inner evaluations), if we can fit below a power of 2 - epsilon (epsilon = 20% for instance, tbd), if the sum of the non zero data is just above a power of 2, no packed technique, even the best, can help us, so we should spread aniway (to reduce the pressure of inner evaluations)
31+
- About those inner evaluations, there is a trick: we need to compute M1(a, b, c, d, ...) then M2(b, c, d, ...), then M3(c, d, ...) -> The trick = compute the "eq(.) for (b, c, d), then dot product with M3. Then expand to eq(b, c, d, ...), dot product with M2. Then expand to eq(a, b, c, d, ...), dot product with M1. The idea is that in this order, computing each "eq" is easier is we start from the previous one.
32+
- Currently the packed pcs works as follows:
33+
34+
```
35+
┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐
36+
| || || || || || || || || || || || || || |
37+
| || || || || || || || || || || || || || |
38+
| || || || || || || || || || || || || || |
39+
| || || || || || || || || || || || || || |
40+
| || || || || || || || || || || || || || |
41+
| || || || || || || || || || || || || || |
42+
└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘
43+
┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐
44+
| || || || || || || || || || || || || || |
45+
| || || || || || || || || || || || || || |
46+
└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘
47+
┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐┌─┐
48+
└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘└─┘
49+
```
50+
51+
But we reduce proof size a lot using instead (TODO):
52+
53+
```
54+
┌────────────────────────┐┌──────────┐┌─┐
55+
| || || |
56+
| || || |
57+
| || || |
58+
| || || |
59+
| || || |
60+
| || || |
61+
└────────────────────────┘└──────────┘└─┘
62+
┌────────────────────────┐┌──────────┐┌─┐
63+
| || || |
64+
| || || |
65+
└────────────────────────┘└──────────┘└─┘
66+
┌────────────────────────┐┌──────────┐┌─┐
67+
└────────────────────────┘└──────────┘└─┘
68+
```
2869

2970
## Not Perf
3071

@@ -37,6 +78,7 @@
3778

3879
- KoalaBear extension of degree 5: the current implem (in a fork of Plonky3) has not been been optimized
3980
- KoalaBear extension of degree 6: in order to use the (proven) Johnson bound in WHIR
81+
- current "packed PCS" is not optimal in the end: can lead to [16][4][2][2] (instead of [16][8])
4082

4183
## Known leanISA compiler bugs:
4284

crates/compiler/src/a_simplify_lang.rs

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ pub enum SimpleLine {
127127
var: Var,
128128
size: SimpleExpr,
129129
vectorized: bool,
130+
vectorized_len: SimpleExpr,
130131
},
131132
ConstMalloc {
132133
// always not vectorized
@@ -283,6 +284,7 @@ fn simplify_lines(
283284
arg1: right,
284285
});
285286
}
287+
Expression::Log2Ceil { .. } => unreachable!(),
286288
},
287289
Line::ArrayAssign {
288290
array,
@@ -589,9 +591,17 @@ fn simplify_lines(
589591
var,
590592
size,
591593
vectorized,
594+
vectorized_len,
592595
} => {
593596
let simplified_size =
594597
simplify_expr(size, &mut res, counters, array_manager, const_malloc);
598+
let simplified_vectorized_len = simplify_expr(
599+
vectorized_len,
600+
&mut res,
601+
counters,
602+
array_manager,
603+
const_malloc,
604+
);
595605
if simplified_size.is_constant()
596606
&& !*vectorized
597607
&& const_malloc.forbidden_vars.contains(var)
@@ -619,6 +629,7 @@ fn simplify_lines(
619629
var: var.clone(),
620630
size: simplified_size,
621631
vectorized: *vectorized,
632+
vectorized_len: simplified_vectorized_len,
622633
});
623634
}
624635
}
@@ -724,6 +735,14 @@ fn simplify_expr(
724735
});
725736
SimpleExpr::Var(aux_var)
726737
}
738+
Expression::Log2Ceil { value } => {
739+
let const_value = simplify_expr(value, lines, counters, array_manager, const_malloc)
740+
.as_constant()
741+
.unwrap();
742+
SimpleExpr::Constant(ConstExpression::Log2Ceil {
743+
value: Box::new(const_value),
744+
})
745+
}
727746
}
728747
}
729748

@@ -884,6 +903,9 @@ fn inline_expr(expr: &mut Expression, args: &BTreeMap<Var, SimpleExpr>, inlining
884903
inline_expr(left, args, inlining_count);
885904
inline_expr(right, args, inlining_count);
886905
}
906+
Expression::Log2Ceil { value } => {
907+
inline_expr(value, args, inlining_count);
908+
}
887909
}
888910
}
889911

@@ -1036,6 +1058,9 @@ fn vars_in_expression(expr: &Expression) -> BTreeSet<Var> {
10361058
vars.extend(vars_in_expression(left));
10371059
vars.extend(vars_in_expression(right));
10381060
}
1061+
Expression::Log2Ceil { value } => {
1062+
vars.extend(vars_in_expression(value));
1063+
}
10391064
}
10401065
vars
10411066
}
@@ -1221,6 +1246,15 @@ fn replace_vars_for_unroll_in_expr(
12211246
internal_vars,
12221247
);
12231248
}
1249+
Expression::Log2Ceil { value } => {
1250+
replace_vars_for_unroll_in_expr(
1251+
value,
1252+
iterator,
1253+
unroll_index,
1254+
iterator_value,
1255+
internal_vars,
1256+
);
1257+
}
12241258
}
12251259
}
12261260

@@ -1434,6 +1468,7 @@ fn replace_vars_for_unroll(
14341468
var,
14351469
size,
14361470
vectorized: _,
1471+
vectorized_len,
14371472
} => {
14381473
assert!(var != iterator, "Weird");
14391474
*var = format!("@unrolled_{unroll_index}_{iterator_value}_{var}");
@@ -1444,7 +1479,13 @@ fn replace_vars_for_unroll(
14441479
iterator_value,
14451480
internal_vars,
14461481
);
1447-
// vectorized is not changed
1482+
replace_vars_for_unroll_in_expr(
1483+
vectorized_len,
1484+
iterator,
1485+
unroll_index,
1486+
iterator_value,
1487+
internal_vars,
1488+
);
14481489
}
14491490
Line::DecomposeBits { var, to_decompose } => {
14501491
assert!(var != iterator, "Weird");
@@ -1734,6 +1775,9 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap<Var, F>)
17341775
replace_vars_by_const_in_expr(left, map);
17351776
replace_vars_by_const_in_expr(right, map);
17361777
}
1778+
Expression::Log2Ceil { value } => {
1779+
replace_vars_by_const_in_expr(value, map);
1780+
}
17371781
}
17381782
}
17391783

@@ -2029,13 +2073,13 @@ impl SimpleLine {
20292073
var,
20302074
size,
20312075
vectorized,
2076+
vectorized_len,
20322077
} => {
2033-
let alloc_type = if *vectorized {
2034-
"malloc_vectorized"
2078+
if *vectorized {
2079+
format!("{var} = malloc_vec({size}, {vectorized_len})")
20352080
} else {
2036-
"malloc"
2037-
};
2038-
format!("{var} = {alloc_type}({size})")
2081+
format!("{var} = malloc({size})")
2082+
}
20392083
}
20402084
Self::ConstMalloc {
20412085
var,

crates/compiler/src/b_compile_intermediate.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,9 +491,20 @@ fn compile_lines(
491491

492492
SimpleLine::FunctionRet { return_data } => {
493493
if compiler.func_name == "main" {
494+
// pC -> ending_pc, fp -> 0
495+
let zero_value_offset = IntermediateValue::MemoryAfterFp {
496+
offset: compiler.stack_size.into(),
497+
};
498+
compiler.stack_size += 1;
499+
instructions.push(IntermediateInstruction::Computation {
500+
operation: Operation::Add,
501+
arg_a: IntermediateValue::Constant(0.into()),
502+
arg_c: IntermediateValue::Constant(0.into()),
503+
res: zero_value_offset.clone(),
504+
});
494505
instructions.push(IntermediateInstruction::Jump {
495506
dest: IntermediateValue::label("@end_program".to_string()),
496-
updated_fp: None,
507+
updated_fp: Some(zero_value_offset),
497508
});
498509
} else {
499510
compile_function_ret(&mut instructions, return_data, compiler);
@@ -504,12 +515,14 @@ fn compile_lines(
504515
var,
505516
size,
506517
vectorized,
518+
vectorized_len,
507519
} => {
508520
declared_vars.insert(var.clone());
509521
instructions.push(IntermediateInstruction::RequestMemory {
510522
offset: compiler.get_offset(&var.clone().into()),
511523
size: IntermediateValue::from_simple_expr(size, compiler),
512524
vectorized: *vectorized,
525+
vectorized_len: IntermediateValue::from_simple_expr(vectorized_len, compiler),
513526
});
514527
}
515528
SimpleLine::ConstMalloc { var, size, label } => {
@@ -631,6 +644,7 @@ fn setup_function_call(
631644
offset: new_fp_pos.into(),
632645
size: ConstExpression::function_size(func_name.to_string()).into(),
633646
vectorized: false,
647+
vectorized_len: IntermediateValue::Constant(ConstExpression::zero()),
634648
},
635649
IntermediateInstruction::Deref {
636650
shift_0: new_fp_pos.into(),

crates/compiler/src/c_compile_final.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,17 @@ fn compile_block(
303303
offset,
304304
size,
305305
vectorized,
306+
vectorized_len,
306307
} => {
307308
let size = try_as_mem_or_constant(&size).unwrap();
309+
let vectorized_len = try_as_constant(&vectorized_len, compiler)
310+
.unwrap()
311+
.to_usize();
308312
let hint = Hint::RequestMemory {
309313
offset: eval_const_expression_usize(&offset, compiler),
310314
vectorized,
311315
size,
316+
vectorized_len,
312317
};
313318
hints.entry(pc).or_default().push(hint);
314319
}

crates/compiler/src/grammar.pest

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,11 @@ div_expr = { exp_expr ~ ("/" ~ exp_expr)* }
7272
exp_expr = { primary ~ ("**" ~ primary)* }
7373
primary = {
7474
"(" ~ expression ~ ")" |
75+
log2_ceil_expr |
7576
array_access_expr |
7677
var_or_constant
7778
}
79+
log2_ceil_expr = { "log2_ceil" ~ "(" ~ expression ~ ")" }
7880
array_access_expr = { identifier ~ "[" ~ expression ~ "]" }
7981

8082
// Basic elements

crates/compiler/src/intermediate_bytecode.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ pub enum IntermediateInstruction {
134134
RequestMemory {
135135
offset: ConstExpression, // m[fp + offset] where the hint will be stored
136136
size: IntermediateValue, // the hint
137-
vectorized: bool, // if true, will be 8-alligned, and the returned pointer will be "divied" by 8 (i.e. everything is in chunks of 8 field elements)
137+
vectorized: bool, // if true, will be (2^vectorized_len)-alligned, and the returned pointer will be "divied" by 2^vectorized_len
138+
vectorized_len: IntermediateValue,
138139
},
139140
DecomposeBits {
140141
res_offset: usize, // m[fp + res_offset..fp + res_offset + 31 * len(to_decompose)] will contain the decomposed bits
@@ -297,13 +298,17 @@ impl Display for IntermediateInstruction {
297298
offset,
298299
size,
299300
vectorized,
300-
} => write!(
301-
f,
302-
"m[fp + {}] = {}({})",
303-
offset,
304-
if *vectorized { "malloc_vec" } else { "malloc" },
305-
size
306-
),
301+
vectorized_len,
302+
} => {
303+
if *vectorized {
304+
write!(
305+
f,
306+
"m[fp + {offset}] = request_memory_vec({size}, {vectorized_len})"
307+
)
308+
} else {
309+
write!(f, "m[fp + {offset}] = request_memory({size})")
310+
}
311+
}
307312
Self::Print { line_info, content } => {
308313
write!(f, "print {line_info}: ")?;
309314
for (i, c) in content.iter().enumerate() {

0 commit comments

Comments
 (0)