Skip to content

Commit

Permalink
feat: update memory chiplet to be element-addressable
Browse files Browse the repository at this point in the history
  • Loading branch information
plafer committed Dec 12, 2024
1 parent b7ff6c8 commit 74efdb5
Show file tree
Hide file tree
Showing 28 changed files with 1,252 additions and 770 deletions.
6 changes: 3 additions & 3 deletions air/src/constraints/chiplets/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use winter_air::TransitionConstraintDegree;
use super::{EvaluationFrame, FieldElement};
use crate::{
trace::chiplets::{
memory::NUM_ELEMENTS, MEMORY_ADDR_COL_IDX, MEMORY_CLK_COL_IDX, MEMORY_CTX_COL_IDX,
memory::NUM_ELEMENTS_IN_BATCH, MEMORY_ADDR_COL_IDX, MEMORY_CLK_COL_IDX, MEMORY_CTX_COL_IDX,
MEMORY_D0_COL_IDX, MEMORY_D1_COL_IDX, MEMORY_D_INV_COL_IDX, MEMORY_TRACE_OFFSET,
MEMORY_V_COL_RANGE,
},
Expand Down Expand Up @@ -152,13 +152,13 @@ fn enforce_values<E: FieldElement>(
let mut index = 0;

// initialize memory to zero when reading from new context and address pair.
for i in 0..NUM_ELEMENTS {
for i in 0..NUM_ELEMENTS_IN_BATCH {
result[index] = memory_flag * frame.init_read_flag() * frame.v(i);
index += 1;
}

// copy previous values when reading memory that was previously accessed.
for i in 0..NUM_ELEMENTS {
for i in 0..NUM_ELEMENTS_IN_BATCH {
result[index] = memory_flag * frame.copy_read_flag() * (frame.v_next(i) - frame.v(i));
index += 1;
}
Expand Down
13 changes: 7 additions & 6 deletions air/src/constraints/chiplets/memory/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ use rand_utils::rand_value;

use super::{
EvaluationFrame, MEMORY_ADDR_COL_IDX, MEMORY_CLK_COL_IDX, MEMORY_CTX_COL_IDX,
MEMORY_D0_COL_IDX, MEMORY_D1_COL_IDX, MEMORY_D_INV_COL_IDX, MEMORY_V_COL_RANGE, NUM_ELEMENTS,
MEMORY_D0_COL_IDX, MEMORY_D1_COL_IDX, MEMORY_D_INV_COL_IDX, MEMORY_V_COL_RANGE,
NUM_ELEMENTS_IN_BATCH,
};
use crate::{
chiplets::memory,
trace::{
chiplets::{
memory::{Selectors, MEMORY_COPY_READ, MEMORY_INIT_READ, MEMORY_WRITE},
memory::{Selectors, MEMORY_COPY_READ, MEMORY_INIT_READ, MEMORY_WRITE_SELECTOR},
MEMORY_TRACE_OFFSET,
},
TRACE_WIDTH,
Expand All @@ -30,7 +31,7 @@ fn test_memory_write() {

// Write to a new context.
let result = get_constraint_evaluation(
MEMORY_WRITE,
MEMORY_WRITE_SELECTOR,
MemoryTestDeltaType::Context,
&old_values,
&new_values,
Expand All @@ -39,7 +40,7 @@ fn test_memory_write() {

// Write to a new address in the same context.
let result = get_constraint_evaluation(
MEMORY_WRITE,
MEMORY_WRITE_SELECTOR,
MemoryTestDeltaType::Address,
&old_values,
&new_values,
Expand All @@ -48,7 +49,7 @@ fn test_memory_write() {

// Write to the same context and address at a new clock cycle.
let result = get_constraint_evaluation(
MEMORY_WRITE,
MEMORY_WRITE_SELECTOR,
MemoryTestDeltaType::Clock,
&old_values,
&new_values,
Expand Down Expand Up @@ -160,7 +161,7 @@ fn get_test_frame(
next[MEMORY_CLK_COL_IDX] = Felt::new(delta_row[2]);

// Set the old and new values.
for idx in 0..NUM_ELEMENTS {
for idx in 0..NUM_ELEMENTS_IN_BATCH {
let old_value = Felt::new(old_values[idx] as u64);
// Add a write for the old values to the current row.
current[MEMORY_V_COL_RANGE.start + idx] = old_value;
Expand Down
41 changes: 32 additions & 9 deletions air/src/trace/chiplets/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use super::{create_range, Felt, Range, ONE, ZERO};
// ================================================================================================

/// Number of columns needed to record an execution trace of the memory chiplet.
pub const TRACE_WIDTH: usize = 12;
pub const TRACE_WIDTH: usize = 15;

// TODO(plafer): get rid of all "selector" constants
/// Number of selector columns in the trace.
pub const NUM_SELECTORS: usize = 2;

Expand All @@ -15,16 +16,27 @@ pub const NUM_SELECTORS: usize = 2;
/// read / write) is to be applied at a specific row of the memory execution trace.
pub type Selectors = [Felt; NUM_SELECTORS];

// --- OPERATION SELECTORS ------------------------------------------------------------------------

/// Specifies an operation that initializes new memory and then reads it.
pub const MEMORY_INIT_READ: Selectors = [ONE, ZERO];

/// Specifies an operation that copies existing memory and then reads it.
pub const MEMORY_COPY_READ: Selectors = [ONE, ONE];

/// Specifies a memory write operation.
pub const MEMORY_WRITE: Selectors = [ZERO, ZERO];
pub const MEMORY_WRITE_SELECTOR: Selectors = [ZERO, ZERO];

// --- OPERATION SELECTORS ------------------------------------------------------------------------

/// Specifies the value of the `READ_WRITE` column when the operation is a write.
pub const MEMORY_WRITE: Felt = ZERO;
/// Specifies the value of the `READ_WRITE` column when the operation is a read.
pub const MEMORY_READ: Felt = ONE;
/// Specifies the value of the `ELEMENT_OR_WORD` column when the operation is over an element.
pub const MEMORY_ACCESS_ELEMENT: Felt = ZERO;
/// Specifies the value of the `ELEMENT_OR_WORD` column when the operation is over a word.
pub const MEMORY_ACCESS_WORD: Felt = ONE;

// TODO(plafer): figure out the new labels

/// Unique label computed as 1 plus the full chiplet selector with the bits reversed.
/// mem_read selector=[1, 1, 0, 1], rev(selector)=[1, 0, 1, 1], +1=[1, 1, 0, 0]
Expand All @@ -37,17 +49,25 @@ pub const MEMORY_WRITE_LABEL: u8 = 0b0100;
// --- COLUMN ACCESSOR INDICES WITHIN THE CHIPLET -------------------------------------------------

/// The number of elements accessible in one read or write memory access.
pub const NUM_ELEMENTS: usize = 4;
pub const NUM_ELEMENTS_IN_BATCH: usize = 4;

/// Column to hold the whether the operation is a read or write.
pub const READ_WRITE_COL_IDX: usize = 0;
/// Column to hold the whether the operation was over an element or a word.
pub const ELEMENT_OR_WORD_COL_IDX: usize = READ_WRITE_COL_IDX + 1;
/// Column to hold the context ID of the current memory context.
pub const CTX_COL_IDX: usize = NUM_SELECTORS;
pub const CTX_COL_IDX: usize = ELEMENT_OR_WORD_COL_IDX + 1;
/// Column to hold the memory address.
pub const ADDR_COL_IDX: usize = CTX_COL_IDX + 1;
pub const BATCH_COL_IDX: usize = CTX_COL_IDX + 1;
/// Column to hold the first bit of the index of the address in the batch.
pub const IDX0_COL_IDX: usize = BATCH_COL_IDX + 1;
/// Column to hold the second bit of the index of the address in the batch.
pub const IDX1_COL_IDX: usize = IDX0_COL_IDX + 1;
/// Column for the clock cycle in which the memory operation occurred.
pub const CLK_COL_IDX: usize = ADDR_COL_IDX + 1;
pub const CLK_COL_IDX: usize = IDX1_COL_IDX + 1;
/// Columns to hold the values stored at a given memory context, address, and clock cycle after
/// the memory operation. When reading from a new address, these are initialized to zero.
pub const V_COL_RANGE: Range<usize> = create_range(CLK_COL_IDX + 1, NUM_ELEMENTS);
pub const V_COL_RANGE: Range<usize> = create_range(CLK_COL_IDX + 1, NUM_ELEMENTS_IN_BATCH);
/// Column for the lower 16-bits of the delta between two consecutive context IDs, addresses, or
/// clock cycles.
pub const D0_COL_IDX: usize = V_COL_RANGE.end;
Expand All @@ -57,3 +77,6 @@ pub const D1_COL_IDX: usize = D0_COL_IDX + 1;
/// Column for the inverse of the delta between two consecutive context IDs, addresses, or clock
/// cycles, used to enforce that changes are correctly constrained.
pub const D_INV_COL_IDX: usize = D1_COL_IDX + 1;
/// Column to hold the flag indicating whether the current memory operation is in the same batch and
/// same context as the previous operation.
pub const FLAG_SAME_BATCH_AND_CONTEXT: usize = D_INV_COL_IDX + 1;
2 changes: 1 addition & 1 deletion air/src/trace/chiplets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub const MEMORY_SELECTORS_COL_IDX: usize = MEMORY_TRACE_OFFSET;
/// The index within the main trace of the column containing the memory context.
pub const MEMORY_CTX_COL_IDX: usize = MEMORY_TRACE_OFFSET + memory::CTX_COL_IDX;
/// The index within the main trace of the column containing the memory address.
pub const MEMORY_ADDR_COL_IDX: usize = MEMORY_TRACE_OFFSET + memory::ADDR_COL_IDX;
pub const MEMORY_ADDR_COL_IDX: usize = MEMORY_TRACE_OFFSET + memory::BATCH_COL_IDX;
/// The index within the main trace of the column containing the clock cycle of the memory
/// access.
pub const MEMORY_CLK_COL_IDX: usize = MEMORY_TRACE_OFFSET + memory::CLK_COL_IDX;
Expand Down
2 changes: 1 addition & 1 deletion air/src/trace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub const RANGE_CHECK_TRACE_RANGE: Range<usize> =

// Chiplets trace
pub const CHIPLETS_OFFSET: usize = RANGE_CHECK_TRACE_RANGE.end;
pub const CHIPLETS_WIDTH: usize = 17;
pub const CHIPLETS_WIDTH: usize = 18;
pub const CHIPLETS_RANGE: Range<usize> = range(CHIPLETS_OFFSET, CHIPLETS_WIDTH);

pub const TRACE_WIDTH: usize = CHIPLETS_OFFSET + CHIPLETS_WIDTH;
Expand Down
9 changes: 5 additions & 4 deletions assembly/src/assembler/instruction/mem_ops.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::string::ToString;

use vm_core::{Felt, Operation::*};
use vm_core::{Felt, Operation::*, WORD_SIZE};

use super::{push_felt, push_u32_value, validate_param, BasicBlockBuilder};
use crate::{assembler::ProcedureContext, diagnostics::Report, AssemblyError};
Expand Down Expand Up @@ -111,7 +111,7 @@ pub fn mem_write_imm(
/// Returns an error if index is greater than the number of procedure locals.
pub fn local_to_absolute_addr(
block_builder: &mut BasicBlockBuilder,
index: u16,
index_of_local: u16,
num_proc_locals: u16,
) -> Result<(), AssemblyError> {
if num_proc_locals == 0 {
Expand All @@ -125,9 +125,10 @@ pub fn local_to_absolute_addr(
}

let max = num_proc_locals - 1;
validate_param(index, 0..=max)?;
validate_param(index_of_local, 0..=max)?;

push_felt(block_builder, -Felt::from(max - index));
let fmp_offset_of_local = (max - index_of_local) * WORD_SIZE as u16;
push_felt(block_builder, -Felt::from(fmp_offset_of_local));
block_builder.push_op(FmpAdd);

Ok(())
Expand Down
14 changes: 8 additions & 6 deletions assembly/src/assembler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use vm_core::{
crypto::hash::RpoDigest,
debuginfo::SourceSpan,
mast::{DecoratorId, MastNodeId},
DecoratorList, Felt, Kernel, Operation, Program,
DecoratorList, Felt, Kernel, Operation, Program, WORD_SIZE,
};

use crate::{
Expand Down Expand Up @@ -574,12 +574,14 @@ impl Assembler {
let proc_body_id = if num_locals > 0 {
// for procedures with locals, we need to update fmp register before and after the
// procedure body is executed. specifically:
// - to allocate procedure locals we need to increment fmp by the number of locals
// - to deallocate procedure locals we need to decrement it by the same amount
let num_locals = Felt::from(num_locals);
// - to allocate procedure locals we need to increment fmp by 4 times the number of
// locals
// - to deallocate procedure locals we need to decrement it by the same amount We leave
// 4 elements between locals to properly support reading and writing words to locals.
let locals_frame = Felt::from(num_locals * WORD_SIZE as u16);
let wrapper = BodyWrapper {
prologue: vec![Operation::Push(num_locals), Operation::FmpUpdate],
epilogue: vec![Operation::Push(-num_locals), Operation::FmpUpdate],
prologue: vec![Operation::Push(locals_frame), Operation::FmpUpdate],
epilogue: vec![Operation::Push(-locals_frame), Operation::FmpUpdate],
};
self.compile_body(proc.iter(), &mut proc_ctx, Some(wrapper), mast_forest_builder)?
} else {
Expand Down
11 changes: 6 additions & 5 deletions assembly/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1800,18 +1800,19 @@ fn program_with_proc_locals() -> TestResult {
mul \
end \
begin \
push.4 push.3 push.2 \
push.10 push.9 push.8 \
exec.foo \
end"
);
let program = context.assemble(source)?;
// Note: 18446744069414584317 == -4 (mod 2^64 - 2^32 + 1)
let expected = "\
begin
basic_block
push(10)
push(9)
push(8)
push(4)
push(3)
push(2)
push(1)
fmpupdate
pad
fmpadd
Expand All @@ -1822,7 +1823,7 @@ begin
fmpadd
mload
mul
push(18446744069414584320)
push(18446744069414584317)
fmpupdate
end
end";
Expand Down
9 changes: 4 additions & 5 deletions miden/src/cli/debug/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl DebugExecutor {

/// print all memory entries.
pub fn print_memory(&self) {
for (address, mem) in self.vm_state.memory.iter() {
for &(address, mem) in self.vm_state.memory.iter() {
Self::print_memory_data(address, mem)
}
}
Expand All @@ -167,7 +167,7 @@ impl DebugExecutor {
});

match entry {
Some(mem) => Self::print_memory_data(&address, mem),
Some(&mem) => Self::print_memory_data(address, mem),
None => println!("memory at address '{address}' not found"),
}
}
Expand All @@ -176,9 +176,8 @@ impl DebugExecutor {
// --------------------------------------------------------------------------------------------

/// print memory data.
fn print_memory_data(address: &u64, memory: &[Felt]) {
let mem_int = memory.iter().map(|&x| x.as_int()).collect::<Vec<_>>();
println!("{address} {mem_int:?}");
fn print_memory_data(address: u64, mem_value: Felt) {
println!("{address} {mem_value:?}");
}

/// print help message
Expand Down
21 changes: 10 additions & 11 deletions miden/src/repl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::BTreeSet, path::PathBuf};

use assembly::{Assembler, Library};
use miden_vm::{math::Felt, DefaultHost, StackInputs, Word};
use miden_vm::{math::Felt, DefaultHost, StackInputs};
use processor::ContextId;
use rustyline::{error::ReadlineError, DefaultEditor};
use stdlib::StdLibrary;
Expand Down Expand Up @@ -171,7 +171,7 @@ pub fn start_repl(library_paths: &Vec<PathBuf>, use_stdlib: bool) {
let mut should_print_stack = false;

// state of the entire memory at the latest clock cycle.
let mut memory: Vec<(u64, Word)> = Vec::new();
let mut memory: Vec<(u64, Felt)> = Vec::new();

// initializing readline.
let mut rl = DefaultEditor::new().expect("Readline couldn't be initialized");
Expand Down Expand Up @@ -224,9 +224,9 @@ pub fn start_repl(library_paths: &Vec<PathBuf>, use_stdlib: bool) {
println!("The memory has not been initialized yet");
continue;
}
for (addr, mem) in &memory {
for &(addr, mem) in &memory {
// prints out the address and memory value at that address.
print_mem_address(*addr, mem);
print_mem_address(addr, mem);
}
} else if line.len() > 6 && &line[..5] == "!mem[" {
// if user wants to see the state of a particular address in a memory, the input
Expand All @@ -238,8 +238,8 @@ pub fn start_repl(library_paths: &Vec<PathBuf>, use_stdlib: bool) {
// extracts the address from user input.
match read_mem_address(&line) {
Ok(addr) => {
for (i, memory_value) in &memory {
if *i == addr {
for &(i, memory_value) in &memory {
if i == addr {
// prints the address and memory value at that address.
print_mem_address(addr, memory_value);
// sets the flag to true as the address has been initialized.
Expand Down Expand Up @@ -305,7 +305,7 @@ pub fn start_repl(library_paths: &Vec<PathBuf>, use_stdlib: bool) {
fn execute(
program: String,
provided_libraries: &[Library],
) -> Result<(Vec<(u64, Word)>, Vec<Felt>), String> {
) -> Result<(Vec<(u64, Felt)>, Vec<Felt>), String> {
// compile program
let mut assembler = Assembler::default();

Expand All @@ -329,7 +329,7 @@ fn execute(
}

// loads the memory at the latest clock cycle.
let mem_state = chiplets.get_mem_state_at(ContextId::root(), system.clk());
let mem_state = chiplets.memory().get_state_at(ContextId::root(), system.clk());
// loads the stack along with the overflow values at the latest clock cycle.
let stack_state = stack.get_state_at(system.clk());

Expand Down Expand Up @@ -404,7 +404,6 @@ fn print_stack(stack: Vec<Felt>) {

/// Accepts and returns a memory at an address by converting its register into integer
/// from Felt.
fn print_mem_address(addr: u64, mem: &Word) {
let mem_int = mem.iter().map(|&x| x.as_int()).collect::<Vec<_>>();
println!("{} {:?}", addr, mem_int)
fn print_mem_address(addr: u64, mem_value: Felt) {
println!("{addr} {mem_value}")
}
Loading

0 comments on commit 74efdb5

Please sign in to comment.