Skip to content

Commit

Permalink
No count in bit writer (#116)
Browse files Browse the repository at this point in the history
* No count in bit writer

* Simpler *tmp_value writer

* Simpler finish
  • Loading branch information
Melirius authored Nov 21, 2024
1 parent 5997f8c commit 6c065e0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 49 deletions.
4 changes: 2 additions & 2 deletions src/structs/lepton_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ fn roundtrip_ac_only() {
&block,
&block,
[1; 64],
0xC634E0F1A29033CA,
0x9F5637364D41FE11,
&EnabledFeatures::compat_lepton_vector_read(),
);
}
Expand Down Expand Up @@ -656,7 +656,7 @@ fn roundtrip_large_coef() {
&block,
&block,
[1; 64],
0x12050FD2C854F927,
0x95CBDD4F7D7B72EB,
&EnabledFeatures::compat_lepton_vector_read(),
);

Expand Down
87 changes: 40 additions & 47 deletions src/structs/vpx_bool_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ use crate::metrics::{Metrics, ModelComponent};
use crate::structs::branch::Branch;
use crate::structs::simple_hash::SimpleHash;

// MAX_STREAM_BITS should be a multiple of 8 larger than 8,
// and (MAX_STREAM_BITS + 1 bit of carry + 1 bit of divider)
// should fit into 64 bits of `low_value`
const MAX_STREAM_BITS: i32 = 56; //48; //40;// 32;// 24;// 16;//

