Skip to content

Commit

Permalink
Shards: Use shard_len_64 instead of shard_bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
AndersTrier committed Aug 17, 2024
1 parent 566fde7 commit 772f193
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 61 deletions.
103 changes: 44 additions & 59 deletions src/engine/shards.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,32 @@ use std::ops::{Bound, Index, IndexMut, RangeBounds};

pub(crate) struct Shards {
shard_count: usize,
shard_bytes: usize,
// Shard length in 64 byte chunks
shard_len_64: usize,

// Flat array of `shard_count * shard_bytes` bytes.
// Flat Vec of `shard_count * shard_len_64 * 64` bytes.
data: Vec<[u8; 64]>,
}

impl Shards {
pub(crate) fn as_ref_mut(&mut self) -> ShardsRefMut {
ShardsRefMut::new(self.shard_count, self.shard_bytes, self.data.as_mut())
ShardsRefMut::new(self.shard_count, self.shard_len_64, self.data.as_mut())
}

pub(crate) fn new() -> Self {
Self {
shard_count: 0,
shard_bytes: 0,
shard_len_64: 0,
data: Vec::new(),
}
}

pub(crate) fn resize(&mut self, shard_count: usize, shard_bytes: usize) {
assert!(shard_bytes > 0 && shard_bytes & 63 == 0);

pub(crate) fn resize(&mut self, shard_count: usize, shard_len_64: usize) {
self.shard_count = shard_count;
self.shard_bytes = shard_bytes;
self.shard_len_64 = shard_len_64;

self.data.resize(shard_count * (shard_bytes / 64), [0; 64]);
self.data
.resize(self.shard_count * self.shard_len_64, [0; 64]);
}
}

Expand All @@ -40,8 +40,7 @@ impl Shards {
impl Index<usize> for Shards {
type Output = [[u8; 64]];
fn index(&self, index: usize) -> &Self::Output {
let shard_chunk_count = self.shard_bytes / 64;
&self.data[index * shard_chunk_count..(index + 1) * shard_chunk_count]
&self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
}
}

Expand All @@ -50,20 +49,18 @@ impl Index<usize> for Shards {

impl IndexMut<usize> for Shards {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
let shard_chunk_count = self.shard_bytes / 64;
&mut self.data[index * shard_chunk_count..(index + 1) * shard_chunk_count]
&mut self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
}
}

// ======================================================================
// ShardsRefMut - PUBLIC

/// Mutable reference to shard array implemented as flat byte array.
/// Mutable reference to a shard array.
pub struct ShardsRefMut<'a> {
shard_count: usize,
shard_bytes: usize,
shard_len_64: usize,

// Flat array of `shard_count * shard_bytes` bytes.
data: &'a mut [[u8; 64]],
}

Expand All @@ -82,13 +79,11 @@ impl<'a> ShardsRefMut<'a> {
mut pos: usize,
mut dist: usize,
) -> (&mut [[u8; 64]], &mut [[u8; 64]]) {
let shard_chunk_count = self.shard_bytes / 64;

pos *= shard_chunk_count;
dist *= shard_chunk_count;
pos *= self.shard_len_64;
dist *= self.shard_len_64;

let (a, b) = self.data[pos..].split_at_mut(dist);
(&mut a[..shard_chunk_count], &mut b[..shard_chunk_count])
(&mut a[..self.shard_len_64], &mut b[..self.shard_len_64])
}

/// Returns mutable references to shards at
Expand All @@ -113,20 +108,18 @@ impl<'a> ShardsRefMut<'a> {
&mut [[u8; 64]],
&mut [[u8; 64]],
) {
let shard_chunk_count = self.shard_bytes / 64;

pos *= shard_chunk_count;
dist *= shard_chunk_count;
pos *= self.shard_len_64;
dist *= self.shard_len_64;

let (ab, cd) = self.data[pos..].split_at_mut(dist * 2);
let (a, b) = ab.split_at_mut(dist);
let (c, d) = cd.split_at_mut(dist);

(
&mut a[..shard_chunk_count],
&mut b[..shard_chunk_count],
&mut c[..shard_chunk_count],
&mut d[..shard_chunk_count],
&mut a[..self.shard_len_64],
&mut b[..self.shard_len_64],
&mut c[..self.shard_len_64],
&mut d[..self.shard_len_64],
)
}

Expand All @@ -144,42 +137,40 @@ impl<'a> ShardsRefMut<'a> {
///
/// # Panics
///
/// If `data` is smaller than `shard_count * shard_bytes` bytes.
pub fn new(shard_count: usize, shard_bytes: usize, data: &'a mut [[u8; 64]]) -> Self {
/// If `data.len() < shard_count * shard_len_64`.
pub fn new(shard_count: usize, shard_len_64: usize, data: &'a mut [[u8; 64]]) -> Self {
assert!(data.len() >= shard_count * shard_len_64);

Self {
shard_count,
shard_bytes,
data: &mut data[..shard_count * (shard_bytes / 64)],
shard_len_64,
data: &mut data[..shard_count * shard_len_64],
}
}

/// Splits this [`ShardsRefMut`] into two so that
/// first includes shards `0..mid` and second includes shards `mid..`.
pub fn split_at_mut(&mut self, mid: usize) -> (ShardsRefMut, ShardsRefMut) {
let shard_chunk_count = self.shard_bytes / 64;

let (a, b) = self.data.split_at_mut(mid * shard_chunk_count);
let (a, b) = self.data.split_at_mut(mid * self.shard_len_64);

(
ShardsRefMut::new(mid, self.shard_bytes, a),
ShardsRefMut::new(self.shard_count - mid, self.shard_bytes, b),
ShardsRefMut::new(mid, self.shard_len_64, a),
ShardsRefMut::new(self.shard_count - mid, self.shard_len_64, b),
)
}

/// Fills the given shard-range with `0u8`:s.
pub fn zero<R: RangeBounds<usize>>(&mut self, range: R) {
let shard_chunk_count = self.shard_bytes / 64;

let start = match range.start_bound() {
Bound::Included(start) => start * shard_chunk_count,
Bound::Excluded(start) => (start + 1) * shard_chunk_count,
Bound::Included(start) => start * self.shard_len_64,
Bound::Excluded(start) => (start + 1) * self.shard_len_64,
Bound::Unbounded => 0,
};

let end = match range.end_bound() {
Bound::Included(end) => (end + 1) * shard_chunk_count,
Bound::Excluded(end) => end * shard_chunk_count,
Bound::Unbounded => self.shard_count * shard_chunk_count,
Bound::Included(end) => (end + 1) * self.shard_len_64,
Bound::Excluded(end) => end * self.shard_len_64,
Bound::Unbounded => self.shard_count * self.shard_len_64,
};

self.data[start..end].fill([0; 64]);
Expand All @@ -192,8 +183,7 @@ impl<'a> ShardsRefMut<'a> {
impl<'a> Index<usize> for ShardsRefMut<'a> {
type Output = [[u8; 64]];
fn index(&self, index: usize) -> &Self::Output {
let shard_chunk_count = self.shard_bytes / 64;
&self.data[index * shard_chunk_count..(index + 1) * shard_chunk_count]
&self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
}
}

Expand All @@ -202,8 +192,7 @@ impl<'a> Index<usize> for ShardsRefMut<'a> {

impl<'a> IndexMut<usize> for ShardsRefMut<'a> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
let shard_chunk_count = self.shard_bytes / 64;
&mut self.data[index * shard_chunk_count..(index + 1) * shard_chunk_count]
&mut self.data[index * self.shard_len_64..(index + 1) * self.shard_len_64]
}
}

Expand All @@ -212,11 +201,9 @@ impl<'a> IndexMut<usize> for ShardsRefMut<'a> {

impl<'a> ShardsRefMut<'a> {
pub(crate) fn copy_within(&mut self, mut src: usize, mut dest: usize, mut count: usize) {
let shard_chunk_count = self.shard_bytes / 64;

src *= shard_chunk_count;
dest *= shard_chunk_count;
count *= shard_chunk_count;
src *= self.shard_len_64;
dest *= self.shard_len_64;
count *= self.shard_len_64;

self.data.copy_within(src..src + count, dest);
}
Expand All @@ -231,11 +218,9 @@ impl<'a> ShardsRefMut<'a> {
mut y: usize,
mut count: usize,
) -> (&mut [[u8; 64]], &mut [[u8; 64]]) {
let shard_chunk_count = self.shard_bytes / 64;

x *= shard_chunk_count;
y *= shard_chunk_count;
count *= shard_chunk_count;
x *= self.shard_len_64;
y *= self.shard_len_64;
count *= self.shard_len_64;

if x < y {
let (head, tail) = self.data.split_at_mut(y);
Expand Down
4 changes: 3 additions & 1 deletion src/rate/decoder_work.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ impl DecoderWork {
recovery_base_pos: usize,
work_count: usize,
) {
assert!(shard_bytes % 64 == 0);

self.original_count = original_count;
self.recovery_count = recovery_count;
self.shard_bytes = shard_bytes;
Expand All @@ -176,7 +178,7 @@ impl DecoderWork {
self.received.grow(max_received_pos);
}

self.shards.resize(work_count, shard_bytes);
self.shards.resize(work_count, shard_bytes / 64);
}

pub(crate) fn reset_received(&mut self) {
Expand Down
4 changes: 3 additions & 1 deletion src/rate/encoder_work.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,14 @@ impl EncoderWork {
shard_bytes: usize,
work_count: usize,
) {
assert!(shard_bytes % 64 == 0);

self.original_count = original_count;
self.recovery_count = recovery_count;
self.shard_bytes = shard_bytes;

self.original_received_count = 0;
self.shards.resize(work_count, shard_bytes);
self.shards.resize(work_count, shard_bytes / 64);
}

pub(crate) fn reset_received(&mut self) {
Expand Down

0 comments on commit 772f193

Please sign in to comment.