Skip to content

Commit

Permalink
add defaults and continue implementing polyvec
Browse files Browse the repository at this point in the history
  • Loading branch information
supinie committed Mar 14, 2024
1 parent 3b24e0f commit bcdde44
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 45 deletions.
15 changes: 15 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use crate::params::K;
use core::fmt::{Display, Formatter};

#[derive(Debug, PartialEq, Eq)]
pub enum CrystalsError {
MismatchedSecurityLevels(K, K),
}

impl Display for CrystalsError {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
match *self {
Self::MismatchedSecurityLevels(sec_level_1, sec_level_2) => write!(f, "Mismatched security levels when attempting operation: {sec_level_1:#?} and {sec_level_2:#?}"),
}
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#![no_std]
#![allow(clippy::needless_range_loop)]

mod errors;
mod field_operations;
// mod indcpa;
pub mod kem;
Expand Down
10 changes: 4 additions & 6 deletions src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ pub const SHAREDSECRETBYTES: usize = 32;

pub const POLYBYTES: usize = 384;

#[derive(Clone, Copy, Debug, PartialEq, Eq, IntoPrimitive)]
#[repr(u8)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, IntoPrimitive)]
#[repr(usize)]
// Get the usize repr using .into()
pub enum K {
Two = 2,
#[default]
Three = 3,
Four = 4,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, IntoPrimitive)]
#[repr(usize)]
// Get the usize repr using .into()
pub enum Eta {
Two = 2,
Three = 3,
Expand All @@ -32,10 +34,6 @@ pub enum SecurityLevel {
TenTwoFour { k: K, eta_1: Eta, eta_2: Eta },
}

pub trait GetSecLevel {
fn sec_level() -> SecurityLevel;
}

impl SecurityLevel {
pub const fn new(k: K) -> Self {
match k {
Expand Down
18 changes: 11 additions & 7 deletions src/polynomials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@ pub struct Poly<S: State> {
}

// Normalised coefficients lie within {0..q-1}
#[derive(Default)]
pub struct Normalised;
#[derive(Default)]
pub struct Unnormalised;

pub trait State {}
pub trait State: Default {}
impl State for Normalised {}
impl State for Unnormalised {}

impl Default for Poly<Normalised> {
/// In all cases, `new()` should be used instead, else the state may be incorrect.
/// Default is defined here for `ArrayVec`.
impl<S: State> Default for Poly<S> {
fn default() -> Self {
Self {
coeffs: [0; N],
state: Normalised,
state: Default::default(),
}
}
}
Expand All @@ -38,7 +42,7 @@ impl<S: State> Poly<S> {
/// ```
/// let new_poly = poly1.add(&poly2);
/// ```
fn add(&self, x: &Self) -> Poly<Unnormalised> {
pub(crate) fn add<T: State>(&self, x: &Poly<T>) -> Poly<Unnormalised> {
let coeffs_arr: [i16; N] = self
.coeffs
.iter()
Expand All @@ -57,7 +61,7 @@ impl<S: State> Poly<S> {
/// ```
/// let new_poly = poly1.sub(&poly2);
/// ```
pub(crate) fn sub(&self, x: &Self) -> Poly<Unnormalised> {
pub(crate) fn sub<T: State>(&self, x: &Poly<T>) -> Poly<Unnormalised> {
let coeffs_arr: [i16; N] = self
.coeffs
.iter()
Expand Down Expand Up @@ -116,7 +120,7 @@ impl<S: State> Poly<S> {
/// Example:
/// ```
/// let new_poly = poly1.pointwise_mul(&poly2);
pub(crate) fn pointwise_mul(&self, x: &Self) -> Poly<Unnormalised> {
pub(crate) fn pointwise_mul<T: State>(&self, x: &Poly<T>) -> Poly<Unnormalised> {
let mut coeffs_arr = self.coeffs;
for ((chunk, x_chunk), &zeta) in coeffs_arr
.chunks_mut(4)
Expand Down Expand Up @@ -152,7 +156,7 @@ impl Poly<Unnormalised> {
/// ```
/// let new_poly = poly.normalise();
/// ```
fn normalise(&self) -> Poly<Normalised> {
pub(crate) fn normalise(&self) -> Poly<Normalised> {
let coeffs_arr: [i16; N] = self
.coeffs
.iter()
Expand Down
7 changes: 6 additions & 1 deletion src/polynomials/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl Poly<Normalised> {
// ```
// new_poly = poly.inv_ntt();
// ```
pub(crate) fn inv_ntt(&self) {
pub(crate) fn inv_ntt(&self) -> Self {
let mut coeffs = self.coeffs;
let mut k: usize = 127;

Expand All @@ -97,5 +97,10 @@ impl Poly<Normalised> {
for coeff in &mut coeffs {
*coeff = montgomery_reduce(1441 * i32::from(*coeff));
}

Self {
coeffs,
state: Normalised,
}
}
}
118 changes: 87 additions & 31 deletions src/vectors.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,119 @@
// use core::num::TryFromIntError;

use crate::{
// matrix::{Mat1024, Mat512, Mat768},
params::{Eta, GetSecLevel, SecurityLevel, K, N, POLYBYTES, Q},
polynomials::{Normalised, Poly, State, Unnormalised},
errors::CrystalsError,
params::{SecurityLevel, K},
polynomials::{Normalised, Poly, State, Unnormalised}
};
use tinyvec::array_vec;
use tinyvec::ArrayVec;

struct PolyVec<S: State> {
#[derive(Default)]
pub struct PolyVec<S: State> {
polynomials: ArrayVec<[Poly<S>; 4]>,
sec_level: K,
}

impl<S: State> PolyVec<S> {
fn new(k: K) -> PolyVec<Normalised> {
let polynomials_arr = match k {
const fn sec_level(&self) -> SecurityLevel {
SecurityLevel::new(self.sec_level)
}

fn polynomials(&self) -> &[Poly<S>] {
&self.polynomials.as_slice()[..self.sec_level.into()]
}

fn add<T: State>(&self, addend: &PolyVec<T>) -> Result<PolyVec<Unnormalised>, CrystalsError> {
if self.sec_level == addend.sec_level {
let mut polynomials = ArrayVec::<[Poly<Unnormalised>; 4]>::new();
for (augend_poly, addend_poly) in self.polynomials.iter().zip(addend.polynomials.iter()) {
polynomials.push(augend_poly.add(addend_poly));
}

Ok(
PolyVec {
polynomials,
sec_level: self.sec_level,
}
)
} else {
Err(CrystalsError::MismatchedSecurityLevels(self.sec_level, addend.sec_level))
}
}

fn barrett_reduce(&self) -> PolyVec<Unnormalised> {
let mut polynomials = ArrayVec::<[Poly<Unnormalised>; 4]>::new();
for poly in self.polynomials.iter() {
polynomials.push(poly.barrett_reduce());
}

PolyVec {
polynomials,
sec_level: self.sec_level,
}
}
}

impl PolyVec<Unnormalised> {
fn normalise(&self) -> PolyVec<Normalised> {
let mut polynomials = ArrayVec::<[Poly<Normalised>; 4]>::new();
for poly in self.polynomials.iter() {
polynomials.push(poly.normalise());
}

PolyVec {
polynomials,
sec_level: self.sec_level,
}
}
}

impl PolyVec<Normalised> {
fn new(k: K) -> Self {
let polynomials = match k {
K::Two => array_vec!([Poly<Normalised>; 4] => Poly::new(), Poly::new()),
K::Three => array_vec!([Poly<Normalised>; 4] => Poly::new(), Poly::new(), Poly::new()),
K::Four => {
array_vec!([Poly<Normalised>; 4] => Poly::new(), Poly::new(), Poly::new(), Poly::new())
}
};

PolyVec {
polynomials: polynomials_arr,
Self {
polynomials,
sec_level: k,
}
}

fn ntt(&self) -> PolyVec<Unnormalised> {
let mut polynomials = ArrayVec::<[Poly<Unnormalised>; 4]>::new();
for poly in self.polynomials.iter() {
polynomials.push(poly.ntt());
}

PolyVec {
polynomials,
sec_level: self.sec_level,
}
}

fn inv_ntt(&self) -> Self {
let mut polynomials = ArrayVec::<[Poly<Normalised>; 4]>::new();
for poly in self.polynomials.iter() {
polynomials.push(poly.inv_ntt());
}

Self {
polynomials,
sec_level: self.sec_level,
}
}
}

// make a normalised impl block for safety

struct Matrix<S: State> {
vectors: ArrayVec<[PolyVec<S>; 4]>,
sec_level: K,
}

// pub type PolyVec512 = ArrayVec<[Poly; 2]>;
// pub type PolyVec768 = ArrayVec<[Poly; 3]>;
// pub type PolyVec1024 = ArrayVec<[Poly; 4]>;

// impl GetSecLevel for PolyVec512 {
// fn sec_level() -> SecurityLevel {
// SecurityLevel::new(K::Two)
// }
// }

// impl GetSecLevel for PolyVec768 {
// fn sec_level() -> SecurityLevel {
// SecurityLevel::new(K::Three)
// }
// }

// impl GetSecLevel for PolyVec1024 {
// fn sec_level() -> SecurityLevel {
// SecurityLevel::new(K::Four)
// }
// }

// trait SameSecLevel {}

Expand Down

0 comments on commit bcdde44

Please sign in to comment.