Skip to content

Commit

Permalink
naive fri
Browse files Browse the repository at this point in the history
  • Loading branch information
katat committed Dec 6, 2023
1 parent baab065 commit 751b7aa
Show file tree
Hide file tree
Showing 7 changed files with 475 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
book
target
75 changes: 75 additions & 0 deletions src/fri/code/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions src/fri/code/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
name = "demo"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
rand = "0.8.5"
20 changes: 20 additions & 0 deletions src/fri/code/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# what is this
This is an ongoing project to implement a FRI prover and verifier in rust.
The code aims to serve as a walkthrough of the FRI protocol.

# run test
`cargo test `

# plans
## Naive FRI
It is to create layers with folded polynomial and then sample the result. No commitment involved.

This is to demonstrate the natural structure of FRI, drawing the connection between the math and the idea of low degree testing.

From PCP(probabilistic checkable proof) perspective, the naive FRI can save verifying time by sampling resulted layers.

## FRI with commitment
*todo*

## Soundness calculation
*todo*
1 change: 1 addition & 0 deletions src/fri/code/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod poly;
212 changes: 212 additions & 0 deletions src/fri/code/src/poly.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
use std::collections::HashMap; // For generating random numbers

#[derive(Debug, Clone)]
struct Polynomial {
coefficients: Vec<i32>,
}

impl Polynomial {
// Constructor for the Polynomial struct
pub fn new(coefficients: Vec<i32>) -> Self {
Polynomial { coefficients }
}

// Function to evaluate the polynomial at a given value of x
pub fn evaluate(&self, x: i32) -> i32 {
self.coefficients
.iter()
.enumerate()
.fold(0, |acc, (power, &coeff)| acc + coeff * x.pow(power as u32))
}

pub fn add(&self, other: &Polynomial) -> Polynomial {
let max_len = usize::max(self.coefficients.len(), other.coefficients.len());
let mut result = vec![0; max_len];

for i in 0..max_len {
let a = *self.coefficients.get(i).unwrap_or(&0);
let b = *other.coefficients.get(i).unwrap_or(&0);
result[i] = a + b;
}

Polynomial::new(result)
}
}

impl PartialEq for Polynomial {
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients
}
}

// Trait for displaying a polynomial
trait DisplayPolynomial {
fn format(&self) -> String;
}

// Implementing DisplayPolynomial for Polynomial
impl DisplayPolynomial for Polynomial {
fn format(&self) -> String {
let mut formatted_string = String::new();
for (i, &coeff) in self.coefficients.iter().enumerate() {
if coeff != 0 {
if i == 0 {
// First coefficient
formatted_string.push_str(&format!("{}", coeff));
} else {
// Add + sign for positive coefficients, except the first term
if formatted_string.len() > 0 && coeff > 0 {
formatted_string.push_str(" + ");
} else if coeff < 0 {
formatted_string.push_str(" - ");
}

// Add coefficient (absolute value) and variable part
let abs_coeff = coeff.abs();
if abs_coeff != 1 {
formatted_string.push_str(&format!("{}", abs_coeff));
}
formatted_string.push_str(&format!("x^{}", i));
}
}
}
formatted_string
}
}

fn fold_polynomial(poly: &Polynomial, beta: i32) -> Polynomial {
let even_coef: Vec<i32> = poly.coefficients.iter().step_by(2).cloned().collect();
let odd_coef: Vec<i32> = poly
.coefficients
.iter()
.skip(1)
.step_by(2)
.map(|&coef| coef * beta) // Multiply each odd coefficient by beta
.collect();

let even_poly = Polynomial::new(even_coef);
let odd_poly = Polynomial::new(odd_coef);

even_poly.add(&odd_poly)
}

fn recursively_fold_polynomials(poly: Polynomial, beta: i32) -> Vec<Polynomial> {
let mut folded_polynomials = Vec::new();
let mut current_poly = poly;

loop {
folded_polynomials.push(current_poly.clone());

// Check if the degree of the current polynomial is 0
if current_poly.coefficients.len() <= 1 {
break;
}

// Fold the polynomial
current_poly = fold_polynomial(&current_poly, beta);
}

folded_polynomials
}

fn create_layers(x_values: &[i32], poly_by_layer: Vec<Polynomial>) -> Vec<HashMap<i32, i32>> {
let mut layer_evals: Vec<HashMap<i32, i32>> = Vec::new();

for &x in x_values {
println!("----\nz = {}\n", x);

for (i, poly) in poly_by_layer.iter().enumerate() {
let degree = poly.coefficients.len() - 1;
let exponent = 2i32.pow(i as u32);
let elm_point = x.pow(exponent as u32);
let symmetric_elm_point = -elm_point;
let elm = poly.evaluate(elm_point);
let symmetric_elm = poly.evaluate(symmetric_elm_point);

let poly_info = format!(
"p(x) at layer {}: {:?}, degree: {}",
i,
poly.format(),
degree
);
let evaluations = if i == 0 {
format!(
"y_{} = z^{}, p(y_{}) = {}, p(-y_{}) = {}",
i, exponent, i, elm, i, symmetric_elm
)
} else {
format!(
"y_{} = y_{}^2 = z^{}, p(y_{}) = {}, p(-y_{}) = {}",
i,
i - 1,
exponent,
i,
elm,
i,
symmetric_elm
)
};

println!("{}\n{}\n", poly_info, evaluations);

if layer_evals.get(i).is_none() {
layer_evals.push(HashMap::new());
}

let point_to_eval = layer_evals.get_mut(i).unwrap();
point_to_eval.insert(elm_point, elm);
point_to_eval.insert(symmetric_elm_point, symmetric_elm);
}
}

layer_evals
}

fn check_layers(layer_evals: &[HashMap<i32, i32>], query: i32, beta: i32) {
for (i, layer) in layer_evals.iter().enumerate() {
println!("current layer: {}", i);

// Skip first layer
if i == 0 {
continue;
}

let prev_exponent = 2i32.pow((i - 1) as u32);
let prev_query_point = query.pow(prev_exponent as u32);
let prev_symmetric_query_point = -(query).pow(prev_exponent as u32);

let prev_layer = layer_evals.get(i - 1).unwrap();
let prev_elm = prev_layer.get(&prev_query_point).unwrap();
let prev_symmetric_elm = prev_layer.get(&prev_symmetric_query_point).unwrap();

let current_query_point = prev_query_point.pow(2);
println!("prev_query_point: {}", current_query_point);

let current_elm = layer.get(&current_query_point).unwrap();

let expected = (prev_elm + prev_symmetric_elm) / 2
+ beta * (prev_elm - prev_symmetric_elm) / (2 * prev_query_point);

assert_eq!(expected, *current_elm);
}

// Assert the values in the last layer are the same
let last_layer = layer_evals.last().unwrap();
let mut values = std::collections::HashSet::new();
for (_, &value) in last_layer.iter() {
values.insert(value);
}
assert_eq!(values.len(), 1);
}

#[test]
fn naive_query() {
let beta = 1;
let poly = Polynomial::new(vec![1, 2, 3]);
let poly_by_layer = recursively_fold_polynomials(poly, beta);

let layer_evals = create_layers(&[1, 2, 3], poly_by_layer);

let query = 2;
check_layers(&layer_evals, query, beta);
}
Loading

0 comments on commit 751b7aa

Please sign in to comment.