pub struct VPXBoolWriter<W> {
low_value: u32,
low_value: u64,
range: u32,
count: i32,
writer: W,
buffer: Vec<u8>,
model_statistics: Metrics,
Expand All @@ -42,9 +46,8 @@ pub struct VPXBoolWriter<W> {
impl<W: Write> VPXBoolWriter<W> {
pub fn new(writer: W) -> Result<Self> {
let mut retval = VPXBoolWriter {
low_value: 0,
low_value: 1 << 9, // this divider bit keeps track of stream bits number
range: 255,
count: -24,
buffer: Vec::new(),
writer: writer,
model_statistics: Metrics::default(),
Expand All @@ -67,17 +70,15 @@ impl<W: Write> VPXBoolWriter<W> {
&mut self,
value: bool,
branch: &mut Branch,
tmp_value: &mut u32,
tmp_value: &mut u64,
tmp_range: &mut u32,
tmp_count: &mut i32,
_cmp: ModelComponent,
) -> Result<()> {
#[cfg(feature = "detailed_tracing")]
{
// used to detect divergences between the C++ and rust versions
self.hash.hash(branch.get_u64());
self.hash.hash(*tmp_value);
self.hash.hash(*tmp_count);
self.hash.hash(*tmp_range);

let hashed_value = self.hash.get();
Expand All @@ -98,7 +99,7 @@ impl<W: Write> VPXBoolWriter<W> {
branch.record_and_update_bit(value);

if value {
*tmp_value += split;
*tmp_value += split as u64;
*tmp_range -= split;
} else {
*tmp_range = split;
Expand All @@ -113,22 +114,26 @@ impl<W: Write> VPXBoolWriter<W> {
}

*tmp_range <<= shift;
*tmp_count += shift;

if *tmp_count >= 0 {
let offset = shift - *tmp_count - 1;

*tmp_value <<= offset;
if (*tmp_value & 0x80000000) != 0 {
// check whether we have more than MAX_STREAM_BITS stream bits after shift
let stream_bits = 64 - (*tmp_value).leading_zeros() as i32 - 2;
let count = shift + stream_bits - MAX_STREAM_BITS;
if count >= 0 {
// check carry
*tmp_value <<= MAX_STREAM_BITS - stream_bits;
if (*tmp_value & (1 << MAX_STREAM_BITS)) != 0 {
self.carry();
}
// write all full bytes
let mut sh = MAX_STREAM_BITS - 8;
while sh > 0 {
self.buffer.push((*tmp_value >> sh) as u8);
sh -= 8;
}
*tmp_value &= (1 << 8) - 1; // exclude written bytes
*tmp_value |= 1 << 9; // restore divider bit

*tmp_value <<= 1;
self.buffer.push((*tmp_value >> 24) as u8);
*tmp_value &= 0xffffff;

shift = *tmp_count;
*tmp_count -= 8;
shift = count;
}

*tmp_value <<= shift;
Expand Down Expand Up @@ -165,7 +170,6 @@ impl<W: Write> VPXBoolWriter<W> {
assert!((A & (A - 1)) == 0);
let mut tmp_value = self.low_value;
let mut tmp_range = self.range;
let mut tmp_count = self.count;

let mut index = A.ilog2() - 1;
let mut serialized_so_far = 1;
Expand All @@ -177,7 +181,6 @@ impl<W: Write> VPXBoolWriter<W> {
&mut branches[serialized_so_far],
&mut tmp_value,
&mut tmp_range,
&mut tmp_count,
cmp,
)?;

Expand All @@ -193,7 +196,6 @@ impl<W: Write> VPXBoolWriter<W> {

self.low_value = tmp_value;
self.range = tmp_range;
self.count = tmp_count;

Ok(())
}
Expand All @@ -208,7 +210,6 @@ impl<W: Write> VPXBoolWriter<W> {
) -> Result<()> {
let mut tmp_value = self.low_value;
let mut tmp_range = self.range;
let mut tmp_count = self.count;

let mut i: i32 = (num_bits - 1) as i32;
while i >= 0 {
Expand All @@ -217,15 +218,13 @@ impl<W: Write> VPXBoolWriter<W> {
&mut branches[i as usize],
&mut tmp_value,
&mut tmp_range,
&mut tmp_count,
cmp,
)?;
i -= 1;
}

self.low_value = tmp_value;
self.range = tmp_range;
self.count = tmp_count;

Ok(())
}
Expand All @@ -241,7 +240,6 @@ impl<W: Write> VPXBoolWriter<W> {

let mut tmp_value = self.low_value;
let mut tmp_range = self.range;
let mut tmp_count = self.count;

for i in 0..A {
let cur_bit = v != i;
Expand All @@ -251,7 +249,6 @@ impl<W: Write> VPXBoolWriter<W> {
&mut branches[i],
&mut tmp_value,
&mut tmp_range,
&mut tmp_count,
cmp,
)?;
if !cur_bit {
Expand All @@ -261,7 +258,6 @@ impl<W: Write> VPXBoolWriter<W> {

self.low_value = tmp_value;
self.range = tmp_range;
self.count = tmp_count;

Ok(())
}
Expand All @@ -275,36 +271,33 @@ impl<W: Write> VPXBoolWriter<W> {
) -> Result<()> {
let mut tmp_value = self.low_value;
let mut tmp_range = self.range;
let mut tmp_count = self.count;

self.put(
value,
branch,
&mut tmp_value,
&mut tmp_range,
&mut tmp_count,
_cmp,
)?;
self.put(value, branch, &mut tmp_value, &mut tmp_range, _cmp)?;

self.low_value = tmp_value;
self.range = tmp_range;
self.count = tmp_count;

Ok(())
}

// Typically all bytes of `low_value` will have stream bits,
// so just write them all - that is what initial Lepton implementation does.
// Here we write down only bytes of the stream necessary for decoding -
// opposite to initial Lepton implementation that writes down all the buffer.
pub fn finish(&mut self) -> Result<()> {
let tmp_value = self.low_value << (-self.count - 1);
let mut tmp_value = self.low_value;
let stream_bits = 64 - tmp_value.leading_zeros() as i32 - 2;

if (tmp_value & 0x80000000) != 0 {
tmp_value <<= MAX_STREAM_BITS - stream_bits;
if (tmp_value & (1 << MAX_STREAM_BITS)) != 0 {
self.carry();
}
self.buffer.push((tmp_value >> 23) as u8);
self.buffer.push((tmp_value >> 15) as u8);
self.buffer.push((tmp_value >> 7) as u8);
self.buffer.push((tmp_value << 1) as u8);

let mut shift = MAX_STREAM_BITS - 8;
let mut stream_bytes = (stream_bits + 7) >> 3;
while stream_bytes > 0 {
self.buffer.push((tmp_value >> shift) as u8);
shift -= 8;
stream_bytes -= 1;
}

self.writer.write_all(&self.buffer[..])?;
Ok(())
Expand Down

0 comments on commit 6c065e0

Please sign in to comment.