Skip to content

Commit

Permalink
Another bool writer
Browse files Browse the repository at this point in the history
Like reader it reduces the number of value tests
  • Loading branch information
Melirius committed Nov 21, 2024
1 parent e0559a2 commit 3e32519
Showing 1 changed file with 61 additions and 37 deletions.
98 changes: 61 additions & 37 deletions src/structs/vpx_bool_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ use crate::metrics::{Metrics, ModelComponent};
use crate::structs::branch::Branch;
use crate::structs::simple_hash::SimpleHash;

// MAX_STREAM_BITS should be a multiple of 8 larger than 8,
// and (MAX_STREAM_BITS + 1 bit of carry + 1 bit of divider)
// should fit into 64 bits of `low_value`
const MAX_STREAM_BITS: i32 = 56; //48; //40;// 32;// 24;// 16;//

pub struct VPXBoolWriter<W> {
low_value: u64,
range: u32,
Expand All @@ -45,18 +40,18 @@ pub struct VPXBoolWriter<W> {

impl<W: Write> VPXBoolWriter<W> {
pub fn new(writer: W) -> Result<Self> {
let mut retval = VPXBoolWriter {
let retval = VPXBoolWriter {
low_value: 1 << 9, // this divider bit keeps track of stream bits number
range: 255,
range: 128, // this is value after putting initial false bit
buffer: Vec::new(),
writer: writer,
model_statistics: Metrics::default(),
hash: SimpleHash::new(),
};

let mut dummy_branch = Branch::new();
// initial false bit is put to not get carry out of stream bits
retval.put_bit(false, &mut dummy_branch, ModelComponent::Dummy)?;
// initial false bit with dummy branch is put into stream
// to not get carry out of stream bits,
// but it is just equivalent to change `range` from initial 255 to 128

Ok(retval)
}
Expand Down Expand Up @@ -95,7 +90,6 @@ impl<W: Write> VPXBoolWriter<W> {

let split = 1 + (((*tmp_range - 1) * probability) >> 8);

let mut shift;
branch.record_and_update_bit(value);

if value {
Expand All @@ -105,7 +99,7 @@ impl<W: Write> VPXBoolWriter<W> {
*tmp_range = split;
}

shift = (*tmp_range as u8).leading_zeros() as i32;
let shift = (*tmp_range as u8).leading_zeros() as i32;

#[cfg(feature = "compression_stats")]
{
Expand All @@ -114,28 +108,6 @@ impl<W: Write> VPXBoolWriter<W> {
}

*tmp_range <<= shift;

// check whether we have more than MAX_STREAM_BITS stream bits after shift
let stream_bits = 64 - (*tmp_value).leading_zeros() as i32 - 2;
let count = shift + stream_bits - MAX_STREAM_BITS;
if count >= 0 {
// check carry
*tmp_value <<= MAX_STREAM_BITS - stream_bits;
if (*tmp_value & (1 << MAX_STREAM_BITS)) != 0 {
self.carry();
}
// write all full bytes
let mut sh = MAX_STREAM_BITS - 8;
while sh > 0 {
self.buffer.push((*tmp_value >> sh) as u8);
sh -= 8;
}
*tmp_value &= (1 << 8) - 1; // exclude written bytes
*tmp_value |= 1 << 9; // restore divider bit

shift = count;
}

*tmp_value <<= shift;

Ok(())
Expand All @@ -159,6 +131,14 @@ impl<W: Write> VPXBoolWriter<W> {
self.buffer[x] += 1;
}

// each added bit can extend stream for up to 7 bits
#[inline(always)]
fn cannot_put_bits(
tmp_value: u64, num_bits: u32
) -> bool {
tmp_value & (u64::MAX << (64 - num_bits * 7)) != 0
}

#[inline(always)]
pub fn put_grid<const A: usize>(
&mut self,
Expand All @@ -173,6 +153,11 @@ impl<W: Write> VPXBoolWriter<W> {

let mut index = A.ilog2() - 1;
let mut serialized_so_far = 1;
// grid is 3 or 6 bits long, so single flash is enough
debug_assert!(A <= 64);
if Self::cannot_put_bits(tmp_value, A.ilog2()) {
tmp_value = self.flush_buffer(tmp_value);
}

loop {
let cur_bit = (v & (1 << index)) != 0;
Expand Down Expand Up @@ -213,6 +198,10 @@ impl<W: Write> VPXBoolWriter<W> {

let mut i: i32 = (num_bits - 1) as i32;
while i >= 0 {
if Self::cannot_put_bits(tmp_value, 1) {
tmp_value = self.flush_buffer(tmp_value);
}

self.put(
(bits & (1 << i)) != 0,
&mut branches[i as usize],
Expand Down Expand Up @@ -243,6 +232,11 @@ impl<W: Write> VPXBoolWriter<W> {

for i in 0..A {
let cur_bit = v != i;
debug_assert!(A <= 12);
// ensure we can put 6 bits into the stream
if (i == 0 || i == 6) && Self::cannot_put_bits(tmp_value, 6) {
tmp_value = self.flush_buffer(tmp_value);
}

self.put(
cur_bit,
Expand Down Expand Up @@ -271,6 +265,9 @@ impl<W: Write> VPXBoolWriter<W> {
) -> Result<()> {
let mut tmp_value = self.low_value;
let mut tmp_range = self.range;
if Self::cannot_put_bits(tmp_value, 1) {
tmp_value = self.flush_buffer(tmp_value);
}

self.put(value, branch, &mut tmp_value, &mut tmp_range, _cmp)?;

Expand All @@ -280,24 +277,51 @@ impl<W: Write> VPXBoolWriter<W> {
Ok(())
}

// After `flush_buffer` we have max 15 stream bits and can put there 6 bits,
// that is adding max 6*7 stream bits, as 15 + 42 < 62
fn flush_buffer(&mut self, mut tmp_value: u64) -> u64 {
let stream_bits = 64 - tmp_value.leading_zeros() as i32 - 2;
let low_value = tmp_value << 63 - stream_bits;
if low_value & (1 << 63) != 0 {
self.carry();
}

let mut sh = 55;
let mut stream_bytes = (stream_bits >> 3) - 1;
while stream_bytes > 0 {
self.buffer.push((low_value >> sh) as u8);
sh -= 8;
stream_bytes -= 1;
}

let remaining_bits = 8 + (stream_bits & 7);
tmp_value &= (1 << remaining_bits) - 1;
tmp_value |= 1 << (remaining_bits + 1);

tmp_value
}

// Here we write down only bytes of the stream necessary for decoding -
// opposite to initial Lepton implementation that writes down all the buffer.
pub fn finish(&mut self) -> Result<()> {
let mut tmp_value = self.low_value;
let stream_bits = 64 - tmp_value.leading_zeros() as i32 - 2;

tmp_value <<= MAX_STREAM_BITS - stream_bits;
if (tmp_value & (1 << MAX_STREAM_BITS)) != 0 {
tmp_value <<= 63 - stream_bits;
if tmp_value & (1 << 63) != 0 {
self.carry();
}

let mut shift = MAX_STREAM_BITS - 8;
tmp_value <<= 1; // needed for 8 stream_bytes
let mut shift = 56;
let mut stream_bytes = (stream_bits + 7) >> 3;
while stream_bytes > 0 {
self.buffer.push((tmp_value >> shift) as u8);
shift -= 8;
stream_bytes -= 1;
}
// check that no stream bits remain in the buffer
debug_assert!(if shift == 56 {tmp_value == 0} else {!(u64::MAX << (shift + 8)) & tmp_value == 0});

self.writer.write_all(&self.buffer[..])?;
Ok(())
Expand Down

0 comments on commit 3e32519

Please sign in to comment.