diff --git a/src/arm/itx.h b/src/arm/itx.h index c180a8635..b57184c45 100644 --- a/src/arm/itx.h +++ b/src/arm/itx.h @@ -49,7 +49,9 @@ decl_itx_fn(BF(dav1d_inv_txfm_add_dct_dct_64x16, neon)); decl_itx_fn(BF(dav1d_inv_txfm_add_dct_dct_64x32, neon)); decl_itx_fn(BF(dav1d_inv_txfm_add_dct_dct_64x64, neon)); -static ALWAYS_INLINE void itx_dsp_init_arm(Dav1dInvTxfmDSPContext *const c, int bpc) { +static ALWAYS_INLINE void itx_dsp_init_arm(Dav1dInvTxfmDSPContext *const c, int bpc, + int *const all_simd) +{ const unsigned flags = dav1d_get_cpu_flags(); if (!(flags & DAV1D_ARM_CPU_FLAG_NEON)) return; @@ -77,4 +79,5 @@ static ALWAYS_INLINE void itx_dsp_init_arm(Dav1dInvTxfmDSPContext *const c, int assign_itx1_fn (R, 64, 16, neon); assign_itx1_fn (R, 64, 32, neon); assign_itx1_fn ( , 64, 64, neon); + *all_simd = 1; } diff --git a/src/in_range.rs b/src/in_range.rs index c27a29708..aae3a6d37 100644 --- a/src/in_range.rs +++ b/src/in_range.rs @@ -23,7 +23,7 @@ where impl InRange where - T: TryFrom + PartialEq + Eq + PartialOrd + Ord, + T: TryFrom + PartialEq + Eq + PartialOrd + Ord + Copy, { fn in_bounds(&self) -> bool { *self >= Self::min() && *self <= Self::max() @@ -43,6 +43,10 @@ where unsafe { assert_unchecked(self.in_bounds()) }; self.0 } + + pub const fn const_get(&'static self) -> T { + self.0 + } } impl Default for InRange diff --git a/src/itx.rs b/src/itx.rs index a2646d917..2e693d6ee 100644 --- a/src/itx.rs +++ b/src/itx.rs @@ -32,7 +32,9 @@ use crate::levels::{ FLIPADST_ADST, FLIPADST_DCT, FLIPADST_FLIPADST, H_ADST, H_DCT, H_FLIPADST, IDTX, N_TX_TYPES_PLUS_LL, V_ADST, V_DCT, V_FLIPADST, WHT_WHT, }; +use crate::scan::DAV1D_LAST_NONZERO_COL_FROM_EOB; use crate::strided::Strided as _; +use crate::tables::DAV1D_TXFM_DIMENSIONS; use crate::wrap_fn_ptr::wrap_fn_ptr; pub type Itx1dFn = fn(c: &mut [i32], stride: NonZeroUsize, min: i32, max: i32); @@ -42,16 +44,18 @@ fn inv_txfm_add( dst: Rav1dPictureDataComponentOffset, coeff: &mut [BD::Coef], eob: i32, - w: usize, - h: usize, + tx: TxfmSize, shift: u8, - first_1d_fn: Itx1dFn, - second_1d_fn: Itx1dFn, - has_dc_only: bool, + txtp: TxfmType, bd: BD, ) { let bitdepth_max = bd.bitdepth_max().as_::(); + let t_dim = &DAV1D_TXFM_DIMENSIONS[tx as usize]; + let w = 4 * t_dim.w as usize; + let h = 4 * t_dim.h as usize; + let has_dc_only = txtp == DCT_DCT; + assert!(w >= 4 && w <= 64); assert!(h >= 4 && h <= 64); assert!(eob >= 0); @@ -78,6 +82,63 @@ fn inv_txfm_add( return; } + #[derive(PartialEq, Clone, Copy)] + enum Type { + Identity, + Dct, + Adst, + FlipAdst, + } + use Type::*; + // For some reason, this is flipped. + let (second, first) = match txtp { + IDTX => (Identity, Identity), + DCT_DCT => (Dct, Dct), + ADST_DCT => (Adst, Dct), + FLIPADST_DCT => (FlipAdst, Dct), + H_DCT => (Identity, Dct), + DCT_ADST => (Dct, Adst), + ADST_ADST => (Adst, Adst), + FLIPADST_ADST => (FlipAdst, Adst), + DCT_FLIPADST => (Dct, FlipAdst), + ADST_FLIPADST => (Adst, FlipAdst), + FLIPADST_FLIPADST => (FlipAdst, FlipAdst), + V_DCT => (Dct, Identity), + H_ADST => (Identity, Adst), + H_FLIPADST => (Identity, FlipAdst), + V_ADST => (Adst, Identity), + V_FLIPADST => (FlipAdst, Identity), + + #[cfg(not(all(feature = "asm", target_feature = "neon")))] + WHT_WHT if (w, h) == (4, 4) => return inv_txfm_add_wht_wht_4x4_rust(dst, coeff, bd), + + _ => unreachable!(), + }; + + fn resolve_1d_fn(r#type: Type, n: usize) -> Itx1dFn { + match (r#type, n) { + (Identity, 4) => rav1d_inv_identity4_1d_c, + (Identity, 8) => rav1d_inv_identity8_1d_c, + (Identity, 16) => rav1d_inv_identity16_1d_c, + (Identity, 32) => rav1d_inv_identity32_1d_c, + (Dct, 4) => rav1d_inv_dct4_1d_c, + (Dct, 8) => rav1d_inv_dct8_1d_c, + (Dct, 16) => rav1d_inv_dct16_1d_c, + (Dct, 32) => rav1d_inv_dct32_1d_c, + (Dct, 64) => rav1d_inv_dct64_1d_c, + (Adst, 4) => rav1d_inv_adst4_1d_c, + (Adst, 8) => rav1d_inv_adst8_1d_c, + (Adst, 16) => rav1d_inv_adst16_1d_c, + (FlipAdst, 4) => rav1d_inv_flipadst4_1d_c, + (FlipAdst, 8) => rav1d_inv_flipadst8_1d_c, + (FlipAdst, 16) => rav1d_inv_flipadst16_1d_c, + _ => unreachable!(), + } + } + + let first_1d_fn = resolve_1d_fn(first, w); + let second_1d_fn = resolve_1d_fn(second, h); + let sh = cmp::min(h, 32); let sw = cmp::min(w, 32); @@ -96,8 +157,18 @@ fn inv_txfm_add( let col_clip_max = !col_clip_min; let mut tmp = [0; 64 * 64]; - let mut c = &mut tmp[..]; - for y in 0..sh { + let mut c = &mut tmp[..sh * w]; + let eob = eob as usize; + // in first 1d itx + let last_nonzero_col = if second == Identity && first != Identity { + std::cmp::min(sh - 1, eob) + } else if first == Identity && second != Identity { + eob >> (t_dim.lw + 2) + } else { + DAV1D_LAST_NONZERO_COL_FROM_EOB[tx as usize][eob as usize] as usize + }; + assert!(last_nonzero_col < sh); + for y in 0..=last_nonzero_col { if is_rect2 { for x in 0..sw { c[x] = coeff[y + x * sh].as_::() * 181 + 128 >> 8; @@ -110,6 +181,8 @@ fn inv_txfm_add( first_1d_fn(c, 1.try_into().unwrap(), row_clip_min, row_clip_max); c = &mut c[w..]; } + // fill remaining values in slice `c` with 0 + c.fill(0); coeff.fill(0.into()); for i in 0..w * sh { @@ -162,76 +235,9 @@ fn inv_txfm_add_rust 2, _ => unreachable!(), }; - let has_dc_only = TYPE == DCT_DCT; - - enum Type { - Identity, - Dct, - Adst, - FlipAdst, - } - use Type::*; - // For some reason, this is flipped. - let (second, first) = match TYPE { - IDTX => (Identity, Identity), - DCT_DCT => (Dct, Dct), - ADST_DCT => (Adst, Dct), - FLIPADST_DCT => (FlipAdst, Dct), - H_DCT => (Identity, Dct), - DCT_ADST => (Dct, Adst), - ADST_ADST => (Adst, Adst), - FLIPADST_ADST => (FlipAdst, Adst), - DCT_FLIPADST => (Dct, FlipAdst), - ADST_FLIPADST => (Adst, FlipAdst), - FLIPADST_FLIPADST => (FlipAdst, FlipAdst), - V_DCT => (Dct, Identity), - H_ADST => (Identity, Adst), - H_FLIPADST => (Identity, FlipAdst), - V_ADST => (Adst, Identity), - V_FLIPADST => (FlipAdst, Identity), - - #[cfg(not(all(feature = "asm", target_feature = "neon")))] - WHT_WHT if (W, H) == (4, 4) => return inv_txfm_add_wht_wht_4x4_rust(dst, coeff, bd), - - _ => unreachable!(), - }; - - fn resolve_1d_fn(r#type: Type, n: usize) -> Itx1dFn { - match (r#type, n) { - (Identity, 4) => rav1d_inv_identity4_1d_c, - (Identity, 8) => rav1d_inv_identity8_1d_c, - (Identity, 16) => rav1d_inv_identity16_1d_c, - (Identity, 32) => rav1d_inv_identity32_1d_c, - (Dct, 4) => rav1d_inv_dct4_1d_c, - (Dct, 8) => rav1d_inv_dct8_1d_c, - (Dct, 16) => rav1d_inv_dct16_1d_c, - (Dct, 32) => rav1d_inv_dct32_1d_c, - (Dct, 64) => rav1d_inv_dct64_1d_c, - (Adst, 4) => rav1d_inv_adst4_1d_c, - (Adst, 8) => rav1d_inv_adst8_1d_c, - (Adst, 16) => rav1d_inv_adst16_1d_c, - (FlipAdst, 4) => rav1d_inv_flipadst4_1d_c, - (FlipAdst, 8) => rav1d_inv_flipadst8_1d_c, - (FlipAdst, 16) => rav1d_inv_flipadst16_1d_c, - _ => unreachable!(), - } - } - let first_1d_fn = resolve_1d_fn(first, W); - let second_1d_fn = resolve_1d_fn(second, H); - - inv_txfm_add( - dst, - coeff, - eob, - W, - H, - shift, - first_1d_fn, - second_1d_fn, - has_dc_only, - bd, - ) + let tx = TxfmSize::from_wh(W, H); + inv_txfm_add(dst, coeff, eob, tx, shift, TYPE, bd) } /// # Safety diff --git a/src/itx_1d.c b/src/itx_1d.c index 8f75c653a..14e89ca0c 100644 --- a/src/itx_1d.c +++ b/src/itx_1d.c @@ -89,8 +89,8 @@ inv_dct4_1d_internal_c(int32_t *const c, const ptrdiff_t stride, c[3 * stride] = CLIP(t0 - t3); } -void dav1d_inv_dct4_1d_c(int32_t *const c, const ptrdiff_t stride, - const int min, const int max) +static void inv_dct4_1d_c(int32_t *const c, const ptrdiff_t stride, + const int min, const int max) { inv_dct4_1d_internal_c(c, stride, min, max, 0); } @@ -142,8 +142,8 @@ inv_dct8_1d_internal_c(int32_t *const c, const ptrdiff_t stride, c[7 * stride] = CLIP(t0 - t7); } -void dav1d_inv_dct8_1d_c(int32_t *const c, const ptrdiff_t stride, - const int min, const int max) +static void inv_dct8_1d_c(int32_t *const c, const ptrdiff_t stride, + const int min, const int max) { inv_dct8_1d_internal_c(c, stride, min, max, 0); } @@ -237,8 +237,8 @@ inv_dct16_1d_internal_c(int32_t *const c, const ptrdiff_t stride, c[15 * stride] = CLIP(t0 - t15a); } -void dav1d_inv_dct16_1d_c(int32_t *const c, const ptrdiff_t stride, - const int min, const int max) +static void inv_dct16_1d_c(int32_t *const c, const ptrdiff_t stride, + const int min, const int max) { inv_dct16_1d_internal_c(c, stride, min, max, 0); } @@ -427,14 +427,14 @@ inv_dct32_1d_internal_c(int32_t *const c, const ptrdiff_t stride, c[31 * stride] = CLIP(t0 - t31); } -void dav1d_inv_dct32_1d_c(int32_t *const c, const ptrdiff_t stride, - const int min, const int max) +static void inv_dct32_1d_c(int32_t *const c, const ptrdiff_t stride, + const int min, const int max) { inv_dct32_1d_internal_c(c, stride, min, max, 0); } -void dav1d_inv_dct64_1d_c(int32_t *const c, const ptrdiff_t stride, - const int min, const int max) +static void inv_dct64_1d_c(int32_t *const c, const ptrdiff_t stride, + const int min, const int max) { assert(stride > 0); inv_dct32_1d_internal_c(c, stride << 1, min, max, 1); @@ -962,13 +962,13 @@ inv_adst16_1d_internal_c(const int32_t *const in, const ptrdiff_t in_s, } #define inv_adst_1d(sz) \ -void dav1d_inv_adst##sz##_1d_c(int32_t *const c, const ptrdiff_t stride, \ - const int min, const int max) \ +static void inv_adst##sz##_1d_c(int32_t *const c, const ptrdiff_t stride, \ + const int min, const int max) \ { \ inv_adst##sz##_1d_internal_c(c, stride, min, max, c, stride); \ } \ -void dav1d_inv_flipadst##sz##_1d_c(int32_t *const c, const ptrdiff_t stride, \ - const int min, const int max) \ +static void inv_flipadst##sz##_1d_c(int32_t *const c, const ptrdiff_t stride, \ + const int min, const int max) \ { \ inv_adst##sz##_1d_internal_c(c, stride, min, max, \ &c[(sz - 1) * stride], -stride); \ @@ -980,8 +980,8 @@ inv_adst_1d(16) #undef inv_adst_1d -void dav1d_inv_identity4_1d_c(int32_t *const c, const ptrdiff_t stride, - const int min, const int max) +static void inv_identity4_1d_c(int32_t *const c, const ptrdiff_t stride, + const int min, const int max) { assert(stride > 0); for (int i = 0; i < 4; i++) { @@ -990,16 +990,16 @@ void dav1d_inv_identity4_1d_c(int32_t *const c, const ptrdiff_t stride, } } -void dav1d_inv_identity8_1d_c(int32_t *const c, const ptrdiff_t stride, - const int min, const int max) +static void inv_identity8_1d_c(int32_t *const c, const ptrdiff_t stride, + const int min, const int max) { assert(stride > 0); for (int i = 0; i < 8; i++) c[stride * i] *= 2; } -void dav1d_inv_identity16_1d_c(int32_t *const c, const ptrdiff_t stride, - const int min, const int max) +static void inv_identity16_1d_c(int32_t *const c, const ptrdiff_t stride, + const int min, const int max) { assert(stride > 0); for (int i = 0; i < 16; i++) { @@ -1008,14 +1008,57 @@ void dav1d_inv_identity16_1d_c(int32_t *const c, const ptrdiff_t stride, } } -void dav1d_inv_identity32_1d_c(int32_t *const c, const ptrdiff_t stride, - const int min, const int max) +static void inv_identity32_1d_c(int32_t *const c, const ptrdiff_t stride, + const int min, const int max) { assert(stride > 0); for (int i = 0; i < 32; i++) c[stride * i] *= 4; } +const itx_1d_fn dav1d_tx1d_fns[N_TX_SIZES][N_TX_1D_TYPES] = { + [TX_4X4] = { + [DCT] = inv_dct4_1d_c, + [ADST] = inv_adst4_1d_c, + [FLIPADST] = inv_flipadst4_1d_c, + [IDENTITY] = inv_identity4_1d_c, + }, [TX_8X8] = { + [DCT] = inv_dct8_1d_c, + [ADST] = inv_adst8_1d_c, + [FLIPADST] = inv_flipadst8_1d_c, + [IDENTITY] = inv_identity8_1d_c, + }, [TX_16X16] = { + [DCT] = inv_dct16_1d_c, + [ADST] = inv_adst16_1d_c, + [FLIPADST] = inv_flipadst16_1d_c, + [IDENTITY] = inv_identity16_1d_c, + }, [TX_32X32] = { + [DCT] = inv_dct32_1d_c, + [IDENTITY] = inv_identity32_1d_c, + }, [TX_64X64] = { + [DCT] = inv_dct64_1d_c, + }, +}; + +const uint8_t /* enum Tx1dType */ dav1d_tx1d_types[N_TX_TYPES][2] = { + [DCT_DCT] = { DCT, DCT }, + [ADST_DCT] = { ADST, DCT }, + [DCT_ADST] = { DCT, ADST }, + [ADST_ADST] = { ADST, ADST }, + [FLIPADST_DCT] = { FLIPADST, DCT }, + [DCT_FLIPADST] = { DCT, FLIPADST }, + [FLIPADST_FLIPADST] = { FLIPADST, FLIPADST }, + [ADST_FLIPADST] = { ADST, FLIPADST }, + [FLIPADST_ADST] = { FLIPADST, ADST }, + [IDTX] = { IDENTITY, IDENTITY }, + [V_DCT] = { DCT, IDENTITY }, + [H_DCT] = { IDENTITY, DCT }, + [V_ADST] = { ADST, IDENTITY }, + [H_ADST] = { IDENTITY, ADST }, + [V_FLIPADST] = { FLIPADST, IDENTITY }, + [H_FLIPADST] = { IDENTITY, FLIPADST }, +}; + #if !(HAVE_ASM && TRIM_DSP_FUNCTIONS && ( \ ARCH_AARCH64 || \ (ARCH_ARM && (defined(__ARM_NEON) || defined(__APPLE__) || defined(_WIN32))) \ diff --git a/src/itx_1d.h b/src/itx_1d.h index b63d71b02..880ac99a3 100644 --- a/src/itx_1d.h +++ b/src/itx_1d.h @@ -28,31 +28,25 @@ #include #include +#include "src/levels.h" + #ifndef DAV1D_SRC_ITX_1D_H #define DAV1D_SRC_ITX_1D_H +enum Tx1dType { + DCT, + ADST, + IDENTITY, + FLIPADST, + N_TX_1D_TYPES, +}; + #define decl_itx_1d_fn(name) \ void (name)(int32_t *c, ptrdiff_t stride, int min, int max) typedef decl_itx_1d_fn(*itx_1d_fn); -decl_itx_1d_fn(dav1d_inv_dct4_1d_c); -decl_itx_1d_fn(dav1d_inv_dct8_1d_c); -decl_itx_1d_fn(dav1d_inv_dct16_1d_c); -decl_itx_1d_fn(dav1d_inv_dct32_1d_c); -decl_itx_1d_fn(dav1d_inv_dct64_1d_c); - -decl_itx_1d_fn(dav1d_inv_adst4_1d_c); -decl_itx_1d_fn(dav1d_inv_adst8_1d_c); -decl_itx_1d_fn(dav1d_inv_adst16_1d_c); - -decl_itx_1d_fn(dav1d_inv_flipadst4_1d_c); -decl_itx_1d_fn(dav1d_inv_flipadst8_1d_c); -decl_itx_1d_fn(dav1d_inv_flipadst16_1d_c); - -decl_itx_1d_fn(dav1d_inv_identity4_1d_c); -decl_itx_1d_fn(dav1d_inv_identity8_1d_c); -decl_itx_1d_fn(dav1d_inv_identity16_1d_c); -decl_itx_1d_fn(dav1d_inv_identity32_1d_c); +EXTERN const itx_1d_fn dav1d_tx1d_fns[N_TX_SIZES][N_TX_1D_TYPES]; +EXTERN const uint8_t /* enum Tx1dType */ dav1d_tx1d_types[N_TX_TYPES][2]; void dav1d_inv_wht4_1d_c(int32_t *c, ptrdiff_t stride); diff --git a/src/itx_tmpl.c b/src/itx_tmpl.c index 1a37c3d54..ed01fa443 100644 --- a/src/itx_tmpl.c +++ b/src/itx_tmpl.c @@ -29,6 +29,7 @@ #include #include +#include #include #include "common/attributes.h" @@ -36,13 +37,17 @@ #include "src/itx.h" #include "src/itx_1d.h" +#include "src/scan.h" +#include "src/tables.h" static NOINLINE void inv_txfm_add_c(pixel *dst, const ptrdiff_t stride, coef *const coeff, - const int eob, const int w, const int h, const int shift, - const itx_1d_fn first_1d_fn, const itx_1d_fn second_1d_fn, - const int has_dconly HIGHBD_DECL_SUFFIX) + const int eob, const /*enum RectTxfmSize*/ int tx, const int shift, + const enum TxfmType txtp HIGHBD_DECL_SUFFIX) { + const TxfmInfo *const t_dim = &dav1d_txfm_dimensions[tx]; + const int w = 4 * t_dim->w, h = 4 * t_dim->h; + const int has_dconly = txtp == DCT_DCT; assert(w >= 4 && w <= 64); assert(h >= 4 && h <= 64); assert(eob >= 0); @@ -64,6 +69,9 @@ inv_txfm_add_c(pixel *dst, const ptrdiff_t stride, coef *const coeff, return; } + const uint8_t *const txtps = dav1d_tx1d_types[txtp]; + const itx_1d_fn first_1d_fn = dav1d_tx1d_fns[t_dim->lw][txtps[0]]; + const itx_1d_fn second_1d_fn = dav1d_tx1d_fns[t_dim->lh][txtps[1]]; const int sh = imin(h, 32), sw = imin(w, 32); #if BITDEPTH == 8 const int row_clip_min = INT16_MIN; @@ -76,7 +84,16 @@ inv_txfm_add_c(pixel *dst, const ptrdiff_t stride, coef *const coeff, const int col_clip_max = ~col_clip_min; int32_t tmp[64 * 64], *c = tmp; - for (int y = 0; y < sh; y++, c += w) { + int last_nonzero_col; // in first 1d itx + if (txtps[1] == IDENTITY && txtps[0] != IDENTITY) { + last_nonzero_col = imin(sh - 1, eob); + } else if (txtps[0] == IDENTITY && txtps[1] != IDENTITY) { + last_nonzero_col = eob >> (t_dim->lw + 2); + } else { + last_nonzero_col = dav1d_last_nonzero_col_from_eob[tx][eob]; + } + assert(last_nonzero_col < sh); + for (int y = 0; y <= last_nonzero_col; y++, c += w) { if (is_rect2) for (int x = 0; x < sw; x++) c[x] = (coeff[y + x * sh] * 181 + 128) >> 8; @@ -85,6 +102,8 @@ inv_txfm_add_c(pixel *dst, const ptrdiff_t stride, coef *const coeff, c[x] = coeff[y + x * sh]; first_1d_fn(c, 1, row_clip_min, row_clip_max); } + if (last_nonzero_col + 1 < sh) + memset(c, 0, sizeof(*c) * (sh - last_nonzero_col - 1) * w); memset(coeff, 0, sizeof(*coeff) * sw * sh); for (int i = 0; i < w * sh; i++) @@ -99,7 +118,7 @@ inv_txfm_add_c(pixel *dst, const ptrdiff_t stride, coef *const coeff, dst[x] = iclip_pixel(dst[x] + ((*c++ + 8) >> 4)); } -#define inv_txfm_fn(type1, type2, w, h, shift, has_dconly) \ +#define inv_txfm_fn(type1, type2, type, pfx, w, h, shift) \ static void \ inv_txfm_add_##type1##_##type2##_##w##x##h##_c(pixel *dst, \ const ptrdiff_t stride, \ @@ -107,57 +126,56 @@ inv_txfm_add_##type1##_##type2##_##w##x##h##_c(pixel *dst, \ const int eob \ HIGHBD_DECL_SUFFIX) \ { \ - inv_txfm_add_c(dst, stride, coeff, eob, w, h, shift, \ - dav1d_inv_##type1##w##_1d_c, dav1d_inv_##type2##h##_1d_c, \ - has_dconly HIGHBD_TAIL_SUFFIX); \ + inv_txfm_add_c(dst, stride, coeff, eob, pfx##TX_##w##X##h, shift, type \ + HIGHBD_TAIL_SUFFIX); \ } -#define inv_txfm_fn64(w, h, shift) \ -inv_txfm_fn(dct, dct, w, h, shift, 1) +#define inv_txfm_fn64(pfx, w, h, shift) \ +inv_txfm_fn(dct, dct, DCT_DCT, pfx, w, h, shift) -#define inv_txfm_fn32(w, h, shift) \ -inv_txfm_fn64(w, h, shift) \ -inv_txfm_fn(identity, identity, w, h, shift, 0) +#define inv_txfm_fn32(pfx, w, h, shift) \ +inv_txfm_fn64(pfx, w, h, shift) \ +inv_txfm_fn(identity, identity, IDTX, pfx, w, h, shift) -#define inv_txfm_fn16(w, h, shift) \ -inv_txfm_fn32(w, h, shift) \ -inv_txfm_fn(adst, dct, w, h, shift, 0) \ -inv_txfm_fn(dct, adst, w, h, shift, 0) \ -inv_txfm_fn(adst, adst, w, h, shift, 0) \ -inv_txfm_fn(dct, flipadst, w, h, shift, 0) \ -inv_txfm_fn(flipadst, dct, w, h, shift, 0) \ -inv_txfm_fn(adst, flipadst, w, h, shift, 0) \ -inv_txfm_fn(flipadst, adst, w, h, shift, 0) \ -inv_txfm_fn(flipadst, flipadst, w, h, shift, 0) \ -inv_txfm_fn(identity, dct, w, h, shift, 0) \ -inv_txfm_fn(dct, identity, w, h, shift, 0) \ +#define inv_txfm_fn16(pfx, w, h, shift) \ +inv_txfm_fn32(pfx, w, h, shift) \ +inv_txfm_fn(adst, dct, ADST_DCT, pfx, w, h, shift) \ +inv_txfm_fn(dct, adst, DCT_ADST, pfx, w, h, shift) \ +inv_txfm_fn(adst, adst, ADST_ADST, pfx, w, h, shift) \ +inv_txfm_fn(dct, flipadst, DCT_FLIPADST, pfx, w, h, shift) \ +inv_txfm_fn(flipadst, dct, FLIPADST_DCT, pfx, w, h, shift) \ +inv_txfm_fn(adst, flipadst, ADST_FLIPADST, pfx, w, h, shift) \ +inv_txfm_fn(flipadst, adst, FLIPADST_ADST, pfx, w, h, shift) \ +inv_txfm_fn(flipadst, flipadst, FLIPADST_FLIPADST, pfx, w, h, shift) \ +inv_txfm_fn(identity, dct, H_DCT, pfx, w, h, shift) \ +inv_txfm_fn(dct, identity, V_DCT, pfx, w, h, shift) \ -#define inv_txfm_fn84(w, h, shift) \ -inv_txfm_fn16(w, h, shift) \ -inv_txfm_fn(identity, flipadst, w, h, shift, 0) \ -inv_txfm_fn(flipadst, identity, w, h, shift, 0) \ -inv_txfm_fn(identity, adst, w, h, shift, 0) \ -inv_txfm_fn(adst, identity, w, h, shift, 0) \ +#define inv_txfm_fn84(pfx, w, h, shift) \ +inv_txfm_fn16(pfx, w, h, shift) \ +inv_txfm_fn(identity, flipadst, H_FLIPADST, pfx, w, h, shift) \ +inv_txfm_fn(flipadst, identity, V_FLIPADST, pfx, w, h, shift) \ +inv_txfm_fn(identity, adst, H_ADST, pfx, w, h, shift) \ +inv_txfm_fn(adst, identity, V_ADST, pfx, w, h, shift) \ -inv_txfm_fn84( 4, 4, 0) -inv_txfm_fn84( 4, 8, 0) -inv_txfm_fn84( 4, 16, 1) -inv_txfm_fn84( 8, 4, 0) -inv_txfm_fn84( 8, 8, 1) -inv_txfm_fn84( 8, 16, 1) -inv_txfm_fn32( 8, 32, 2) -inv_txfm_fn84(16, 4, 1) -inv_txfm_fn84(16, 8, 1) -inv_txfm_fn16(16, 16, 2) -inv_txfm_fn32(16, 32, 1) -inv_txfm_fn64(16, 64, 2) -inv_txfm_fn32(32, 8, 2) -inv_txfm_fn32(32, 16, 1) -inv_txfm_fn32(32, 32, 2) -inv_txfm_fn64(32, 64, 1) -inv_txfm_fn64(64, 16, 2) -inv_txfm_fn64(64, 32, 1) -inv_txfm_fn64(64, 64, 2) +inv_txfm_fn84( , 4, 4, 0) +inv_txfm_fn84(R, 4, 8, 0) +inv_txfm_fn84(R, 4, 16, 1) +inv_txfm_fn84(R, 8, 4, 0) +inv_txfm_fn84( , 8, 8, 1) +inv_txfm_fn84(R, 8, 16, 1) +inv_txfm_fn32(R, 8, 32, 2) +inv_txfm_fn84(R, 16, 4, 1) +inv_txfm_fn84(R, 16, 8, 1) +inv_txfm_fn16( , 16, 16, 2) +inv_txfm_fn32(R, 16, 32, 1) +inv_txfm_fn64(R, 16, 64, 2) +inv_txfm_fn32(R, 32, 8, 2) +inv_txfm_fn32(R, 32, 16, 1) +inv_txfm_fn32( , 32, 32, 2) +inv_txfm_fn64(R, 32, 64, 1) +inv_txfm_fn64(R, 64, 16, 2) +inv_txfm_fn64(R, 64, 32, 1) +inv_txfm_fn64( , 64, 64, 2) #if !(HAVE_ASM && TRIM_DSP_FUNCTIONS && ( \ ARCH_AARCH64 || \ @@ -263,12 +281,16 @@ COLD void bitfn(dav1d_itx_dsp_init)(Dav1dInvTxfmDSPContext *const c, int bpc) { assign_itx_all_fn64(64, 32, R); assign_itx_all_fn64(64, 64, ); + int all_simd = 0; #if HAVE_ASM #if ARCH_AARCH64 || ARCH_ARM - itx_dsp_init_arm(c, bpc); + itx_dsp_init_arm(c, bpc, &all_simd); #endif #if ARCH_X86 - itx_dsp_init_x86(c, bpc); + itx_dsp_init_x86(c, bpc, &all_simd); #endif #endif + + if (!all_simd) + dav1d_init_last_nonzero_col_from_eob_tables(); } diff --git a/src/scan.c b/src/scan.c index 5261ccd3d..6f9dc0369 100644 --- a/src/scan.c +++ b/src/scan.c @@ -28,7 +28,10 @@ #include "config.h" #include "common/attributes.h" +#include "common/intops.h" + #include "src/scan.h" +#include "src/thread.h" static const uint16_t ALIGN(scan_4x4[], 32) = { 0, 4, 1, 2, @@ -297,3 +300,76 @@ const uint16_t *const dav1d_scans[N_RECT_TX_SIZES] = { [RTX_16X64] = scan_16x32, [RTX_64X16] = scan_32x16, }; + +static uint8_t last_nonzero_col_from_eob_4x4[16]; +static uint8_t last_nonzero_col_from_eob_8x8[64]; +static uint8_t last_nonzero_col_from_eob_16x16[256]; +static uint8_t last_nonzero_col_from_eob_32x32[1024]; +static uint8_t last_nonzero_col_from_eob_4x8[32]; +static uint8_t last_nonzero_col_from_eob_8x4[32]; +static uint8_t last_nonzero_col_from_eob_8x16[128]; +static uint8_t last_nonzero_col_from_eob_16x8[128]; +static uint8_t last_nonzero_col_from_eob_16x32[512]; +static uint8_t last_nonzero_col_from_eob_32x16[512]; +static uint8_t last_nonzero_col_from_eob_4x16[64]; +static uint8_t last_nonzero_col_from_eob_16x4[64]; +static uint8_t last_nonzero_col_from_eob_8x32[256]; +static uint8_t last_nonzero_col_from_eob_32x8[256]; + +static COLD void init_tbl(uint8_t *const last_nonzero_col_from_eob, + const uint16_t *const scan, const int w, const int h) +{ + int max_col = 0; + for (int y = 0, n = 0; y < h; y++) { + for (int x = 0; x < w; x++, n++) { + const int rc = scan[n]; + const int rcx = rc & (h - 1); + max_col = imax(max_col, rcx); + last_nonzero_col_from_eob[n] = max_col; + } + } +} + +static COLD void init_internal(void) { + init_tbl(last_nonzero_col_from_eob_4x4, scan_4x4, 4, 4); + init_tbl(last_nonzero_col_from_eob_8x8, scan_8x8, 8, 8); + init_tbl(last_nonzero_col_from_eob_16x16, scan_16x16, 16, 16); + init_tbl(last_nonzero_col_from_eob_32x32, scan_32x32, 32, 32); + init_tbl(last_nonzero_col_from_eob_4x8, scan_4x8, 4, 8); + init_tbl(last_nonzero_col_from_eob_8x4, scan_8x4, 8, 4); + init_tbl(last_nonzero_col_from_eob_8x16, scan_8x16, 8, 16); + init_tbl(last_nonzero_col_from_eob_16x8, scan_16x8, 16, 8); + init_tbl(last_nonzero_col_from_eob_16x32, scan_16x32, 16, 32); + init_tbl(last_nonzero_col_from_eob_32x16, scan_32x16, 32, 16); + init_tbl(last_nonzero_col_from_eob_4x16, scan_4x16, 4, 16); + init_tbl(last_nonzero_col_from_eob_16x4, scan_16x4, 16, 4); + init_tbl(last_nonzero_col_from_eob_8x32, scan_8x32, 8, 32); + init_tbl(last_nonzero_col_from_eob_32x8, scan_32x8, 32, 8); +} + +COLD void dav1d_init_last_nonzero_col_from_eob_tables(void) { + static pthread_once_t initted = PTHREAD_ONCE_INIT; + pthread_once(&initted, init_internal); +} + +const uint8_t *const dav1d_last_nonzero_col_from_eob[N_RECT_TX_SIZES] = { + [ TX_4X4 ] = last_nonzero_col_from_eob_4x4, + [ TX_8X8 ] = last_nonzero_col_from_eob_8x8, + [ TX_16X16] = last_nonzero_col_from_eob_16x16, + [ TX_32X32] = last_nonzero_col_from_eob_32x32, + [ TX_64X64] = last_nonzero_col_from_eob_32x32, + [RTX_4X8 ] = last_nonzero_col_from_eob_4x8, + [RTX_8X4 ] = last_nonzero_col_from_eob_8x4, + [RTX_8X16 ] = last_nonzero_col_from_eob_8x16, + [RTX_16X8 ] = last_nonzero_col_from_eob_16x8, + [RTX_16X32] = last_nonzero_col_from_eob_16x32, + [RTX_32X16] = last_nonzero_col_from_eob_32x16, + [RTX_32X64] = last_nonzero_col_from_eob_32x32, + [RTX_64X32] = last_nonzero_col_from_eob_32x32, + [RTX_4X16 ] = last_nonzero_col_from_eob_4x16, + [RTX_16X4 ] = last_nonzero_col_from_eob_16x4, + [RTX_8X32 ] = last_nonzero_col_from_eob_8x32, + [RTX_32X8 ] = last_nonzero_col_from_eob_32x8, + [RTX_16X64] = last_nonzero_col_from_eob_16x32, + [RTX_64X16] = last_nonzero_col_from_eob_32x16, +}; diff --git a/src/scan.h b/src/scan.h index 09df98877..2bd0b5b84 100644 --- a/src/scan.h +++ b/src/scan.h @@ -33,5 +33,8 @@ #include "src/levels.h" EXTERN const uint16_t *const dav1d_scans[N_RECT_TX_SIZES]; +EXTERN const uint8_t *const dav1d_last_nonzero_col_from_eob[N_RECT_TX_SIZES]; + +void dav1d_init_last_nonzero_col_from_eob_tables(void); #endif /* DAV1D_SRC_SCAN_H */ diff --git a/src/scan.rs b/src/scan.rs index 11a468a57..85ae312cc 100644 --- a/src/scan.rs +++ b/src/scan.rs @@ -1,6 +1,7 @@ use strum::EnumCount; use crate::align::Align32; +use crate::const_fn::const_for; use crate::in_range::InRange; use crate::levels::TxfmSize; @@ -228,3 +229,54 @@ pub static DAV1D_SCANS: [&'static [Scan]; TxfmSize::COUNT] = [ &SCAN_16X32.0, &SCAN_32X16.0, ]; + +const fn init_tbl(scan: &'static [Scan; S], h: u16) -> [u8; S] { + let mut last_nonzero_col_from_eob: [u8; S] = [0; S]; + + let mut max_col: u8 = 0; + const_for!(n in 0..S => { + let rc = scan[n].const_get(); + let rcx = (rc & (h - 1) as u16) as u8; + max_col = if rcx > max_col { rcx } else {max_col }; + last_nonzero_col_from_eob[n] = max_col; + }); + + last_nonzero_col_from_eob +} + +static LAST_NONZERO_COL_FROM_EOB_4X4: [u8; 16] = init_tbl(&SCAN_4X4.0, 4); +static LAST_NONZERO_COL_FROM_EOB_8X8: [u8; 64] = init_tbl(&SCAN_8X8.0, 8); +static LAST_NONZERO_COL_FROM_EOB_16X16: [u8; 256] = init_tbl(&SCAN_16X16.0, 16); +static LAST_NONZERO_COL_FROM_EOB_32X32: [u8; 1024] = init_tbl(&SCAN_32X32.0, 32); +static LAST_NONZERO_COL_FROM_EOB_4X8: [u8; 32] = init_tbl(&SCAN_4X8.0, 8); +static LAST_NONZERO_COL_FROM_EOB_8X4: [u8; 32] = init_tbl(&SCAN_8X4.0, 4); +static LAST_NONZERO_COL_FROM_EOB_8X16: [u8; 128] = init_tbl(&SCAN_8X16.0, 16); +static LAST_NONZERO_COL_FROM_EOB_16X8: [u8; 128] = init_tbl(&SCAN_16X8.0, 8); +static LAST_NONZERO_COL_FROM_EOB_16X32: [u8; 512] = init_tbl(&SCAN_16X32.0, 32); +static LAST_NONZERO_COL_FROM_EOB_32X16: [u8; 512] = init_tbl(&SCAN_32X16.0, 16); +static LAST_NONZERO_COL_FROM_EOB_4X16: [u8; 64] = init_tbl(&SCAN_4X16.0, 16); +static LAST_NONZERO_COL_FROM_EOB_16X4: [u8; 64] = init_tbl(&SCAN_16X4.0, 4); +static LAST_NONZERO_COL_FROM_EOB_8X32: [u8; 256] = init_tbl(&SCAN_8X32.0, 32); +static LAST_NONZERO_COL_FROM_EOB_32X8: [u8; 256] = init_tbl(&SCAN_32X8.0, 8); + +pub static DAV1D_LAST_NONZERO_COL_FROM_EOB: [&'static [u8]; TxfmSize::COUNT] = [ + &LAST_NONZERO_COL_FROM_EOB_4X4, + &LAST_NONZERO_COL_FROM_EOB_8X8, + &LAST_NONZERO_COL_FROM_EOB_16X16, + &LAST_NONZERO_COL_FROM_EOB_32X32, + &LAST_NONZERO_COL_FROM_EOB_32X32, + &LAST_NONZERO_COL_FROM_EOB_4X8, + &LAST_NONZERO_COL_FROM_EOB_8X4, + &LAST_NONZERO_COL_FROM_EOB_8X16, + &LAST_NONZERO_COL_FROM_EOB_16X8, + &LAST_NONZERO_COL_FROM_EOB_16X32, + &LAST_NONZERO_COL_FROM_EOB_32X16, + &LAST_NONZERO_COL_FROM_EOB_32X32, + &LAST_NONZERO_COL_FROM_EOB_32X32, + &LAST_NONZERO_COL_FROM_EOB_4X16, + &LAST_NONZERO_COL_FROM_EOB_16X4, + &LAST_NONZERO_COL_FROM_EOB_8X32, + &LAST_NONZERO_COL_FROM_EOB_32X8, + &LAST_NONZERO_COL_FROM_EOB_16X32, + &LAST_NONZERO_COL_FROM_EOB_32X16, +]; diff --git a/src/x86/itx.h b/src/x86/itx.h index e9aab4587..c96437421 100644 --- a/src/x86/itx.h +++ b/src/x86/itx.h @@ -107,7 +107,9 @@ decl_itx_fns(ssse3); decl_itx_fn(dav1d_inv_txfm_add_wht_wht_4x4_16bpc_avx2); decl_itx_fn(BF(dav1d_inv_txfm_add_wht_wht_4x4, sse2)); -static ALWAYS_INLINE void itx_dsp_init_x86(Dav1dInvTxfmDSPContext *const c, const int bpc) { +static ALWAYS_INLINE void itx_dsp_init_x86(Dav1dInvTxfmDSPContext *const c, + const int bpc, int *const all_simd) +{ #define assign_itx_bpc_fn(pfx, w, h, type, type_enum, bpc, ext) \ c->itxfm_add[pfx##TX_##w##X##h][type_enum] = \ BF_BPC(dav1d_inv_txfm_add_##type##_##w##x##h, bpc, ext) @@ -167,6 +169,7 @@ static ALWAYS_INLINE void itx_dsp_init_x86(Dav1dInvTxfmDSPContext *const c, cons assign_itx1_fn (R, 64, 16, ssse3); assign_itx1_fn (R, 64, 32, ssse3); assign_itx1_fn ( , 64, 64, ssse3); + *all_simd = 1; #endif if (!(flags & DAV1D_X86_CPU_FLAG_SSE41)) return; @@ -192,6 +195,7 @@ static ALWAYS_INLINE void itx_dsp_init_x86(Dav1dInvTxfmDSPContext *const c, cons assign_itx1_fn (R, 64, 16, sse4); assign_itx1_fn (R, 64, 32, sse4); assign_itx1_fn (, 64, 64, sse4); + *all_simd = 1; } #endif diff --git a/tests/checkasm/itx.c b/tests/checkasm/itx.c index c7cc411ff..b0de65ddd 100644 --- a/tests/checkasm/itx.c +++ b/tests/checkasm/itx.c @@ -130,7 +130,8 @@ static void fwht4_1d(double *const out, const double *const in) static int copy_subcoefs(coef *coeff, const enum RectTxfmSize tx, const enum TxfmType txtp, - const int sw, const int sh, const int subsh) + const int sw, const int sh, const int subsh, + int *const max_eob) { /* copy the topleft coefficients such that the return value (being the * coefficient scantable index for the eob token) guarantees that only @@ -160,6 +161,7 @@ static int copy_subcoefs(coef *coeff, } else if (!eob && (rcx > sub_low || rcy > sub_low)) eob = n; /* lower boundary */ } + *max_eob = n - 1; if (eob) eob += rnd() % (n - eob - 1); @@ -182,7 +184,7 @@ static int copy_subcoefs(coef *coeff, static int ftx(coef *const buf, const enum RectTxfmSize tx, const enum TxfmType txtp, const int w, const int h, - const int subsh, const int bitdepth_max) + const int subsh, int *const max_eob, const int bitdepth_max) { double out[64 * 64], temp[64 * 64]; const double scale = scaling_factors[ctz(w * h) - 4]; @@ -236,7 +238,7 @@ static int ftx(coef *const buf, const enum RectTxfmSize tx, for (int x = 0; x < sw; x++) buf[y * sw + x] = (coef) (out[y * w + x] + 0.5); - return copy_subcoefs(buf, tx, txtp, sw, sh, subsh); + return copy_subcoefs(buf, tx, txtp, sw, sh, subsh, max_eob); } static void check_itxfm_add(Dav1dInvTxfmDSPContext *const c, @@ -272,7 +274,9 @@ static void check_itxfm_add(Dav1dInvTxfmDSPContext *const c, bpc)) { const int bitdepth_max = (1 << bpc) - 1; - const int eob = ftx(coeff[0], tx, txtp, w, h, subsh, bitdepth_max); + int max_eob; + const int eob = ftx(coeff[0], tx, txtp, w, h, subsh, &max_eob, + bitdepth_max); memcpy(coeff[1], coeff[0], sizeof(*coeff)); CLEAR_PIXEL_RECT(c_dst); @@ -295,7 +299,7 @@ static void check_itxfm_add(Dav1dInvTxfmDSPContext *const c, fail(); bench_new(alternate(c_dst, a_dst), a_dst_stride, - alternate(coeff[0], coeff[1]), eob HIGHBD_TAIL_SUFFIX); + alternate(coeff[0], coeff[1]), max_eob HIGHBD_TAIL_SUFFIX); } } report("add_%dx%d", w, h);