Skip to content

Commit

Permalink
No count in stream reader (#107)
Browse files Browse the repository at this point in the history
* outline get_unary

* work

* split to cold version

* Get rid of count

* Exclude index marker handling, simpler reading

* Fix test

* Simpler testing for number of stream bits in value

* Restore compression_stats and detailed_tracing features, some comments

* Fix comment

* Format

* Energy saving

* Iterations number fix
Each iteration, although consumes max 7 bits, needs min 8 stream bits in `value` for correct handling, then in the worst case 56 stream bits can sustain max 7 iterations

* Removed cold function - the same perf

* Correct limits

---------

Co-authored-by: Kristof <[email protected]>
  • Loading branch information
Melirius and mcroomp authored Nov 12, 2024
1 parent ab125e6 commit 223e411
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 84 deletions.
214 changes: 131 additions & 83 deletions src/structs/vpx_bool_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ use crate::metrics::{Metrics, ModelComponent};

use super::{branch::Branch, simple_hash::SimpleHash};

const BITS_IN_BYTE: i32 = 8;
const BITS_IN_VALUE: i32 = 64;
const BITS_IN_VALUE_MINUS_LAST_BYTE: i32 = BITS_IN_VALUE - BITS_IN_BYTE;
const BITS_IN_BYTE: u32 = 8;
const BITS_IN_VALUE: u32 = 64;
const BITS_IN_VALUE_MINUS_LAST_BYTE: u32 = BITS_IN_VALUE - BITS_IN_BYTE;
const VALUE_MASK: u64 = (1 << BITS_IN_VALUE_MINUS_LAST_BYTE) - 1;

pub struct VPXBoolReader<R> {
value: u64,
range: u64, // 128 << BITS_IN_VALUE_MINUS_LAST_BYTE <= range <= 255 << BITS_IN_VALUE_MINUS_LAST_BYTE
count: i32,
upstream_reader: R,
model_statistics: Metrics,
#[allow(dead_code)]
Expand All @@ -46,15 +46,12 @@ impl<R: Read> VPXBoolReader<R> {
pub fn new(reader: R) -> Result<Self> {
let mut r = VPXBoolReader {
upstream_reader: reader,
value: 0,
count: -8,
value: 1 << (BITS_IN_VALUE - 1), // guard bit
range: 255 << BITS_IN_VALUE_MINUS_LAST_BYTE,
model_statistics: Metrics::default(),
hash: SimpleHash::new(),
};

Self::vpx_reader_fill(&mut r.value, &mut r.count, &mut r.upstream_reader)?;

let mut dummy_branch = Branch::new();
r.get_bit(&mut dummy_branch, ModelComponent::Dummy)?; // marker bit

Expand Down Expand Up @@ -87,21 +84,20 @@ impl<R: Read> VPXBoolReader<R> {
// Second, `range` and `split` are also stored in 8 MSBs of the same size variables (it is new
// and it allows to reduce number of operations to compute `split` - previously `big_split` -
// and to update `range` and `shift`). Third, we use local values for all stream state variables
// to reduce number of memory load/store operations in decoding of many-bit values.
// to reduce number of memory load/store operations in decoding of many-bit values. Fourth,
// we use in `value` a set bit after the stream bits as a guard - completely getting rid
// of bit counter and not changing comparison result `value >= split`.
#[inline(always)]
pub fn get(
&mut self,
branch: &mut Branch,
tmp_value: &mut u64,
tmp_range: &mut u64,
tmp_count: &mut i32,
_cmp: ModelComponent,
) -> bool {
let probability = branch.get_probability() as u64;

let split = ((((*tmp_range - (1 << BITS_IN_VALUE_MINUS_LAST_BYTE)) >> 8) * probability)
& (0xFF << BITS_IN_VALUE_MINUS_LAST_BYTE))
+ (1 << BITS_IN_VALUE_MINUS_LAST_BYTE);
let split = mul_prob(*tmp_range, probability);

// So optimizer understands that 0 should never happen and uses a cold jump
// if we don't have LZCNT on x86 CPUs (older BSR instruction requires check for zero).
Expand All @@ -124,11 +120,10 @@ impl<R: Read> VPXBoolReader<R> {
*tmp_range = split;
}

let shift = (*tmp_range).leading_zeros() as i32;
let shift = (*tmp_range).leading_zeros();

*tmp_value <<= shift;
*tmp_range <<= shift;
*tmp_count -= shift;

#[cfg(feature = "compression_stats")]
{
Expand All @@ -140,7 +135,6 @@ impl<R: Read> VPXBoolReader<R> {
{
self.hash.hash(branch.get_u64());
self.hash.hash(*tmp_value);
self.hash.hash(*tmp_count);
self.hash.hash(*tmp_range);

let hash = self.hash.get();
Expand All @@ -163,28 +157,23 @@ impl<R: Read> VPXBoolReader<R> {
_cmp: ModelComponent,
) -> Result<usize> {
// check if A is a power of 2
assert!((A & (A - 1)) == 0);
debug_assert!((A & (A - 1)) == 0);

let mut tmp_value = self.value;
let mut tmp_range = self.range;
let mut tmp_count = self.count;

let mut decoded_so_far = 1;
// We can read only each 7-th iteration: minimum 56 bits are in `value` after `vpx_reader_fill`,
// and one `get` needs 8 bits but consumes at most 7 bits (with `range` coming from >127 to 1).
// As Lepton uses only 3 and 6 iterations, we can read only once.
debug_assert!(A <= 128);
tmp_value = Self::vpx_reader_fill(tmp_value, &mut self.upstream_reader)?;

for index in 0..A.ilog2() {
// We can read only each 8-th iteration: minimum 57 bits are in `value` after `vpx_reader_fill`,
// and one `get` consumes at most 7 bits (with `range` coming from >127 to 1).
// Reading like this instead of old `tmp_count < 0` condition we got perfect branch prediction
// or no branching at all for unrolled loop, possible since number of iterations is known beforehand.
if index & 7 == 0 {
Self::vpx_reader_fill(&mut tmp_value, &mut tmp_count, &mut self.upstream_reader)?;
}

for _index in 0..A.ilog2() {
let cur_bit = self.get(
&mut branches[decoded_so_far],
&mut tmp_value,
&mut tmp_range,
&mut tmp_count,
_cmp,
) as usize;
decoded_so_far <<= 1;
Expand All @@ -196,7 +185,6 @@ impl<R: Read> VPXBoolReader<R> {

self.value = tmp_value;
self.range = tmp_range;
self.count = tmp_count;

Ok(value)
}
Expand All @@ -209,36 +197,94 @@ impl<R: Read> VPXBoolReader<R> {
) -> Result<usize> {
let mut tmp_value = self.value;
let mut tmp_range = self.range;
let mut tmp_count = self.count;

let mut value = 0;
for value in 0..A {
let split = mul_prob(tmp_range, branches[value].get_probability() as u64);

while value != A {
// Reading like this instead of old `tmp_count < 0` condition we got perfect branch prediction
// or no branching at all for unrolled loop, possible since number of iterations is known beforehand.
if value & 7 == 0 {
Self::vpx_reader_fill(&mut tmp_value, &mut tmp_count, &mut self.upstream_reader)?;
// We know that after this we have min 56 stream bits in `tmp_value`,
// and can have at least 7 iterations, so we can decode 7 bits at once.
// Each iteration needs at least 8 bits of stream in `tmp_value` and
// consumes max 7 of them.
debug_assert!(A <= 14);
if value == 0 || value == 7 {
tmp_value = Self::vpx_reader_fill(tmp_value, &mut self.upstream_reader)?;
}

let cur_bit = self.get(
&mut branches[value],
&mut tmp_value,
&mut tmp_range,
&mut tmp_count,
_cmp,
);
if !cur_bit {
break;
}
if tmp_value >= split {
branches[value].record_and_update_bit(true);

tmp_range -= split;
tmp_value -= split;

let shift = tmp_range.leading_zeros();

tmp_value <<= shift;
tmp_range <<= shift;

#[cfg(feature = "compression_stats")]
{
self.model_statistics
.record_compression_stats(_cmp, 1, i64::from(shift));
}

#[cfg(feature = "detailed_tracing")]
{
self.hash.hash(branches[value].get_u64());
self.hash.hash(tmp_value);
self.hash.hash(tmp_range);

let hash = self.hash.get();
//if hash == 0x88f9c945
{
print!("({0}:{1:x})", true as u8, hash);
if hash % 8 == 0 {
println!();
}
}
}
} else {
branches[value].record_and_update_bit(false);

tmp_range = split;

let shift = tmp_range.leading_zeros();

value += 1;
tmp_value <<= shift;
tmp_range <<= shift;

#[cfg(feature = "compression_stats")]
{
self.model_statistics
.record_compression_stats(_cmp, 1, i64::from(shift));
}

#[cfg(feature = "detailed_tracing")]
{
self.hash.hash(branches[value].get_u64());
self.hash.hash(tmp_value);
self.hash.hash(tmp_range);

let hash = self.hash.get();
//if hash == 0x88f9c945
{
print!("({0}:{1:x})", false as u8, hash);
if hash % 8 == 0 {
println!();
}
}
}

self.value = tmp_value;
self.range = tmp_range;

return Ok(value);
}
}

self.value = tmp_value;
self.range = tmp_range;
self.count = tmp_count;

return Ok(value);
Ok(A)
}

#[inline(always)]
Expand All @@ -252,31 +298,23 @@ impl<R: Read> VPXBoolReader<R> {

let mut tmp_value = self.value;
let mut tmp_range = self.range;
let mut tmp_count = self.count;

let mut coef = 0;
for i in (0..n).rev() {
// Here the fastest way is to use this old condition, presumably as
// Here the fastest way is to use condition of `get_bit`, presumably as
// this loop cannot be unrolled due to vaiable iterations number.
// Moreover, this condition holds very rarely as `value` is usually already filled
// by previous `get_bit` sign reading.
if tmp_count < 0 {
Self::vpx_reader_fill(&mut tmp_value, &mut tmp_count, &mut self.upstream_reader)?;
if tmp_value & VALUE_MASK == 0 {
tmp_value = Self::vpx_reader_fill(tmp_value, &mut self.upstream_reader)?;
}

coef |= (self.get(
&mut branches[i],
&mut tmp_value,
&mut tmp_range,
&mut tmp_count,
_cmp,
) as usize)
<< i;
coef |=
(self.get(&mut branches[i], &mut tmp_value, &mut tmp_range, _cmp) as usize) << i;
}

self.value = tmp_value;
self.range = tmp_range;
self.count = tmp_count;

return Ok(coef);
}
Expand All @@ -285,43 +323,53 @@ impl<R: Read> VPXBoolReader<R> {
pub fn get_bit(&mut self, branch: &mut Branch, _cmp: ModelComponent) -> Result<bool> {
let mut tmp_value = self.value;
let mut tmp_range = self.range;
let mut tmp_count = self.count;

if tmp_count < 0 {
Self::vpx_reader_fill(&mut tmp_value, &mut tmp_count, &mut self.upstream_reader)?;
// We ensure that the guard bit never comes into the first byte,
// thus having in `value` at least 8 stream bits.
if tmp_value & VALUE_MASK == 0 {
tmp_value = Self::vpx_reader_fill(tmp_value, &mut self.upstream_reader)?;
}

let bit = self.get(branch, &mut tmp_value, &mut tmp_range, &mut tmp_count, _cmp);
let bit = self.get(branch, &mut tmp_value, &mut tmp_range, _cmp);

self.value = tmp_value;
self.range = tmp_range;
self.count = tmp_count;

return Ok(bit);
}

#[cold]
// Fill `tmp_value` maximally still preserving space for the guard bit,
// after this returned value has `56 | (63 - shift)` stream bits
#[inline(always)]
fn vpx_reader_fill(
tmp_value: &mut u64,
tmp_count: &mut i32,
upstream_reader: &mut R,
) -> Result<()> {
let mut shift = BITS_IN_VALUE_MINUS_LAST_BYTE - (*tmp_count + BITS_IN_BYTE);
fn vpx_reader_fill(mut tmp_value: u64, upstream_reader: &mut R) -> Result<u64> {
// This `if` does not change performance but drops down instructions count by 3 %
if tmp_value & 0xFF == 0 {
let mut shift: i32 = tmp_value.trailing_zeros() as i32;
// Unset the last guard bit and set a new one
tmp_value &= tmp_value - 1;
tmp_value |= 1 << (shift & 7);

while shift >= 0 {
// BufReader is already pretty efficient handling small reads, so optimization doesn't help that much
let mut v = [0u8; 1];
let bytes_read = upstream_reader.read(&mut v)?;
if bytes_read == 0 {
break;
}
shift -= 7;

while shift > 0 {
let bytes_read = upstream_reader.read(&mut v)?;
if bytes_read == 0 {
break;
}

*tmp_value |= (v[0] as u64) << shift;
shift -= BITS_IN_BYTE;
*tmp_count += BITS_IN_BYTE;
tmp_value |= (v[0] as u64) << shift;
shift -= 8;
}
}

return Ok(());
return Ok(tmp_value);
}
}

fn mul_prob(tmp_range: u64, probability: u64) -> u64 {
((((tmp_range - (1 << BITS_IN_VALUE_MINUS_LAST_BYTE)) >> 8) * probability)
& (0xFF << BITS_IN_VALUE_MINUS_LAST_BYTE))
+ (1 << BITS_IN_VALUE_MINUS_LAST_BYTE)
}
3 changes: 2 additions & 1 deletion src/structs/vpx_bool_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ impl<W: Write> VPXBoolWriter<W> {
}

pub fn finish(&mut self) -> Result<()> {
// push real stream bits out of `value`
for _i in 0..32 {
let mut dummy_branch = Branch::new();
self.put_bit(false, &mut dummy_branch, ModelComponent::Dummy)?;
Expand Down Expand Up @@ -359,7 +360,7 @@ fn test_roundtrip_vpxboolwriter_n_bits() {

#[test]
fn test_roundtrip_vpxboolwriter_unary() {
const MAX_UNARY: usize = 8;
const MAX_UNARY: usize = 11; // the size used in Lepton

#[derive(Default)]
struct BranchData {
Expand Down

0 comments on commit 223e411

Please sign in to comment.