Skip to content

Commit a4d144a

Browse files
committed
transposed_par_iter_mut
1 parent 8624d29 commit a4d144a

File tree

2 files changed

+30
-39
lines changed

2 files changed

+30
-39
lines changed

crates/lean_prover/witness_generation/src/execution_trace.rs

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
use std::mem::transmute;
1+
use std::array;
22

33
use crate::instruction_encoder::field_representation;
44
use crate::{
55
COL_INDEX_FP, COL_INDEX_MEM_ADDRESS_A, COL_INDEX_MEM_ADDRESS_B, COL_INDEX_MEM_ADDRESS_C,
66
COL_INDEX_MEM_VALUE_A, COL_INDEX_MEM_VALUE_B, COL_INDEX_MEM_VALUE_C, COL_INDEX_PC,
7-
N_EXEC_COLUMNS, N_INSTRUCTION_COLUMNS,
7+
N_TOTAL_COLUMNS,
88
};
99
use lean_vm::*;
1010
use p3_field::Field;
1111
use p3_field::PrimeCharacteristicRing;
1212
use rayon::prelude::*;
13-
use utils::{SyncUnsafeCell, ToUsize};
13+
use utils::{ToUsize, transposed_par_iter_mut};
1414

1515
#[derive(Debug)]
1616
pub struct ExecutionTrace {
17-
pub full_trace: Vec<Vec<F>>,
17+
pub full_trace: [Vec<F>; N_TOTAL_COLUMNS],
1818
pub n_poseidons_16: usize,
1919
pub n_poseidons_24: usize,
2020
pub poseidons_16: Vec<WitnessPoseidon16>, // padded with empty poseidons
@@ -34,16 +34,13 @@ pub fn get_execution_trace(
3434
let n_cycles = execution_result.pcs.len();
3535
let memory = &execution_result.memory;
3636
let log_n_cycles_rounded_up = n_cycles.next_power_of_two().ilog2() as usize;
37-
let trace = (0..N_INSTRUCTION_COLUMNS + N_EXEC_COLUMNS)
38-
.map(|_| unsafe { transmute(F::zero_vec(1 << log_n_cycles_rounded_up)) })
39-
.collect::<Vec<Vec<SyncUnsafeCell<F>>>>();
37+
let mut trace: [Vec<F>; N_TOTAL_COLUMNS] =
38+
array::from_fn(|_| F::zero_vec(1 << log_n_cycles_rounded_up));
4039

41-
execution_result
42-
.pcs
43-
.par_iter()
40+
transposed_par_iter_mut(&mut trace)
41+
.zip(execution_result.pcs.par_iter())
4442
.zip(execution_result.fps.par_iter())
45-
.enumerate()
46-
.for_each(|(cycle, (&pc, &fp))| {
43+
.for_each(|((trace_row, &pc), &fp)| {
4744
let instruction = &bytecode.instructions[pc];
4845
let field_repr = field_representation(instruction);
4946

@@ -71,23 +68,19 @@ pub fn get_execution_trace(
7168
}
7269
let value_c = memory.0[addr_c.to_usize()].unwrap();
7370

74-
unsafe {
75-
for (j, field) in field_repr.iter().enumerate() {
76-
*trace[j][cycle].get() = *field;
77-
}
78-
*trace[COL_INDEX_MEM_VALUE_A][cycle].get() = value_a;
79-
*trace[COL_INDEX_MEM_VALUE_B][cycle].get() = value_b;
80-
*trace[COL_INDEX_MEM_VALUE_C][cycle].get() = value_c;
81-
*trace[COL_INDEX_PC][cycle].get() = F::from_usize(pc);
82-
*trace[COL_INDEX_FP][cycle].get() = F::from_usize(fp);
83-
*trace[COL_INDEX_MEM_ADDRESS_A][cycle].get() = addr_a;
84-
*trace[COL_INDEX_MEM_ADDRESS_B][cycle].get() = addr_b;
85-
*trace[COL_INDEX_MEM_ADDRESS_C][cycle].get() = addr_c;
71+
for (j, field) in field_repr.iter().enumerate() {
72+
*trace_row[j] = *field;
8673
}
74+
*trace_row[COL_INDEX_MEM_VALUE_A] = value_a;
75+
*trace_row[COL_INDEX_MEM_VALUE_B] = value_b;
76+
*trace_row[COL_INDEX_MEM_VALUE_C] = value_c;
77+
*trace_row[COL_INDEX_PC] = F::from_usize(pc);
78+
*trace_row[COL_INDEX_FP] = F::from_usize(fp);
79+
*trace_row[COL_INDEX_MEM_ADDRESS_A] = addr_a;
80+
*trace_row[COL_INDEX_MEM_ADDRESS_B] = addr_b;
81+
*trace_row[COL_INDEX_MEM_ADDRESS_C] = addr_c;
8782
});
8883

89-
let mut trace: Vec<Vec<F>> = unsafe { transmute(trace) };
90-
9184
// repeat the last row to get to a power of two
9285
trace.par_iter_mut().for_each(|column| {
9386
let last_value = column[n_cycles - 1];

crates/utils/src/misc.rs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::cell::UnsafeCell;
2-
31
use p3_field::{BasedVectorSpace, ExtensionField, Field, dot_product};
42
use rayon::prelude::*;
53

@@ -135,17 +133,17 @@ pub fn transpose<F: Copy + Send + Sync>(
135133
res
136134
}
137135

138-
#[derive(Debug)]
139-
pub struct SyncUnsafeCell<T>(UnsafeCell<T>);
140-
141-
unsafe impl<T> Sync for SyncUnsafeCell<T> {}
136+
struct SendPtr<T>(*mut T);
137+
unsafe impl<T> Send for SendPtr<T> {}
138+
unsafe impl<T> Sync for SendPtr<T> {}
142139

143-
impl<T> SyncUnsafeCell<T> {
144-
pub fn new(value: T) -> Self {
145-
Self(UnsafeCell::new(value))
146-
}
140+
pub fn transposed_par_iter_mut<A: Send + Sync, const N: usize>(
141+
array: &mut [Vec<A>; N], // all vectors must have the same length
142+
) -> impl IndexedParallelIterator<Item = [&mut A; N]> + '_ {
143+
let len = array[0].len();
144+
let data_ptrs: [SendPtr<A>; N] = array.each_mut().map(|v| SendPtr(v.as_mut_ptr()));
147145

148-
pub fn get(&self) -> *mut T {
149-
self.0.get()
150-
}
146+
(0..len)
147+
.into_par_iter()
148+
.map(move |i| unsafe { std::array::from_fn(|j| &mut *data_ptrs[j].0.add(i)) })
151149
}

0 commit comments

Comments
 (0)