From 5f99c89698229895f2c91cc49ef28376fd49a767 Mon Sep 17 00:00:00 2001 From: Aous Naman Date: Tue, 5 Nov 2024 18:24:27 +1100 Subject: [PATCH] Most of the SIMD is done. No wasm yet. --- src/core/codestream/ojph_codeblock.cpp | 6 +- src/core/codestream/ojph_codeblock_fun.cpp | 41 +- src/core/codestream/ojph_codeblock_fun.h | 6 +- src/core/codestream/ojph_codestream_avx.cpp | 2 +- src/core/codestream/ojph_codestream_avx2.cpp | 79 ++- src/core/codestream/ojph_codestream_gen.cpp | 12 +- src/core/codestream/ojph_codestream_sse.cpp | 3 +- src/core/codestream/ojph_codestream_sse2.cpp | 85 ++- src/core/codestream/ojph_resolution.cpp | 36 +- src/core/coding/ojph_block_encoder_avx2.cpp | 7 +- src/core/coding/ojph_block_encoder_avx512.cpp | 10 +- src/core/transform/ojph_colour.cpp | 10 +- src/core/transform/ojph_colour_local.h | 43 +- src/core/transform/ojph_colour_sse2.cpp | 354 ++++++++-- src/core/transform/ojph_transform.cpp | 38 +- src/core/transform/ojph_transform_avx.cpp | 74 +- src/core/transform/ojph_transform_avx2.cpp | 656 +++++++++++++++++- src/core/transform/ojph_transform_local.h | 76 +- src/core/transform/ojph_transform_sse.cpp | 4 +- src/core/transform/ojph_transform_sse2.cpp | 572 ++++++++++++++- 20 files changed, 1834 insertions(+), 280 deletions(-) diff --git a/src/core/codestream/ojph_codeblock.cpp b/src/core/codestream/ojph_codeblock.cpp index 53d9a6b1..bd76fb3f 100644 --- a/src/core/codestream/ojph_codeblock.cpp +++ b/src/core/codestream/ojph_codeblock.cpp @@ -245,7 +245,7 @@ namespace ojph { cb_size.w); } else - this->codeblock_functions.mem_clear32(dp, cb_size.w * sizeof(ui32)); + this->codeblock_functions.mem_clear(dp, cb_size.w * sizeof(ui32)); } else { @@ -259,9 +259,7 @@ namespace ojph { cb_size.w); } else - this->codeblock_functions.mem_clear64(dp, cb_size.w * sizeof(*dp)); - - + this->codeblock_functions.mem_clear(dp, cb_size.w * sizeof(*dp)); } ++cur_line; diff --git a/src/core/codestream/ojph_codeblock_fun.cpp b/src/core/codestream/ojph_codeblock_fun.cpp index 4474428f..c0b70dc9 100644 --- a/src/core/codestream/ojph_codeblock_fun.cpp +++ b/src/core/codestream/ojph_codeblock_fun.cpp @@ -57,15 +57,10 @@ namespace ojph { { ////////////////////////////////////////////////////////////////////////// - void gen_mem_clear32(si32* addr, size_t count); - void sse_mem_clear32(si32* addr, size_t count); - void avx_mem_clear32(si32* addr, size_t count); - void wasm_mem_clear32(si32* addr, size_t count); - - void gen_mem_clear64(si64* addr, size_t count); - void sse_mem_clear64(si64* addr, size_t count); - void avx_mem_clear64(si64* addr, size_t count); - void wasm_mem_clear64(si64* addr, size_t count); + void gen_mem_clear(void* addr, size_t count); + void sse_mem_clear(void* addr, size_t count); + void avx_mem_clear(void* addr, size_t count); + void wasm_mem_clear(void* addr, size_t count); ////////////////////////////////////////////////////////////////////////// ui32 gen_find_max_val32(ui32* address); @@ -135,7 +130,7 @@ namespace ojph { // Default path, no acceleration. We may change this later decode_cb32 = ojph_decode_codeblock32; find_max_val32 = gen_find_max_val32; - mem_clear32 = gen_mem_clear32; + mem_clear = gen_mem_clear; if (reversible) { tx_to_cb32 = gen_rev_tx_to_cb32; tx_from_cb32 = gen_rev_tx_from_cb32; @@ -149,7 +144,6 @@ namespace ojph { decode_cb64 = ojph_decode_codeblock64; find_max_val64 = gen_find_max_val64; - mem_clear64 = gen_mem_clear64; if (reversible) { tx_to_cb64 = gen_rev_tx_to_cb64; tx_from_cb64 = gen_rev_tx_from_cb64; @@ -168,7 +162,7 @@ namespace ojph { // Accelerated functions for INTEL/AMD CPUs #ifndef OJPH_DISABLE_SSE if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_SSE) - mem_clear32 = sse_mem_clear32; + mem_clear = sse_mem_clear; #endif // !OJPH_DISABLE_SSE #ifndef OJPH_DISABLE_SSE2 @@ -182,6 +176,16 @@ namespace ojph { tx_to_cb32 = sse2_irv_tx_to_cb32; tx_from_cb32 = sse2_irv_tx_from_cb32; } + find_max_val64 = sse2_find_max_val64; + if (reversible) { + tx_to_cb64 = sse2_rev_tx_to_cb64; + tx_from_cb64 = sse2_rev_tx_from_cb64; + } + else + { + tx_to_cb64 = NULL; + tx_from_cb64 = NULL; + } } #endif // !OJPH_DISABLE_SSE2 @@ -192,7 +196,7 @@ namespace ojph { #ifndef OJPH_DISABLE_AVX if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_AVX) - mem_clear32 = avx_mem_clear32; + mem_clear = avx_mem_clear; #endif // !OJPH_DISABLE_AVX #ifndef OJPH_DISABLE_AVX2 @@ -208,6 +212,17 @@ namespace ojph { } encode_cb32 = ojph_encode_codeblock_avx2; decode_cb32 = ojph_decode_codeblock_avx2; + + find_max_val64 = avx2_find_max_val64; + if (reversible) { + tx_to_cb64 = avx2_rev_tx_to_cb64; + tx_from_cb64 = avx2_rev_tx_from_cb64; + } + else + { + tx_to_cb64 = NULL; + tx_from_cb64 = NULL; + } } #endif // !OJPH_DISABLE_AVX2 diff --git a/src/core/codestream/ojph_codeblock_fun.h b/src/core/codestream/ojph_codeblock_fun.h index 03b3b243..67fbc2b7 100644 --- a/src/core/codestream/ojph_codeblock_fun.h +++ b/src/core/codestream/ojph_codeblock_fun.h @@ -48,8 +48,7 @@ namespace ojph { namespace local { // define function signature simple memory clearing - typedef void (*mem_clear_fun32)(si32* addr, size_t count); - typedef void (*mem_clear_fun64)(si64* addr, size_t count); + typedef void (*mem_clear_fun)(void* addr, size_t count); // define function signature for max value finding typedef ui32 (*find_max_val_fun32)(ui32* addr); @@ -96,8 +95,7 @@ namespace ojph { void init(bool reversible); // a pointer to the max value finding function - mem_clear_fun32 mem_clear32; - mem_clear_fun64 mem_clear64; + mem_clear_fun mem_clear; // a pointer to the max value finding function find_max_val_fun32 find_max_val32; diff --git a/src/core/codestream/ojph_codestream_avx.cpp b/src/core/codestream/ojph_codestream_avx.cpp index 22405c7e..4c6d678d 100644 --- a/src/core/codestream/ojph_codestream_avx.cpp +++ b/src/core/codestream/ojph_codestream_avx.cpp @@ -42,7 +42,7 @@ namespace ojph { namespace local { ////////////////////////////////////////////////////////////////////////// - void avx_mem_clear32(si32* addr, size_t count) + void avx_mem_clear(void* addr, size_t count) { float* p = (float*)addr; __m256 zero = _mm256_setzero_ps(); diff --git a/src/core/codestream/ojph_codestream_avx2.cpp b/src/core/codestream/ojph_codestream_avx2.cpp index bd849b59..c01e0718 100644 --- a/src/core/codestream/ojph_codestream_avx2.cpp +++ b/src/core/codestream/ojph_codestream_avx2.cpp @@ -55,6 +55,18 @@ namespace ojph { return t; } + ////////////////////////////////////////////////////////////////////////// + ui64 avx2_find_max_val64(ui64* address) + { + __m128i x0 = _mm_loadu_si128((__m128i*)address); + __m128i x1 = _mm_loadu_si128((__m128i*)address + 1); + x0 = _mm_or_si128(x0, x1); + x1 = _mm_shuffle_epi32(x0, 0xEE); // x1 = x0[2,3,2,3] + x0 = _mm_or_si128(x0, x1); + ui64 t = (ui64)_mm_extract_epi64(x0, 0); + return t; + } + ////////////////////////////////////////////////////////////////////////// void avx2_rev_tx_to_cb32(const void *sp, ui32 *dp, ui32 K_max, float delta_inv, ui32 count, ui32* max_val) @@ -78,7 +90,7 @@ namespace ojph { } _mm256_storeu_si256((__m256i*)max_val, tmax); } - + ////////////////////////////////////////////////////////////////////////// void avx2_irv_tx_to_cb32(const void *sp, ui32 *dp, ui32 K_max, float delta_inv, ui32 count, ui32* max_val) @@ -115,11 +127,11 @@ namespace ojph { si32 *p = (si32*)dp; for (ui32 i = 0; i < count; i += 8, sp += 8, p += 8) { - __m256i v = _mm256_load_si256((__m256i*)sp); - __m256i val = _mm256_and_si256(v, m1); - val = _mm256_srli_epi32(val, (int)shift); - val = _mm256_sign_epi32(val, v); - _mm256_storeu_si256((__m256i*)p, val); + __m256i v = _mm256_load_si256((__m256i*)sp); + __m256i val = _mm256_and_si256(v, m1); + val = _mm256_srli_epi32(val, (int)shift); + val = _mm256_sign_epi32(val, v); + _mm256_storeu_si256((__m256i*)p, val); } } @@ -142,5 +154,58 @@ namespace ojph { _mm256_storeu_ps(p, valf); } } + + ////////////////////////////////////////////////////////////////////////// + void avx2_rev_tx_to_cb64(const void *sp, ui64 *dp, ui32 K_max, + float delta_inv, ui32 count, ui64* max_val) + { + ojph_unused(delta_inv); + + // convert to sign and magnitude and keep max_val + ui32 shift = 63 - K_max; + __m256i m0 = _mm256_set1_epi64x(0x8000000000000000LL); + __m256i zero = _mm256_setzero_si256(); + __m256i one = _mm256_set1_epi64x(1); + __m256i tmax = _mm256_loadu_si256((__m256i*)max_val); + __m256i *p = (__m256i*)sp; + for (ui32 i = 0; i < count; i += 4, p += 1, dp += 4) + { + __m256i v = _mm256_loadu_si256(p); + __m256i sign = _mm256_cmpgt_epi64(zero, v); + __m256i val = _mm256_xor_si256(v, sign); // negate 1's complement + __m256i ones = _mm256_and_si256(sign, one); + val = _mm256_add_epi64(val, ones); // 2's complement + sign = _mm256_and_si256(sign, m0); + val = _mm256_slli_epi64(val, (int)shift); + tmax = _mm256_or_si256(tmax, val); + val = _mm256_or_si256(val, sign); + _mm256_storeu_si256((__m256i*)dp, val); + } + _mm256_storeu_si256((__m256i*)max_val, tmax); + } + + ////////////////////////////////////////////////////////////////////////// + void avx2_rev_tx_from_cb64(const ui64 *sp, void *dp, ui32 K_max, + float delta, ui32 count) + { + ojph_unused(delta); + + ui32 shift = 63 - K_max; + __m256i m1 = _mm256_set1_epi64x(0x7FFFFFFFFFFFFFFFLL); + __m256i zero = _mm256_setzero_si256(); + __m256i one = _mm256_set1_epi64x(1); + si64 *p = (si64*)dp; + for (ui32 i = 0; i < count; i += 4, sp += 4, p += 4) + { + __m256i v = _mm256_load_si256((__m256i*)sp); + __m256i val = _mm256_and_si256(v, m1); + val = _mm256_srli_epi64(val, (int)shift); + __m256i sign = _mm256_cmpgt_epi64(zero, v); + val = _mm256_xor_si256(val, sign); // negate 1's complement + __m256i ones = _mm256_and_si256(sign, one); + val = _mm256_add_epi64(val, ones); // 2's complement + _mm256_storeu_si256((__m256i*)p, val); + } + } } -} \ No newline at end of file +} diff --git a/src/core/codestream/ojph_codestream_gen.cpp b/src/core/codestream/ojph_codestream_gen.cpp index 50fc878d..cdc72c6e 100644 --- a/src/core/codestream/ojph_codestream_gen.cpp +++ b/src/core/codestream/ojph_codestream_gen.cpp @@ -42,17 +42,11 @@ namespace ojph { namespace local { ////////////////////////////////////////////////////////////////////////// - void gen_mem_clear32(si32* addr, size_t count) - { - for (size_t i = 0; i < count; i += 4) - *addr++ = 0; - } - - ////////////////////////////////////////////////////////////////////////// - void gen_mem_clear64(si64* addr, size_t count) + void gen_mem_clear(void* addr, size_t count) { + si64* p = (si64*)addr; for (size_t i = 0; i < count; i += 8) - *addr++ = 0; + *p++ = 0; } ////////////////////////////////////////////////////////////////////////// diff --git a/src/core/codestream/ojph_codestream_sse.cpp b/src/core/codestream/ojph_codestream_sse.cpp index 99082aaa..6a31cbd6 100644 --- a/src/core/codestream/ojph_codestream_sse.cpp +++ b/src/core/codestream/ojph_codestream_sse.cpp @@ -42,13 +42,12 @@ namespace ojph { namespace local { ////////////////////////////////////////////////////////////////////////// - void sse_mem_clear32(si32* addr, size_t count) + void sse_mem_clear(void* addr, size_t count) { float* p = (float*)addr; __m128 zero = _mm_setzero_ps(); for (size_t i = 0; i < count; i += 16, p += 4) _mm_storeu_ps(p, zero); } - } } \ No newline at end of file diff --git a/src/core/codestream/ojph_codestream_sse2.cpp b/src/core/codestream/ojph_codestream_sse2.cpp index 145db822..738f24b0 100644 --- a/src/core/codestream/ojph_codestream_sse2.cpp +++ b/src/core/codestream/ojph_codestream_sse2.cpp @@ -58,6 +58,21 @@ namespace ojph { // return t; } + ////////////////////////////////////////////////////////////////////////// + ui64 sse2_find_max_val64(ui64* address) + { + __m128i x1, x0 = _mm_loadu_si128((__m128i*)address); + x1 = _mm_shuffle_epi32(x0, 0xEE); // x1 = x0[2,3,2,3] + x0 = _mm_or_si128(x0, x1); + _mm_storeu_si128((__m128i*)address, x0); + return *address; + // A single movd t, xmm0 can do the trick, but it is not available + // in SSE2 intrinsics. extract_epi32 is available in sse4.1 + // ui32 t = (ui32)_mm_extract_epi16(x0, 0); + // t |= (ui32)_mm_extract_epi16(x0, 1) << 16; + // return t; + } + ////////////////////////////////////////////////////////////////////////// void sse2_rev_tx_to_cb32(const void *sp, ui32 *dp, ui32 K_max, float delta_inv, ui32 count, ui32* max_val) @@ -129,14 +144,14 @@ namespace ojph { si32 *p = (si32*)dp; for (ui32 i = 0; i < count; i += 4, sp += 4, p += 4) { - __m128i v = _mm_load_si128((__m128i*)sp); - __m128i val = _mm_and_si128(v, m1); - val = _mm_srli_epi32(val, (int)shift); - __m128i sign = _mm_cmplt_epi32(v, zero); - val = _mm_xor_si128(val, sign); // negate 1's complement - __m128i ones = _mm_and_si128(sign, one); - val = _mm_add_epi32(val, ones); // 2's complement - _mm_storeu_si128((__m128i*)p, val); + __m128i v = _mm_load_si128((__m128i*)sp); + __m128i val = _mm_and_si128(v, m1); + val = _mm_srli_epi32(val, (int)shift); + __m128i sign = _mm_cmplt_epi32(v, zero); + val = _mm_xor_si128(val, sign); // negate 1's complement + __m128i ones = _mm_and_si128(sign, one); + val = _mm_add_epi32(val, ones); // 2's complement + _mm_storeu_si128((__m128i*)p, val); } } @@ -159,5 +174,59 @@ namespace ojph { _mm_storeu_ps(p, valf); } } + + ////////////////////////////////////////////////////////////////////////// + void sse2_rev_tx_to_cb64(const void *sp, ui64 *dp, ui32 K_max, + float delta_inv, ui32 count, ui64* max_val) + { + ojph_unused(delta_inv); + + // convert to sign and magnitude and keep max_val + ui32 shift = 63 - K_max; + __m128i m0 = _mm_set1_epi64x(0x8000000000000000LL); + __m128i zero = _mm_setzero_si128(); + __m128i one = _mm_set1_epi64x(1); + __m128i tmax = _mm_loadu_si128((__m128i*)max_val); + __m128i *p = (__m128i*)sp; + for (ui32 i = 0; i < count; i += 2, p += 1, dp += 2) + { + __m128i v = _mm_loadu_si128(p); + __m128i sign = _mm_cmplt_epi32(v, zero); + sign = _mm_shuffle_epi32(sign, 0xF5); // sign = sign[1,1,3,3]; + __m128i val = _mm_xor_si128(v, sign); // negate 1's complement + __m128i ones = _mm_and_si128(sign, one); + val = _mm_add_epi64(val, ones); // 2's complement + sign = _mm_and_si128(sign, m0); + val = _mm_slli_epi64(val, (int)shift); + tmax = _mm_or_si128(tmax, val); + val = _mm_or_si128(val, sign); + _mm_storeu_si128((__m128i*)dp, val); + } + _mm_storeu_si128((__m128i*)max_val, tmax); + } + + ////////////////////////////////////////////////////////////////////////// + void sse2_rev_tx_from_cb64(const ui64 *sp, void *dp, ui32 K_max, + float delta, ui32 count) + { + ojph_unused(delta); + ui32 shift = 63 - K_max; + __m128i m1 = _mm_set1_epi64x(0x7FFFFFFFFFFFFFFFLL); + __m128i zero = _mm_setzero_si128(); + __m128i one = _mm_set1_epi64x(1); + si64 *p = (si64*)dp; + for (ui32 i = 0; i < count; i += 2, sp += 2, p += 2) + { + __m128i v = _mm_load_si128((__m128i*)sp); + __m128i val = _mm_and_si128(v, m1); + val = _mm_srli_epi64(val, (int)shift); + __m128i sign = _mm_cmplt_epi32(v, zero); + sign = _mm_shuffle_epi32(sign, 0xF5); // sign = sign[1,1,3,3]; + val = _mm_xor_si128(val, sign); // negate 1's complement + __m128i ones = _mm_and_si128(sign, one); + val = _mm_add_epi64(val, ones); // 2's complement + _mm_storeu_si128((__m128i*)p, val); + } + } } } \ No newline at end of file diff --git a/src/core/codestream/ojph_resolution.cpp b/src/core/codestream/ojph_resolution.cpp index fb4efdfe..bcb27c98 100644 --- a/src/core/codestream/ojph_resolution.cpp +++ b/src/core/codestream/ojph_resolution.cpp @@ -708,8 +708,8 @@ namespace ojph { rev_horz_syn(atk, aug->line, child_res->pull_line(), bands[1].pull_line(), width, horz_even); else - memcpy(aug->line->i32, child_res->pull_line()->i32, - width * sizeof(si32)); + memcpy(aug->line->p, child_res->pull_line()->p, + width * (aug->line->flags & line_buf::LFT_SIZE_MASK)); aug->active = true; vert_even = !vert_even; ++cur_line; @@ -720,8 +720,8 @@ namespace ojph { rev_horz_syn(atk, sig->line, bands[2].pull_line(), bands[3].pull_line(), width, horz_even); else - memcpy(sig->line->i32, bands[2].pull_line()->i32, - width * sizeof(si32)); + memcpy(sig->line->p, bands[2].pull_line()->p, + width * (sig->line->flags & line_buf::LFT_SIZE_MASK)); sig->active = true; vert_even = !vert_even; ++cur_line; @@ -759,8 +759,8 @@ namespace ojph { rev_horz_syn(atk, aug->line, child_res->pull_line(), bands[1].pull_line(), width, horz_even); else - memcpy(aug->line->i32, child_res->pull_line()->i32, - width * sizeof(si32)); + memcpy(aug->line->p, child_res->pull_line()->p, + width * (aug->line->flags & line_buf::LFT_SIZE_MASK)); } else { @@ -768,11 +768,21 @@ namespace ojph { rev_horz_syn(atk, aug->line, bands[2].pull_line(), bands[3].pull_line(), width, horz_even); else - memcpy(aug->line->i32, bands[2].pull_line()->i32, - width * sizeof(si32)); - si32* sp = aug->line->i32; - for (ui32 i = width; i > 0; --i) - *sp++ >>= 1; + memcpy(aug->line->p, bands[2].pull_line()->p, + width * (aug->line->flags & line_buf::LFT_SIZE_MASK)); + if (aug->line->flags & line_buf::LFT_32BIT) + { + si32* sp = aug->line->i32; + for (ui32 i = width; i > 0; --i) + *sp++ >>= 1; + } + else + { + assert(aug->line->flags & line_buf::LFT_64BIT); + si64* sp = aug->line->i64; + for (ui32 i = width; i > 0; --i) + *sp++ >>= 1; + } } return aug->line; } @@ -880,8 +890,8 @@ namespace ojph { rev_horz_syn(atk, aug->line, child_res->pull_line(), bands[1].pull_line(), width, horz_even); else - memcpy(aug->line->i32, child_res->pull_line()->i32, - width * sizeof(si32)); + memcpy(aug->line->p, child_res->pull_line()->p, + width * (aug->line->flags & line_buf::LFT_SIZE_MASK)); return aug->line; } else diff --git a/src/core/coding/ojph_block_encoder_avx2.cpp b/src/core/coding/ojph_block_encoder_avx2.cpp index d579f83a..6f3db34e 100644 --- a/src/core/coding/ojph_block_encoder_avx2.cpp +++ b/src/core/coding/ojph_block_encoder_avx2.cpp @@ -64,8 +64,8 @@ namespace ojph { // index is (c_q << 8) + (rho << 4) + eps // data is (cwd << 8) + (cwd_len << 4) + eps // table 0 is for the initial line of quads - static ui32 vlc_tbl0[2048] = { 0 }; - static ui32 vlc_tbl1[2048] = { 0 }; + static ui32 vlc_tbl0[2048]; + static ui32 vlc_tbl1[2048]; //UVLC encoding static ui32 ulvc_cwd_pre[33]; @@ -220,6 +220,9 @@ namespace ojph { ///////////////////////////////////////////////////////////////////////// bool initialize_tables_avx2() { if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_AVX2) { + memset(vlc_tbl0, 0, 2048 * sizeof(ui32)); + memset(vlc_tbl1, 0, 2048 * sizeof(ui32)); + bool result; result = vlc_init_tables(); result = result && uvlc_init_tables(); diff --git a/src/core/coding/ojph_block_encoder_avx512.cpp b/src/core/coding/ojph_block_encoder_avx512.cpp index 9df0e8ef..f0c7438b 100644 --- a/src/core/coding/ojph_block_encoder_avx512.cpp +++ b/src/core/coding/ojph_block_encoder_avx512.cpp @@ -64,8 +64,8 @@ namespace ojph { // index is (c_q << 8) + (rho << 4) + eps // data is (cwd << 8) + (cwd_len << 4) + eps // table 0 is for the initial line of quads - static ui32 vlc_tbl0[2048] = { 0 }; - static ui32 vlc_tbl1[2048] = { 0 }; + static ui32 vlc_tbl0[2048]; + static ui32 vlc_tbl1[2048]; //UVLC encoding static ui32 ulvc_cwd_pre[33]; @@ -219,7 +219,11 @@ namespace ojph { ///////////////////////////////////////////////////////////////////////// bool initialize_tables() { - if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_AVX512) { + if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_AVX512) + { + memset(vlc_tbl0, 0, 2048 * sizeof(ui32)); + memset(vlc_tbl1, 0, 2048 * sizeof(ui32)); + bool result; result = vlc_init_tables(); result = result && uvlc_init_tables(); diff --git a/src/core/transform/ojph_colour.cpp b/src/core/transform/ojph_colour.cpp index a72cd3d4..6289ae13 100644 --- a/src/core/transform/ojph_colour.cpp +++ b/src/core/transform/ojph_colour.cpp @@ -109,8 +109,6 @@ namespace ojph { #if !defined(OJPH_ENABLE_WASM_SIMD) || !defined(OJPH_EMSCRIPTEN) - // cnvrt_si32_to_si32_shftd = gen_cnvrt_si32_to_si32_shftd; - // cnvrt_si32_to_si32_nlt_type3 = gen_cnvrt_si32_to_si32_nlt_type3; rev_convert = gen_rev_convert; rev_convert_nlt_type3 = gen_rev_convert_nlt_type3; cnvrt_si32_to_float_shftd = gen_cnvrt_si32_to_float_shftd; @@ -141,12 +139,12 @@ namespace ojph { #ifndef OJPH_DISABLE_SSE2 if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_SSE2) { + rev_convert = sse2_rev_convert; + rev_convert_nlt_type3 = sse2_rev_convert_nlt_type3; cnvrt_float_to_si32_shftd = sse2_cnvrt_float_to_si32_shftd; cnvrt_float_to_si32 = sse2_cnvrt_float_to_si32; - // cnvrt_si32_to_si32_shftd = sse2_cnvrt_si32_to_si32_shftd; - // cnvrt_si32_to_si32_nlt_type3 = sse2_cnvrt_si32_to_si32_nlt_type3; - // rct_forward = sse2_rct_forward; - // rct_backward = sse2_rct_backward; + rct_forward = sse2_rct_forward; + rct_backward = sse2_rct_backward; } #endif // !OJPH_DISABLE_SSE2 diff --git a/src/core/transform/ojph_colour_local.h b/src/core/transform/ojph_colour_local.h index 08e99a92..5314c53b 100644 --- a/src/core/transform/ojph_colour_local.h +++ b/src/core/transform/ojph_colour_local.h @@ -167,21 +167,26 @@ namespace ojph { ////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////// - void sse2_cnvrt_si32_to_si32_shftd(const si32 *sp, si32 *dp, int shift, - ui32 width); + void sse2_rev_convert( + const line_buf *src_line, const ui32 src_line_offset, + line_buf *dst_line, const ui32 dst_line_offset, + si64 shift, ui32 width); ////////////////////////////////////////////////////////////////////////// - void sse2_cnvrt_si32_to_si32_nlt_type3(const si32 *sp, si32 *dp, - int shift, ui32 width); - + void sse2_rev_convert_nlt_type3( + const line_buf *src_line, const ui32 src_line_offset, + line_buf *dst_line, const ui32 dst_line_offset, + si64 shift, ui32 width); ////////////////////////////////////////////////////////////////////////// - void sse2_rct_forward(const si32 *r, const si32 *g, const si32 *b, - si32 *y, si32 *cb, si32 *cr, ui32 repeat); + void sse2_rct_forward( + const line_buf *r, const line_buf *g, const line_buf *b, + line_buf *y, line_buf *cb, line_buf *cr, ui32 repeat); ////////////////////////////////////////////////////////////////////////// - void sse2_rct_backward(const si32 *y, const si32 *cb, const si32 *cr, - si32 *r, si32 *g, si32 *b, ui32 repeat); + void sse2_rct_backward( + const line_buf *y, const line_buf *cb, const line_buf *cr, + line_buf *r, line_buf *g, line_buf *b, ui32 repeat); ////////////////////////////////////////////////////////////////////////// // @@ -232,12 +237,14 @@ namespace ojph { int shift, ui32 width); ////////////////////////////////////////////////////////////////////////// - void avx2_rct_forward(const si32 *r, const si32 *g, const si32 *b, - si32 *y, si32 *cb, si32 *cr, ui32 repeat); + void avx2_rct_forward( + const line_buf *r, const line_buf *g, const line_buf *b, + line_buf *y, line_buf *cb, line_buf *cr, ui32 repeat); ////////////////////////////////////////////////////////////////////////// - void avx2_rct_backward(const si32 *y, const si32 *cb, const si32 *cr, - si32 *r, si32 *g, si32 *b, ui32 repeat); + void avx2_rct_backward( + const line_buf *y, const line_buf *cb, const line_buf *cr, + line_buf *r, line_buf *g, line_buf *b, ui32 repeat); ////////////////////////////////////////////////////////////////////////// // @@ -272,12 +279,14 @@ namespace ojph { int shift, ui32 width); ////////////////////////////////////////////////////////////////////////// - void wasm_rct_forward(const si32 *r, const si32 *g, const si32 *b, - si32 *y, si32 *cb, si32 *cr, ui32 repeat); + void wasm_rct_forward( + const line_buf *r, const line_buf *g, const line_buf *b, + line_buf *y, line_buf *cb, line_buf *cr, ui32 repeat); ////////////////////////////////////////////////////////////////////////// - void wasm_rct_backward(const si32 *y, const si32 *cb, const si32 *cr, - si32 *r, si32 *g, si32 *b, ui32 repeat); + void wasm_rct_backward( + const line_buf *y, const line_buf *cb, const line_buf *cr, + line_buf *r, line_buf *g, line_buf *b, ui32 repeat); ////////////////////////////////////////////////////////////////////////// void wasm_ict_forward(const float *r, const float *g, const float *b, diff --git a/src/core/transform/ojph_colour_sse2.cpp b/src/core/transform/ojph_colour_sse2.cpp index c50c091e..3829f6a5 100644 --- a/src/core/transform/ojph_colour_sse2.cpp +++ b/src/core/transform/ojph_colour_sse2.cpp @@ -39,6 +39,7 @@ #include "ojph_defs.h" #include "ojph_arch.h" +#include "ojph_mem.h" #include "ojph_colour.h" #include @@ -46,6 +47,118 @@ namespace ojph { namespace local { + ///////////////////////////////////////////////////////////////////////// + // https://github.com/seung-lab/dijkstra3d/blob/master/libdivide.h + static inline __m128i sse2_mm_srai_epi64(__m128i a, int amt, __m128i m) + { + // note than m must be obtained using + // __m128i ve = _mm_set1_epi64x(1ULL << (63 - amt)); + __m128i x = _mm_srli_epi64(a, amt); + x = _mm_xor_si128(x, m); + __m128i result = _mm_sub_epi64(x, m); + return result; + } + + ////////////////////////////////////////////////////////////////////////// + static inline __m128i sse2_cvtlo_epi32_epi64(__m128i a, __m128i zero) + { + __m128i s, t; + s = _mm_unpacklo_epi32(a, zero); // missing extended -ve + t = _mm_cmplt_epi32(a, zero); // get -ve + t = _mm_unpacklo_epi32(zero, t); + s = _mm_or_si128(t, s); // put -ve + return s; + } + + ////////////////////////////////////////////////////////////////////////// + static inline __m128i sse2_cvthi_epi32_epi64(__m128i a, __m128i zero) + { + __m128i s, t; + s = _mm_unpackhi_epi32(a, zero); // missing extended -ve + t = _mm_cmplt_epi32(a, zero); // get -ve + t = _mm_unpackhi_epi32(zero, t); + s = _mm_or_si128(t, s); // put -ve + return s; + } + + ////////////////////////////////////////////////////////////////////////// + void sse2_rev_convert(const line_buf *src_line, + const ui32 src_line_offset, + line_buf *dst_line, + const ui32 dst_line_offset, + si64 shift, ui32 width) + { + if (src_line->flags & line_buf::LFT_32BIT) + { + if (dst_line->flags & line_buf::LFT_32BIT) + { + const si32 *sp = src_line->i32 + src_line_offset; + si32 *dp = dst_line->i32 + dst_line_offset; + si32 s = (si32)shift; + for (ui32 i = width; i > 0; --i) + *dp++ = *sp++ + s; + } + else + { + const si32 *sp = src_line->i32 + src_line_offset; + si64 *dp = dst_line->i64 + dst_line_offset; + for (ui32 i = width; i > 0; --i) + *dp++ = *sp++ + shift; + } + } + else + { + assert(src_line->flags | line_buf::LFT_64BIT); + assert(dst_line->flags | line_buf::LFT_32BIT); + const si64 *sp = src_line->i64 + src_line_offset; + si32 *dp = dst_line->i32 + dst_line_offset; + for (ui32 i = width; i > 0; --i) + *dp++ = (si32)(*sp++ + shift); + } + } + + ////////////////////////////////////////////////////////////////////////// + void sse2_rev_convert_nlt_type3(const line_buf *src_line, + const ui32 src_line_offset, + line_buf *dst_line, + const ui32 dst_line_offset, + si64 shift, ui32 width) + { + if (src_line->flags & line_buf::LFT_32BIT) + { + if (dst_line->flags & line_buf::LFT_32BIT) + { + const si32 *sp = src_line->i32 + src_line_offset; + si32 *dp = dst_line->i32 + dst_line_offset; + si32 s = (si32)shift; + for (ui32 i = width; i > 0; --i) { + const si32 v = *sp++; + *dp++ = v >= 0 ? v : (- v - s); + } + } + else + { + const si32 *sp = src_line->i32 + src_line_offset; + si64 *dp = dst_line->i64 + dst_line_offset; + for (ui32 i = width; i > 0; --i) { + const si64 v = *sp++; + *dp++ = v >= 0 ? v : (- v - shift); + } + } + } + else + { + assert(src_line->flags | line_buf::LFT_64BIT); + assert(dst_line->flags | line_buf::LFT_32BIT); + const si64 *sp = src_line->i64 + src_line_offset; + si32 *dp = dst_line->i32 + dst_line_offset; + for (ui32 i = width; i > 0; --i) { + const si64 v = *sp++; + *dp++ = (si32)(v >= 0 ? v : (- v - shift)); + } + } + } + ////////////////////////////////////////////////////////////////////////// void sse2_cnvrt_float_to_si32_shftd(const float *sp, si32 *dp, float mul, ui32 width) @@ -80,80 +193,199 @@ namespace ojph { _MM_SET_ROUNDING_MODE(rounding_mode); } - ////////////////////////////////////////////////////////////////////////// - void sse2_cnvrt_si32_to_si32_shftd(const si32 *sp, si32 *dp, int shift, - ui32 width) + void sse2_rct_forward(const line_buf *r, + const line_buf *g, + const line_buf *b, + line_buf *y, line_buf *cb, line_buf *cr, + ui32 repeat) { - __m128i sh = _mm_set1_epi32(shift); - for (int i = (width + 3) >> 2; i > 0; --i, sp+=4, dp+=4) + assert((y->flags & line_buf::LFT_REVERSIBLE) && + (cb->flags & line_buf::LFT_REVERSIBLE) && + (cr->flags & line_buf::LFT_REVERSIBLE) && + (r->flags & line_buf::LFT_REVERSIBLE) && + (g->flags & line_buf::LFT_REVERSIBLE) && + (b->flags & line_buf::LFT_REVERSIBLE)); + + if (y->flags & line_buf::LFT_32BIT) { - __m128i s = _mm_loadu_si128((__m128i*)sp); - s = _mm_add_epi32(s, sh); - _mm_storeu_si128((__m128i*)dp, s); - } - } + assert((y->flags & line_buf::LFT_32BIT) && + (cb->flags & line_buf::LFT_32BIT) && + (cr->flags & line_buf::LFT_32BIT) && + (r->flags & line_buf::LFT_32BIT) && + (g->flags & line_buf::LFT_32BIT) && + (b->flags & line_buf::LFT_32BIT)); + const si32 *rp = r->i32, * gp = g->i32, * bp = b->i32; + si32 *yp = y->i32, * cbp = cb->i32, * crp = cr->i32; + for (int i = (repeat + 3) >> 2; i > 0; --i) + { + __m128i mr = _mm_load_si128((__m128i*)rp); + __m128i mg = _mm_load_si128((__m128i*)gp); + __m128i mb = _mm_load_si128((__m128i*)bp); + __m128i t = _mm_add_epi32(mr, mb); + t = _mm_add_epi32(t, _mm_slli_epi32(mg, 1)); + _mm_store_si128((__m128i*)yp, _mm_srai_epi32(t, 2)); + t = _mm_sub_epi32(mb, mg); + _mm_store_si128((__m128i*)cbp, t); + t = _mm_sub_epi32(mr, mg); + _mm_store_si128((__m128i*)crp, t); - ////////////////////////////////////////////////////////////////////////// - void sse2_cnvrt_si32_to_si32_nlt_type3(const si32* sp, si32* dp, - int shift, ui32 width) - { - __m128i sh = _mm_set1_epi32(-shift); - __m128i zero = _mm_setzero_si128(); - for (int i = (width + 3) >> 2; i > 0; --i, sp += 4, dp += 4) + rp += 4; gp += 4; bp += 4; + yp += 4; cbp += 4; crp += 4; + } + } + else { - __m128i s = _mm_loadu_si128((__m128i*)sp); - __m128i c = _mm_cmplt_epi32(s, zero); // 0xFFFFFFFF for -ve value - __m128i v_m_sh = _mm_sub_epi32(sh, s); // - shift - value - v_m_sh = _mm_and_si128(c, v_m_sh); // keep only - shift - value - s = _mm_andnot_si128(c, s); // keep only +ve or 0 - s = _mm_or_si128(s, v_m_sh); // combine - _mm_storeu_si128((__m128i*)dp, s); + assert((y->flags & line_buf::LFT_64BIT) && + (cb->flags & line_buf::LFT_64BIT) && + (cr->flags & line_buf::LFT_64BIT) && + (r->flags & line_buf::LFT_32BIT) && + (g->flags & line_buf::LFT_32BIT) && + (b->flags & line_buf::LFT_32BIT)); + __m128i zero = _mm_setzero_si128(); + __m128i v2 = _mm_set1_epi64x(1ULL << (63 - 2)); + const si32 *rp = r->i32, *gp = g->i32, *bp = b->i32; + si64 *yp = y->i64, *cbp = cb->i64, *crp = cr->i64; + for (int i = (repeat + 3) >> 2; i > 0; --i) + { + __m128i mr32 = _mm_load_si128((__m128i*)rp); + __m128i mg32 = _mm_load_si128((__m128i*)gp); + __m128i mb32 = _mm_load_si128((__m128i*)bp); + __m128i mr, mg, mb, t; + mr = sse2_cvtlo_epi32_epi64(mr32, zero); + mg = sse2_cvtlo_epi32_epi64(mg32, zero); + mb = sse2_cvtlo_epi32_epi64(mb32, zero); + + t = _mm_add_epi64(mr, mb); + t = _mm_add_epi64(t, _mm_slli_epi64(mg, 1)); + _mm_store_si128((__m128i*)yp, sse2_mm_srai_epi64(t, 2, v2)); + t = _mm_sub_epi64(mb, mg); + _mm_store_si128((__m128i*)cbp, t); + t = _mm_sub_epi64(mr, mg); + _mm_store_si128((__m128i*)crp, t); + + yp += 2; cbp += 2; crp += 2; + + mr = sse2_cvthi_epi32_epi64(mr32, zero); + mg = sse2_cvthi_epi32_epi64(mg32, zero); + mb = sse2_cvthi_epi32_epi64(mb32, zero); + + t = _mm_add_epi64(mr, mb); + t = _mm_add_epi64(t, _mm_slli_epi64(mg, 1)); + _mm_store_si128((__m128i*)yp, sse2_mm_srai_epi64(t, 2, v2)); + t = _mm_sub_epi64(mb, mg); + _mm_store_si128((__m128i*)cbp, t); + t = _mm_sub_epi64(mr, mg); + _mm_store_si128((__m128i*)crp, t); + + rp += 4; gp += 4; bp += 4; + yp += 2; cbp += 2; crp += 2; + } } } ////////////////////////////////////////////////////////////////////////// - void sse2_rct_forward(const si32 *r, const si32 *g, const si32 *b, - si32 *y, si32 *cb, si32 *cr, ui32 repeat) + void sse2_rct_backward(const line_buf *y, + const line_buf *cb, + const line_buf *cr, + line_buf *r, line_buf *g, line_buf *b, + ui32 repeat) { - for (int i = (repeat + 3) >> 2; i > 0; --i) + assert((y->flags & line_buf::LFT_REVERSIBLE) && + (cb->flags & line_buf::LFT_REVERSIBLE) && + (cr->flags & line_buf::LFT_REVERSIBLE) && + (r->flags & line_buf::LFT_REVERSIBLE) && + (g->flags & line_buf::LFT_REVERSIBLE) && + (b->flags & line_buf::LFT_REVERSIBLE)); + + if (y->flags & line_buf::LFT_32BIT) { - __m128i mr = _mm_load_si128((__m128i*)r); - __m128i mg = _mm_load_si128((__m128i*)g); - __m128i mb = _mm_load_si128((__m128i*)b); - __m128i t = _mm_add_epi32(mr, mb); - t = _mm_add_epi32(t, _mm_slli_epi32(mg, 1)); - _mm_store_si128((__m128i*)y, _mm_srai_epi32(t, 2)); - t = _mm_sub_epi32(mb, mg); - _mm_store_si128((__m128i*)cb, t); - t = _mm_sub_epi32(mr, mg); - _mm_store_si128((__m128i*)cr, t); - - r += 4; g += 4; b += 4; - y += 4; cb += 4; cr += 4; - } - } + assert((y->flags & line_buf::LFT_32BIT) && + (cb->flags & line_buf::LFT_32BIT) && + (cr->flags & line_buf::LFT_32BIT) && + (r->flags & line_buf::LFT_32BIT) && + (g->flags & line_buf::LFT_32BIT) && + (b->flags & line_buf::LFT_32BIT)); + const si32 *yp = y->i32, *cbp = cb->i32, *crp = cr->i32; + si32 *rp = r->i32, *gp = g->i32, *bp = b->i32; + for (int i = (repeat + 3) >> 2; i > 0; --i) + { + __m128i my = _mm_load_si128((__m128i*)yp); + __m128i mcb = _mm_load_si128((__m128i*)cbp); + __m128i mcr = _mm_load_si128((__m128i*)crp); - ////////////////////////////////////////////////////////////////////////// - void sse2_rct_backward(const si32 *y, const si32 *cb, const si32 *cr, - si32 *r, si32 *g, si32 *b, ui32 repeat) - { - for (int i = (repeat + 3) >> 2; i > 0; --i) + __m128i t = _mm_add_epi32(mcb, mcr); + t = _mm_sub_epi32(my, _mm_srai_epi32(t, 2)); + _mm_store_si128((__m128i*)gp, t); + __m128i u = _mm_add_epi32(mcb, t); + _mm_store_si128((__m128i*)bp, u); + u = _mm_add_epi32(mcr, t); + _mm_store_si128((__m128i*)rp, u); + + yp += 4; cbp += 4; crp += 4; + rp += 4; gp += 4; bp += 4; + } + } + else { - __m128i my = _mm_load_si128((__m128i*)y); - __m128i mcb = _mm_load_si128((__m128i*)cb); - __m128i mcr = _mm_load_si128((__m128i*)cr); - - __m128i t = _mm_add_epi32(mcb, mcr); - t = _mm_sub_epi32(my, _mm_srai_epi32(t, 2)); - _mm_store_si128((__m128i*)g, t); - __m128i u = _mm_add_epi32(mcb, t); - _mm_store_si128((__m128i*)b, u); - u = _mm_add_epi32(mcr, t); - _mm_store_si128((__m128i*)r, u); - - y += 4; cb += 4; cr += 4; - r += 4; g += 4; b += 4; + assert((y->flags & line_buf::LFT_64BIT) && + (cb->flags & line_buf::LFT_64BIT) && + (cr->flags & line_buf::LFT_64BIT) && + (r->flags & line_buf::LFT_32BIT) && + (g->flags & line_buf::LFT_32BIT) && + (b->flags & line_buf::LFT_32BIT)); + __m128i v2 = _mm_set1_epi64x(1ULL << (63 - 2)); + __m128i low_bits = _mm_set_epi64x(0, 0xFFFFFFFFFFFFFFFFLL); + const si64 *yp = y->i64, *cbp = cb->i64, *crp = cr->i64; + si32 *rp = r->i32, *gp = g->i32, *bp = b->i32; + for (int i = (repeat + 3) >> 2; i > 0; --i) + { + __m128i my, mcb, mcr, tr, tg, tb; + my = _mm_load_si128((__m128i*)yp); + mcb = _mm_load_si128((__m128i*)cbp); + mcr = _mm_load_si128((__m128i*)crp); + + tg = _mm_add_epi64(mcb, mcr); + tg = _mm_sub_epi64(my, sse2_mm_srai_epi64(tg, 2, v2)); + tb = _mm_add_epi64(mcb, tg); + tr = _mm_add_epi64(mcr, tg); + + __m128i mr, mg, mb; + mr = _mm_shuffle_epi32(tr, _MM_SHUFFLE(0, 0, 2, 0)); + mr = _mm_and_si128(low_bits, mr); + mg = _mm_shuffle_epi32(tg, _MM_SHUFFLE(0, 0, 2, 0)); + mg = _mm_and_si128(low_bits, mg); + mb = _mm_shuffle_epi32(tb, _MM_SHUFFLE(0, 0, 2, 0)); + mb = _mm_and_si128(low_bits, mb); + + yp += 2; cbp += 2; crp += 2; + + my = _mm_load_si128((__m128i*)yp); + mcb = _mm_load_si128((__m128i*)cbp); + mcr = _mm_load_si128((__m128i*)crp); + + tg = _mm_add_epi64(mcb, mcr); + tg = _mm_sub_epi64(my, sse2_mm_srai_epi64(tg, 2, v2)); + tb = _mm_add_epi64(mcb, tg); + tr = _mm_add_epi64(mcr, tg); + + tr = _mm_shuffle_epi32(tr, _MM_SHUFFLE(2, 0, 0, 0)); + tr = _mm_andnot_si128(low_bits, tr); + mr = _mm_or_si128(mr, tr); + tg = _mm_shuffle_epi32(tg, _MM_SHUFFLE(2, 0, 0, 0)); + tg = _mm_andnot_si128(low_bits, tg); + mg = _mm_or_si128(mg, tg); + tb = _mm_shuffle_epi32(tb, _MM_SHUFFLE(2, 0, 0, 0)); + tb = _mm_andnot_si128(low_bits, tb); + mb = _mm_or_si128(mb, tb); + + _mm_store_si128((__m128i*)rp, mr); + _mm_store_si128((__m128i*)gp, mg); + _mm_store_si128((__m128i*)bp, mb); + + yp += 2; cbp += 2; crp += 2; + rp += 4; gp += 4; bp += 4; + } } } diff --git a/src/core/transform/ojph_transform.cpp b/src/core/transform/ojph_transform.cpp index 32189e56..c4313ab2 100644 --- a/src/core/transform/ojph_transform.cpp +++ b/src/core/transform/ojph_transform.cpp @@ -127,14 +127,14 @@ namespace ojph { } #endif // !OJPH_DISABLE_SSE - // #ifndef OJPH_DISABLE_SSE2 - // if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_SSE2) - // { - // rev_vert_step = sse2_rev_vert_step; - // rev_horz_ana = sse2_rev_horz_ana; - // rev_horz_syn = sse2_rev_horz_syn; - // } - // #endif // !OJPH_DISABLE_SSE2 + #ifndef OJPH_DISABLE_SSE2 + if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_SSE2) + { + rev_vert_step = sse2_rev_vert_step; + rev_horz_ana = sse2_rev_horz_ana; + rev_horz_syn = sse2_rev_horz_syn; + } + #endif // !OJPH_DISABLE_SSE2 #ifndef OJPH_DISABLE_AVX if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_AVX) @@ -146,14 +146,14 @@ namespace ojph { } #endif // !OJPH_DISABLE_AVX - // #ifndef OJPH_DISABLE_AVX2 - // if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_AVX2) - // { - // rev_vert_step = avx2_rev_vert_step; - // rev_horz_ana = avx2_rev_horz_ana; - // rev_horz_syn = avx2_rev_horz_syn; - // } - // #endif // !OJPH_DISABLE_AVX2 + #ifndef OJPH_DISABLE_AVX2 + if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_AVX2) + { + rev_vert_step = avx2_rev_vert_step; + rev_horz_ana = avx2_rev_horz_ana; + rev_horz_syn = avx2_rev_horz_syn; + } + #endif // !OJPH_DISABLE_AVX2 #if (defined(OJPH_ARCH_X86_64) && !defined(OJPH_DISABLE_AVX512)) if (get_cpu_ext_level() >= X86_CPU_EXT_LEVEL_AVX512) @@ -194,6 +194,7 @@ namespace ojph { #if !defined(OJPH_ENABLE_WASM_SIMD) || !defined(OJPH_EMSCRIPTEN) ///////////////////////////////////////////////////////////////////////// + static void gen_rev_vert_step32(const lifting_step* s, const line_buf* sig, const line_buf* other, const line_buf* aug, ui32 repeat, bool synthesis) @@ -245,6 +246,7 @@ namespace ojph { } ///////////////////////////////////////////////////////////////////////// + static void gen_rev_vert_step64(const lifting_step* s, const line_buf* sig, const line_buf* other, const line_buf* aug, ui32 repeat, bool synthesis) @@ -319,6 +321,7 @@ namespace ojph { } ///////////////////////////////////////////////////////////////////////// + static void gen_rev_horz_ana32(const param_atk* atk, const line_buf* ldst, const line_buf* hdst, const line_buf* src, ui32 width, bool even) @@ -397,6 +400,7 @@ namespace ojph { } ///////////////////////////////////////////////////////////////////////// + static void gen_rev_horz_ana64(const param_atk* atk, const line_buf* ldst, const line_buf* hdst, const line_buf* src, ui32 width, bool even) @@ -495,6 +499,7 @@ namespace ojph { } ////////////////////////////////////////////////////////////////////////// + static void gen_rev_horz_syn32(const param_atk* atk, const line_buf* dst, const line_buf* lsrc, const line_buf* hsrc, ui32 width, bool even) @@ -573,6 +578,7 @@ namespace ojph { } ////////////////////////////////////////////////////////////////////////// + static void gen_rev_horz_syn64(const param_atk* atk, const line_buf* dst, const line_buf* lsrc, const line_buf* hsrc, ui32 width, bool even) diff --git a/src/core/transform/ojph_transform_avx.cpp b/src/core/transform/ojph_transform_avx.cpp index 08566624..4e5b82e7 100644 --- a/src/core/transform/ojph_transform_avx.cpp +++ b/src/core/transform/ojph_transform_avx.cpp @@ -61,6 +61,76 @@ namespace ojph { } } + ////////////////////////////////////////////////////////////////////////// + static inline + void avx_deinterleave32(float* dpl, float* dph, float* sp, + int width, bool even) + { + if (even) + { + for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8) + { + __m256 a = _mm256_load_ps(sp); + __m256 b = _mm256_load_ps(sp + 8); + __m256 c = _mm256_permute2f128_ps(a, b, (2 << 4) | (0)); + __m256 d = _mm256_permute2f128_ps(a, b, (3 << 4) | (1)); + __m256 e = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(2, 0, 2, 0)); + __m256 f = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(3, 1, 3, 1)); + _mm256_store_ps(dpl, e); + _mm256_store_ps(dph, f); + } + } + else + { + for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8) + { + __m256 a = _mm256_load_ps(sp); + __m256 b = _mm256_load_ps(sp + 8); + __m256 c = _mm256_permute2f128_ps(a, b, (2 << 4) | (0)); + __m256 d = _mm256_permute2f128_ps(a, b, (3 << 4) | (1)); + __m256 e = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(2, 0, 2, 0)); + __m256 f = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(3, 1, 3, 1)); + _mm256_store_ps(dpl, f); + _mm256_store_ps(dph, e); + } + } + } + + ////////////////////////////////////////////////////////////////////////// + static inline + void avx_interleave32(float* dp, float* spl, float* sph, + int width, bool even) + { + if (even) + { + for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8) + { + __m256 a = _mm256_load_ps(spl); + __m256 b = _mm256_load_ps(sph); + __m256 c = _mm256_unpacklo_ps(a, b); + __m256 d = _mm256_unpackhi_ps(a, b); + __m256 e = _mm256_permute2f128_ps(c, d, (2 << 4) | (0)); + __m256 f = _mm256_permute2f128_ps(c, d, (3 << 4) | (1)); + _mm256_store_ps(dp, e); + _mm256_store_ps(dp + 8, f); + } + } + else + { + for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8) + { + __m256 a = _mm256_load_ps(spl); + __m256 b = _mm256_load_ps(sph); + __m256 c = _mm256_unpacklo_ps(b, a); + __m256 d = _mm256_unpackhi_ps(b, a); + __m256 e = _mm256_permute2f128_ps(c, d, (2 << 4) | (0)); + __m256 f = _mm256_permute2f128_ps(c, d, (3 << 4) | (1)); + _mm256_store_ps(dp, e); + _mm256_store_ps(dp + 8, f); + } + } + } + ////////////////////////////////////////////////////////////////////////// void avx_irv_vert_step(const lifting_step* s, const line_buf* sig, const line_buf* other, const line_buf* aug, @@ -104,7 +174,7 @@ namespace ojph { float* dph = hdst->f32; float* sp = src->f32; int w = (int)width; - AVX_DEINTERLEAVE(dpl, dph, sp, w, even); + avx_deinterleave32(dpl, dph, sp, w, even); } // the actual horizontal transform @@ -238,7 +308,7 @@ namespace ojph { float* spl = lsrc->f32; float* sph = hsrc->f32; int w = (int)width; - AVX_INTERLEAVE(dp, spl, sph, w, even); + avx_interleave32(dp, spl, sph, w, even); } } else { diff --git a/src/core/transform/ojph_transform_avx2.cpp b/src/core/transform/ojph_transform_avx2.cpp index 847cd4c4..76a4dd71 100644 --- a/src/core/transform/ojph_transform_avx2.cpp +++ b/src/core/transform/ojph_transform_avx2.cpp @@ -52,13 +52,95 @@ namespace ojph { namespace local { ///////////////////////////////////////////////////////////////////////// - void avx2_rev_vert_step(const lifting_step* s, const line_buf* sig, - const line_buf* other, const line_buf* aug, - ui32 repeat, bool synthesis) + // https://github.com/seung-lab/dijkstra3d/blob/master/libdivide.h + static inline + __m256i avx2_mm256_srai_epi64(__m256i a, int amt, __m256i m) + { + // note than m must be obtained using + // __m256i ve = _mm256_set1_epi64x(1ULL << (63 - amt)); + __m256i x = _mm256_srli_epi64(a, amt); + x = _mm256_xor_si256(x, m); + __m256i result = _mm256_sub_epi64(x, m); + return result; + } + + ////////////////////////////////////////////////////////////////////////// + static inline + void avx2_deinterleave32(float* dpl, float* dph, float* sp, int width) + { + for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8) + { + __m256 a = _mm256_load_ps(sp); + __m256 b = _mm256_load_ps(sp + 8); + __m256 c = _mm256_permute2f128_ps(a, b, (2 << 4) | (0)); + __m256 d = _mm256_permute2f128_ps(a, b, (3 << 4) | (1)); + __m256 e = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(2, 0, 2, 0)); + __m256 f = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(3, 1, 3, 1)); + _mm256_store_ps(dpl, e); + _mm256_store_ps(dph, f); + } + } + + ////////////////////////////////////////////////////////////////////////// + static inline + void avx2_interleave32(float* dp, float* spl, float* sph, int width) + { + for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8) + { + __m256 a = _mm256_load_ps(spl); + __m256 b = _mm256_load_ps(sph); + __m256 c = _mm256_unpacklo_ps(a, b); + __m256 d = _mm256_unpackhi_ps(a, b); + __m256 e = _mm256_permute2f128_ps(c, d, (2 << 4) | (0)); + __m256 f = _mm256_permute2f128_ps(c, d, (3 << 4) | (1)); + _mm256_store_ps(dp, e); + _mm256_store_ps(dp + 8, f); + } + } + + ////////////////////////////////////////////////////////////////////////// + static inline + void avx2_deinterleave64(double* dpl, double* dph, double* sp, int width) + { + for (; width > 0; width -= 8, sp += 8, dpl += 4, dph += 4) + { + __m256d a = _mm256_load_pd(sp); + __m256d b = _mm256_load_pd(sp + 4); + __m256d c = _mm256_permute2f128_pd(a, b, (2 << 4) | (0)); + __m256d d = _mm256_permute2f128_pd(a, b, (3 << 4) | (1)); + __m256d e = _mm256_shuffle_pd(c, d, 0x0); + __m256d f = _mm256_shuffle_pd(c, d, 0xF); + _mm256_store_pd(dpl, e); + _mm256_store_pd(dph, f); + } + } + + ////////////////////////////////////////////////////////////////////////// + static inline + void avx2_interleave64(double* dp, double* spl, double* sph, int width) + { + for (; width > 0; width -= 8, dp += 8, spl += 4, sph += 4) + { + __m256d a = _mm256_load_pd(spl); + __m256d b = _mm256_load_pd(sph); + __m256d c = _mm256_unpacklo_pd(a, b); + __m256d d = _mm256_unpackhi_pd(a, b); + __m256d e = _mm256_permute2f128_pd(c, d, (2 << 4) | (0)); + __m256d f = _mm256_permute2f128_pd(c, d, (3 << 4) | (1)); + _mm256_store_pd(dp, e); + _mm256_store_pd(dp + 4, f); + } + } + + ///////////////////////////////////////////////////////////////////////// + static + void avx2_rev_vert_step32(const lifting_step* s, const line_buf* sig, + const line_buf* other, const line_buf* aug, + ui32 repeat, bool synthesis) { const si32 a = s->rev.Aatk; const si32 b = s->rev.Batk; - const si32 e = s->rev.Eatk; + const ui8 e = s->rev.Eatk; __m256i va = _mm256_set1_epi32(a); __m256i vb = _mm256_set1_epi32(b); @@ -181,19 +263,174 @@ namespace ojph { } ///////////////////////////////////////////////////////////////////////// - void avx2_rev_horz_ana(const param_atk* atk, const line_buf* ldst, - const line_buf* hdst, const line_buf* src, - ui32 width, bool even) + static + void avx2_rev_vert_step64(const lifting_step* s, const line_buf* sig, + const line_buf* other, const line_buf* aug, + ui32 repeat, bool synthesis) + { + const si32 a = s->rev.Aatk; + const si32 b = s->rev.Batk; + const ui8 e = s->rev.Eatk; + __m256i va = _mm256_set1_epi64x(a); + __m256i vb = _mm256_set1_epi64x(b); + __m256i ve = _mm256_set1_epi64x(1ULL << (63 - e)); + + si64* dst = aug->i64; + const si64* src1 = sig->i64, * src2 = other->i64; + // The general definition of the wavelet in Part 2 is slightly + // different to part 2, although they are mathematically equivalent + // here, we identify the simpler form from Part 1 and employ them + if (a == 1) + { // 5/3 update and any case with a == 1 + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 4, dst += 4, src1 += 4, src2 += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)src1); + __m256i s2 = _mm256_load_si256((__m256i*)src2); + __m256i d = _mm256_load_si256((__m256i*)dst); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_add_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dst, d); + } + else + for (; i > 0; i -= 4, dst += 4, src1 += 4, src2 += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)src1); + __m256i s2 = _mm256_load_si256((__m256i*)src2); + __m256i d = _mm256_load_si256((__m256i*)dst); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_add_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dst, d); + } + } + else if (a == -1 && b == 1 && e == 1) + { // 5/3 predict + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 4, dst += 4, src1 += 4, src2 += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)src1); + __m256i s2 = _mm256_load_si256((__m256i*)src2); + __m256i d = _mm256_load_si256((__m256i*)dst); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i w = avx2_mm256_srai_epi64(t, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dst, d); + } + else + for (; i > 0; i -= 4, dst += 4, src1 += 4, src2 += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)src1); + __m256i s2 = _mm256_load_si256((__m256i*)src2); + __m256i d = _mm256_load_si256((__m256i*)dst); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i w = avx2_mm256_srai_epi64(t, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dst, d); + } + } + else if (a == -1) + { // any case with a == -1, which is not 5/3 predict + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 4, dst += 4, src1 += 4, src2 += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)src1); + __m256i s2 = _mm256_load_si256((__m256i*)src2); + __m256i d = _mm256_load_si256((__m256i*)dst); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_sub_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dst, d); + } + else + for (; i > 0; i -= 4, dst += 4, src1 += 4, src2 += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)src1); + __m256i s2 = _mm256_load_si256((__m256i*)src2); + __m256i d = _mm256_load_si256((__m256i*)dst); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_sub_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dst, d); + } + } + else { // general case + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 4, dst += 4, src1 += 4, src2 += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)src1); + __m256i s2 = _mm256_load_si256((__m256i*)src2); + __m256i d = _mm256_load_si256((__m256i*)dst); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i u = _mm256_mullo_epi64(va, t); + __m256i v = _mm256_add_epi64(vb, u); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dst, d); + } + else + for (; i > 0; i -= 4, dst += 4, src1 += 4, src2 += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)src1); + __m256i s2 = _mm256_load_si256((__m256i*)src2); + __m256i d = _mm256_load_si256((__m256i*)dst); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i u = _mm256_mullo_epi64(va, t); + __m256i v = _mm256_add_epi64(vb, u); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dst, d); + } + } + } + + ///////////////////////////////////////////////////////////////////////// + void avx2_rev_vert_step(const lifting_step* s, const line_buf* sig, + const line_buf* other, const line_buf* aug, + ui32 repeat, bool synthesis) + { + if (((sig != NULL) && (sig->flags & line_buf::LFT_32BIT)) || + ((aug != NULL) && (aug->flags & line_buf::LFT_32BIT)) || + ((other != NULL) && (other->flags & line_buf::LFT_32BIT))) + { + assert((sig == NULL || sig->flags & line_buf::LFT_32BIT) && + (other == NULL || other->flags & line_buf::LFT_32BIT) && + (aug == NULL || aug->flags & line_buf::LFT_32BIT)); + avx2_rev_vert_step32(s, sig, other, aug, repeat, synthesis); + } + else + { + assert((sig == NULL || sig->flags & line_buf::LFT_64BIT) && + (other == NULL || other->flags & line_buf::LFT_64BIT) && + (aug == NULL || aug->flags & line_buf::LFT_64BIT)); + avx2_rev_vert_step64(s, sig, other, aug, repeat, synthesis); + } + } + + ///////////////////////////////////////////////////////////////////////// + static + void avx2_rev_horz_ana32(const param_atk* atk, const line_buf* ldst, + const line_buf* hdst, const line_buf* src, + ui32 width, bool even) { if (width > 1) { // combine both lsrc and hsrc into dst { - float* dpl = ldst->f32; - float* dph = hdst->f32; - float* sp = src->f32; + float* dpl = even ? ldst->f32 : hdst->f32; + float* dph = even ? hdst->f32 : ldst->f32; + float* sp = src->f32; int w = (int)width; - AVX_DEINTERLEAVE(dpl, dph, sp, w, even); + avx2_deinterleave32(dpl, dph, sp, w); } si32* hp = hdst->i32, * lp = ldst->i32; @@ -206,7 +443,7 @@ namespace ojph { const lifting_step* s = atk->get_step(j - 1); const si32 a = s->rev.Aatk; const si32 b = s->rev.Batk; - const si32 e = s->rev.Eatk; + const ui8 e = s->rev.Eatk; __m256i va = _mm256_set1_epi32(a); __m256i vb = _mm256_set1_epi32(b); @@ -346,11 +583,201 @@ namespace ojph { hdst->i32[0] = src->i32[0] << 1; } } + + ///////////////////////////////////////////////////////////////////////// + static + void avx2_rev_horz_ana64(const param_atk* atk, const line_buf* ldst, + const line_buf* hdst, const line_buf* src, + ui32 width, bool even) + { + if (width > 1) + { + // combine both lsrc and hsrc into dst + { + double* dpl = (double*)(even ? ldst->p : hdst->p); + double* dph = (double*)(even ? hdst->p : ldst->p); + double* sp = (double*)src->p; + int w = (int)width; + avx2_deinterleave64(dpl, dph, sp, w); + } + + si64* hp = hdst->i64, * lp = ldst->i64; + ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass + ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass + ui32 num_steps = atk->get_num_steps(); + for (ui32 j = num_steps; j > 0; --j) + { + // first lifting step + const lifting_step* s = atk->get_step(j - 1); + const si32 a = s->rev.Aatk; + const si32 b = s->rev.Batk; + const ui8 e = s->rev.Eatk; + __m256i va = _mm256_set1_epi64x(a); + __m256i vb = _mm256_set1_epi64x(b); + __m256i ve = _mm256_set1_epi64x(1ULL << (63 - e)); + + // extension + lp[-1] = lp[0]; + lp[l_width] = lp[l_width - 1]; + // lifting step + const si64* sp = lp; + si64* dp = hp; + if (a == 1) + { // 5/3 update and any case with a == 1 + int i = (int)h_width; + if (even) + { + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp + 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_add_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + } + else + { + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp - 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_add_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + } + } + else if (a == -1 && b == 1 && e == 1) + { // 5/3 predict + int i = (int)h_width; + if (even) + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp + 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i w = avx2_mm256_srai_epi64(t, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + else + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp - 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i w = avx2_mm256_srai_epi64(t, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + } + else if (a == -1) + { // any case with a == -1, which is not 5/3 predict + int i = (int)h_width; + if (even) + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp + 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_sub_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + else + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp - 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_sub_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + } + else { + // general case + int i = (int)h_width; + if (even) + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp + 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i u = _mm256_mullo_epi64(va, t); + __m256i v = _mm256_add_epi64(vb, u); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + else + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp - 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i u = _mm256_mullo_epi64(va, t); + __m256i v = _mm256_add_epi64(vb, u); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + } + + // swap buffers + si64* t = lp; lp = hp; hp = t; + even = !even; + ui32 w = l_width; l_width = h_width; h_width = w; + } + } + else { + if (even) + ldst->i64[0] = src->i64[0]; + else + hdst->i64[0] = src->i64[0] << 1; + } + } + + ///////////////////////////////////////////////////////////////////////// + void avx2_rev_horz_ana(const param_atk* atk, const line_buf* ldst, + const line_buf* hdst, const line_buf* src, + ui32 width, bool even) + { + if (src->flags & line_buf::LFT_32BIT) + { + assert((ldst == NULL || ldst->flags & line_buf::LFT_32BIT) && + (hdst == NULL || hdst->flags & line_buf::LFT_32BIT)); + avx2_rev_horz_ana32(atk, ldst, hdst, src, width, even); + } + else + { + assert((ldst == NULL || ldst->flags & line_buf::LFT_64BIT) && + (hdst == NULL || hdst->flags & line_buf::LFT_64BIT) && + (src == NULL || src->flags & line_buf::LFT_64BIT)); + avx2_rev_horz_ana64(atk, ldst, hdst, src, width, even); + } + } ////////////////////////////////////////////////////////////////////////// - void avx2_rev_horz_syn(const param_atk* atk, const line_buf* dst, - const line_buf* lsrc, const line_buf* hsrc, - ui32 width, bool even) + static + void avx2_rev_horz_syn32(const param_atk* atk, const line_buf* dst, + const line_buf* lsrc, const line_buf* hsrc, + ui32 width, bool even) { if (width > 1) { @@ -364,7 +791,7 @@ namespace ojph { const lifting_step* s = atk->get_step(j); const si32 a = s->rev.Aatk; const si32 b = s->rev.Batk; - const si32 e = s->rev.Eatk; + const ui8 e = s->rev.Eatk; __m256i va = _mm256_set1_epi32(a); __m256i vb = _mm256_set1_epi32(b); @@ -499,11 +926,11 @@ namespace ojph { // combine both lsrc and hsrc into dst { - float* dp = dst->f32; - float* spl = lsrc->f32; - float* sph = hsrc->f32; + float* dp = dst->f32; + float* spl = even ? lsrc->f32 : hsrc->f32; + float* sph = even ? hsrc->f32 : lsrc->f32; int w = (int)width; - AVX_INTERLEAVE(dp, spl, sph, w, even); + avx2_interleave32(dp, spl, sph, w); } } else { @@ -514,5 +941,194 @@ namespace ojph { } } + ////////////////////////////////////////////////////////////////////////// + static + void avx2_rev_horz_syn64(const param_atk* atk, const line_buf* dst, + const line_buf* lsrc, const line_buf* hsrc, + ui32 width, bool even) + { + if (width > 1) + { + bool ev = even; + si64* oth = hsrc->i64, * aug = lsrc->i64; + ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass + ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass + ui32 num_steps = atk->get_num_steps(); + for (ui32 j = 0; j < num_steps; ++j) + { + const lifting_step* s = atk->get_step(j); + const si32 a = s->rev.Aatk; + const si32 b = s->rev.Batk; + const ui8 e = s->rev.Eatk; + __m256i va = _mm256_set1_epi64x(a); + __m256i vb = _mm256_set1_epi64x(b); + __m256i ve = _mm256_set1_epi64x(1ULL << (63 - e)); + + // extension + oth[-1] = oth[0]; + oth[oth_width] = oth[oth_width - 1]; + // lifting step + const si64* sp = oth; + si64* dp = aug; + if (a == 1) + { // 5/3 update and any case with a == 1 + int i = (int)aug_width; + if (ev) + { + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp - 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_add_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + } + else + { + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp + 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_add_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + } + } + else if (a == -1 && b == 1 && e == 1) + { // 5/3 predict + int i = (int)aug_width; + if (ev) + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp - 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i w = avx2_mm256_srai_epi64(t, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + else + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp + 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i w = avx2_mm256_srai_epi64(t, e, ve); + d = _mm256_add_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + } + else if (a == -1) + { // any case with a == -1, which is not 5/3 predict + int i = (int)aug_width; + if (ev) + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp - 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_sub_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + else + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp + 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i v = _mm256_sub_epi64(vb, t); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + } + else { + // general case + int i = (int)aug_width; + if (ev) + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp - 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i u = _mm256_mullo_epi64(va, t); + __m256i v = _mm256_add_epi64(vb, u); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + else + for (; i > 0; i -= 4, sp += 4, dp += 4) + { + __m256i s1 = _mm256_load_si256((__m256i*)sp); + __m256i s2 = _mm256_loadu_si256((__m256i*)(sp + 1)); + __m256i d = _mm256_load_si256((__m256i*)dp); + __m256i t = _mm256_add_epi64(s1, s2); + __m256i u = _mm256_mullo_epi64(va, t); + __m256i v = _mm256_add_epi64(vb, u); + __m256i w = avx2_mm256_srai_epi64(v, e, ve); + d = _mm256_sub_epi64(d, w); + _mm256_store_si256((__m256i*)dp, d); + } + } + + // swap buffers + si64* t = aug; aug = oth; oth = t; + ev = !ev; + ui32 w = aug_width; aug_width = oth_width; oth_width = w; + } + + // combine both lsrc and hsrc into dst + { + double* dp = (double*)dst->p; + double* spl = (double*)(even ? lsrc->p : hsrc->p); + double* sph = (double*)(even ? hsrc->p : lsrc->p); + int w = (int)width; + avx2_interleave64(dp, spl, sph, w); + } + } + else { + if (even) + dst->i64[0] = lsrc->i64[0]; + else + dst->i64[0] = hsrc->i64[0] >> 1; + } + } + + ///////////////////////////////////////////////////////////////////////// + void avx2_rev_horz_syn(const param_atk* atk, const line_buf* dst, + const line_buf* lsrc, const line_buf* hsrc, + ui32 width, bool even) + { + if (dst->flags & line_buf::LFT_32BIT) + { + assert((lsrc == NULL || lsrc->flags & line_buf::LFT_32BIT) && + (hsrc == NULL || hsrc->flags & line_buf::LFT_32BIT)); + avx2_rev_horz_syn32(atk, dst, lsrc, hsrc, width, even); + } + else + { + assert((dst == NULL || dst->flags & line_buf::LFT_64BIT) && + (lsrc == NULL || lsrc->flags & line_buf::LFT_64BIT) && + (hsrc == NULL || hsrc->flags & line_buf::LFT_64BIT)); + avx2_rev_horz_syn64(atk, dst, lsrc, hsrc, width, even); + } + } + } // !local } // !ojph diff --git a/src/core/transform/ojph_transform_local.h b/src/core/transform/ojph_transform_local.h index c139ca00..5406124c 100644 --- a/src/core/transform/ojph_transform_local.h +++ b/src/core/transform/ojph_transform_local.h @@ -112,7 +112,7 @@ namespace ojph { ////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////// - #define SSE_DEINTERLEAVE(dpl, dph, sp, width, even) \ + #define SSE_DEINTERLEAVE32(dpl, dph, sp, width, even) \ { \ if (even) \ for (; width > 0; width -= 8, sp += 8, dpl += 4, dph += 4) \ @@ -134,10 +134,10 @@ namespace ojph { _mm_store_ps(dpl, d); \ _mm_store_ps(dph, c); \ } \ - } + } ////////////////////////////////////////////////////////////////////////// - #define SSE_INTERLEAVE(dp, spl, sph, width, even) \ + #define SSE_INTERLEAVE32(dp, spl, sph, width, even) \ { \ if (even) \ for (; width > 0; width -= 8, dp += 8, spl += 4, sph += 4) \ @@ -219,76 +219,6 @@ namespace ojph { // ////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////// - // Supporting macros - ////////////////////////////////////////////////////////////////////////// - - ////////////////////////////////////////////////////////////////////////// - #define AVX_DEINTERLEAVE(dpl, dph, sp, width, even) \ - { \ - if (even) \ - { \ - for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8) \ - { \ - __m256 a = _mm256_load_ps(sp); \ - __m256 b = _mm256_load_ps(sp + 8); \ - __m256 c = _mm256_permute2f128_ps(a, b, (2 << 4) | (0)); \ - __m256 d = _mm256_permute2f128_ps(a, b, (3 << 4) | (1)); \ - __m256 e = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(2, 0, 2, 0)); \ - __m256 f = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(3, 1, 3, 1)); \ - _mm256_store_ps(dpl, e); \ - _mm256_store_ps(dph, f); \ - } \ - } \ - else \ - { \ - for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8) \ - { \ - __m256 a = _mm256_load_ps(sp); \ - __m256 b = _mm256_load_ps(sp + 8); \ - __m256 c = _mm256_permute2f128_ps(a, b, (2 << 4) | (0)); \ - __m256 d = _mm256_permute2f128_ps(a, b, (3 << 4) | (1)); \ - __m256 e = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(2, 0, 2, 0)); \ - __m256 f = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(3, 1, 3, 1)); \ - _mm256_store_ps(dpl, f); \ - _mm256_store_ps(dph, e); \ - } \ - } \ - } - - ////////////////////////////////////////////////////////////////////////// - #define AVX_INTERLEAVE(dp, spl, sph, width, even) \ - { \ - if (even) \ - { \ - for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8) \ - { \ - __m256 a = _mm256_load_ps(spl); \ - __m256 b = _mm256_load_ps(sph); \ - __m256 c = _mm256_unpacklo_ps(a, b); \ - __m256 d = _mm256_unpackhi_ps(a, b); \ - __m256 e = _mm256_permute2f128_ps(c, d, (2 << 4) | (0)); \ - __m256 f = _mm256_permute2f128_ps(c, d, (3 << 4) | (1)); \ - _mm256_store_ps(dp, e); \ - _mm256_store_ps(dp + 8, f); \ - } \ - } \ - else \ - { \ - for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8) \ - { \ - __m256 a = _mm256_load_ps(spl); \ - __m256 b = _mm256_load_ps(sph); \ - __m256 c = _mm256_unpacklo_ps(b, a); \ - __m256 d = _mm256_unpackhi_ps(b, a); \ - __m256 e = _mm256_permute2f128_ps(c, d, (2 << 4) | (0)); \ - __m256 f = _mm256_permute2f128_ps(c, d, (3 << 4) | (1)); \ - _mm256_store_ps(dp, e); \ - _mm256_store_ps(dp + 8, f); \ - } \ - } \ - } - ////////////////////////////////////////////////////////////////////////// // Irreversible functions ////////////////////////////////////////////////////////////////////////// diff --git a/src/core/transform/ojph_transform_sse.cpp b/src/core/transform/ojph_transform_sse.cpp index 897a1939..e878746d 100644 --- a/src/core/transform/ojph_transform_sse.cpp +++ b/src/core/transform/ojph_transform_sse.cpp @@ -104,7 +104,7 @@ namespace ojph { float* dph = hdst->f32; float* sp = src->f32; int w = (int)width; - SSE_DEINTERLEAVE(dpl, dph, sp, w, even); + SSE_DEINTERLEAVE32(dpl, dph, sp, w, even); } // the actual horizontal transform @@ -238,7 +238,7 @@ namespace ojph { float* spl = lsrc->f32; float* sph = hsrc->f32; int w = (int)width; - SSE_INTERLEAVE(dp, spl, sph, w, even); + SSE_INTERLEAVE32(dp, spl, sph, w, even); } } else { diff --git a/src/core/transform/ojph_transform_sse2.cpp b/src/core/transform/ojph_transform_sse2.cpp index 8328842a..21e0409a 100644 --- a/src/core/transform/ojph_transform_sse2.cpp +++ b/src/core/transform/ojph_transform_sse2.cpp @@ -52,13 +52,80 @@ namespace ojph { namespace local { ///////////////////////////////////////////////////////////////////////// - void sse2_rev_vert_step(const lifting_step* s, const line_buf* sig, - const line_buf* other, const line_buf* aug, - ui32 repeat, bool synthesis) + // https://github.com/seung-lab/dijkstra3d/blob/master/libdivide.h + static inline __m128i sse2_mm_srai_epi64(__m128i a, int amt, __m128i m) + { + // note than m must be obtained using + // __m128i ve = _mm_set1_epi64x(1ULL << (63 - amt)); + __m128i x = _mm_srli_epi64(a, amt); + x = _mm_xor_si128(x, m); + __m128i result = _mm_sub_epi64(x, m); + return result; + } + + ////////////////////////////////////////////////////////////////////////// + static inline + void sse2_deinterleave64(double* dpl, double* dph, double* sp, + int width, bool even) + { + if (even) + for (; width > 0; width -= 4, sp += 4, dpl += 2, dph += 2) + { + __m128d a = _mm_load_pd(sp); + __m128d b = _mm_load_pd(sp + 2); + __m128d c = _mm_shuffle_pd(a, b, 0); + __m128d d = _mm_shuffle_pd(a, b, 3); + _mm_store_pd(dpl, c); + _mm_store_pd(dph, d); + } + else + for (; width > 0; width -= 4, sp += 4, dpl += 2, dph += 2) + { + __m128d a = _mm_load_pd(sp); + __m128d b = _mm_load_pd(sp + 2); + __m128d c = _mm_shuffle_pd(a, b, 0); + __m128d d = _mm_shuffle_pd(a, b, 3); + _mm_store_pd(dpl, d); + _mm_store_pd(dph, c); + } + } + + ////////////////////////////////////////////////////////////////////////// + static inline + void sse2_interleave64(double* dp, double* spl, double* sph, + int width, bool even) + { + if (even) + for (; width > 0; width -= 4, dp += 4, spl += 2, sph += 2) + { + __m128d a = _mm_load_pd(spl); + __m128d b = _mm_load_pd(sph); + __m128d c = _mm_unpacklo_pd(a, b); + __m128d d = _mm_unpackhi_pd(a, b); + _mm_store_pd(dp, c); + _mm_store_pd(dp + 2, d); + } + else + for (; width > 0; width -= 4, dp += 4, spl += 2, sph += 2) + { + __m128d a = _mm_load_pd(spl); + __m128d b = _mm_load_pd(sph); + __m128d c = _mm_unpacklo_pd(b, a); + __m128d d = _mm_unpackhi_pd(b, a); + _mm_store_pd(dp, c); + _mm_store_pd(dp + 2, d); + } + } + + ///////////////////////////////////////////////////////////////////////// + static + void sse2_rev_vert_step32(const lifting_step* s, const line_buf* sig, + const line_buf* other, const line_buf* aug, + ui32 repeat, bool synthesis) { const si32 a = s->rev.Aatk; const si32 b = s->rev.Batk; - const si32 e = s->rev.Eatk; + const ui8 e = s->rev.Eatk; __m128i vb = _mm_set1_epi32(b); si32* dst = aug->i32; @@ -162,9 +229,143 @@ namespace ojph { } ///////////////////////////////////////////////////////////////////////// - void sse2_rev_horz_ana(const param_atk* atk, const line_buf* ldst, - const line_buf* hdst, const line_buf* src, - ui32 width, bool even) + static + void sse2_rev_vert_step64(const lifting_step* s, const line_buf* sig, + const line_buf* other, const line_buf* aug, + ui32 repeat, bool synthesis) + { + const si64 a = s->rev.Aatk; + const si64 b = s->rev.Batk; + const ui8 e = s->rev.Eatk; + __m128i vb = _mm_set1_epi64x(b); + __m128i ve = _mm_set1_epi64x(1ULL << (63 - e)); + + si64* dst = aug->i64; + const si64* src1 = sig->i64, * src2 = other->i64; + // The general definition of the wavelet in Part 2 is slightly + // different to part 2, although they are mathematically equivalent + // here, we identify the simpler form from Part 1 and employ them + if (a == 1) + { // 5/3 update and any case with a == 1 + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 2, dst += 2, src1 += 2, src2 += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)src1); + __m128i s2 = _mm_load_si128((__m128i*)src2); + __m128i d = _mm_load_si128((__m128i*)dst); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_add_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_sub_epi64(d, w); + _mm_store_si128((__m128i*)dst, d); + } + else + for (; i > 0; i -= 2, dst += 2, src1 += 2, src2 += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)src1); + __m128i s2 = _mm_load_si128((__m128i*)src2); + __m128i d = _mm_load_si128((__m128i*)dst); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_add_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_add_epi64(d, w); + _mm_store_si128((__m128i*)dst, d); + } + } + else if (a == -1 && b == 1 && e == 1) + { // 5/3 predict + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 2, dst += 2, src1 += 2, src2 += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)src1); + __m128i s2 = _mm_load_si128((__m128i*)src2); + __m128i d = _mm_load_si128((__m128i*)dst); + __m128i t = _mm_add_epi64(s1, s2); + __m128i w = sse2_mm_srai_epi64(t, e, ve); + d = _mm_add_epi64(d, w); + _mm_store_si128((__m128i*)dst, d); + } + else + for (; i > 0; i -= 2, dst += 2, src1 += 2, src2 += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)src1); + __m128i s2 = _mm_load_si128((__m128i*)src2); + __m128i d = _mm_load_si128((__m128i*)dst); + __m128i t = _mm_add_epi64(s1, s2); + __m128i w = sse2_mm_srai_epi64(t, e, ve); + d = _mm_sub_epi64(d, w); + _mm_store_si128((__m128i*)dst, d); + } + } + else if (a == -1) + { // any case with a == -1, which is not 5/3 predict + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 2, dst += 2, src1 += 2, src2 += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)src1); + __m128i s2 = _mm_load_si128((__m128i*)src2); + __m128i d = _mm_load_si128((__m128i*)dst); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_sub_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_sub_epi64(d, w); + _mm_store_si128((__m128i*)dst, d); + } + else + for (; i > 0; i -= 2, dst += 2, src1 += 2, src2 += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)src1); + __m128i s2 = _mm_load_si128((__m128i*)src2); + __m128i d = _mm_load_si128((__m128i*)dst); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_sub_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_add_epi64(d, w); + _mm_store_si128((__m128i*)dst, d); + } + } + else { // general case + // 64bit multiplication is not supported in sse2 + if (synthesis) + for (ui32 i = repeat; i > 0; --i) + *dst++ -= (b + a * (*src1++ + *src2++)) >> e; + else + for (ui32 i = repeat; i > 0; --i) + *dst++ += (b + a * (*src1++ + *src2++)) >> e; + } + } + + ///////////////////////////////////////////////////////////////////////// + void sse2_rev_vert_step(const lifting_step* s, const line_buf* sig, + const line_buf* other, const line_buf* aug, + ui32 repeat, bool synthesis) + { + if (((sig != NULL) && (sig->flags & line_buf::LFT_32BIT)) || + ((aug != NULL) && (aug->flags & line_buf::LFT_32BIT)) || + ((other != NULL) && (other->flags & line_buf::LFT_32BIT))) + { + assert((sig == NULL || sig->flags & line_buf::LFT_32BIT) && + (other == NULL || other->flags & line_buf::LFT_32BIT) && + (aug == NULL || aug->flags & line_buf::LFT_32BIT)); + sse2_rev_vert_step32(s, sig, other, aug, repeat, synthesis); + } + else + { + assert((sig == NULL || sig->flags & line_buf::LFT_64BIT) && + (other == NULL || other->flags & line_buf::LFT_64BIT) && + (aug == NULL || aug->flags & line_buf::LFT_64BIT)); + sse2_rev_vert_step64(s, sig, other, aug, repeat, synthesis); + } + } + + ///////////////////////////////////////////////////////////////////////// + static + void sse2_rev_horz_ana32(const param_atk* atk, const line_buf* ldst, + const line_buf* hdst, const line_buf* src, + ui32 width, bool even) { if (width > 1) { @@ -174,7 +375,7 @@ namespace ojph { float* dph = hdst->f32; float* sp = src->f32; int w = (int)width; - SSE_DEINTERLEAVE(dpl, dph, sp, w, even); + SSE_DEINTERLEAVE32(dpl, dph, sp, w, even); } si32* hp = hdst->i32, * lp = ldst->i32; @@ -187,7 +388,7 @@ namespace ojph { const lifting_step* s = atk->get_step(j - 1); const si32 a = s->rev.Aatk; const si32 b = s->rev.Batk; - const si32 e = s->rev.Eatk; + const ui8 e = s->rev.Eatk; __m128i vb = _mm_set1_epi32(b); // extension @@ -284,9 +485,7 @@ namespace ojph { } else { // general case - // 32bit multiplication is not supported in sse2; we need sse4.1, - // where we can use _mm_mullo_epi32, which multiplies - // 32bit x 32bit, keeping the LSBs + // 64bit multiplication is not supported in sse2. if (even) for (ui32 i = h_width; i > 0; --i, sp++, dp++) *dp += (b + a * (sp[0] + sp[1])) >> e; @@ -308,11 +507,181 @@ namespace ojph { hdst->i32[0] = src->i32[0] << 1; } } + + ///////////////////////////////////////////////////////////////////////// + static + void sse2_rev_horz_ana64(const param_atk* atk, const line_buf* ldst, + const line_buf* hdst, const line_buf* src, + ui32 width, bool even) + { + if (width > 1) + { + // combine both lsrc and hsrc into dst + { + double* dpl = (double*)ldst->p; + double* dph = (double*)hdst->p; + double* sp = (double*)src->p; + int w = (int)width; + sse2_deinterleave64(dpl, dph, sp, w, even); + } + + si64* hp = hdst->i64, * lp = ldst->i64; + ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass + ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass + ui32 num_steps = atk->get_num_steps(); + for (ui32 j = num_steps; j > 0; --j) + { + // first lifting step + const lifting_step* s = atk->get_step(j - 1); + const si32 a = s->rev.Aatk; + const si32 b = s->rev.Batk; + const ui8 e = s->rev.Eatk; + __m128i vb = _mm_set1_epi64x(b); + __m128i ve = _mm_set1_epi64x(1ULL << (63 - e)); + + // extension + lp[-1] = lp[0]; + lp[l_width] = lp[l_width - 1]; + // lifting step + const si64* sp = lp; + si64* dp = hp; + if (a == 1) + { // 5/3 update and any case with a == 1 + int i = (int)h_width; + if (even) + { + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp + 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_add_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_add_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + } + else + { + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp - 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_add_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_add_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + } + } + else if (a == -1 && b == 1 && e == 1) + { // 5/3 predict + int i = (int)h_width; + if (even) + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp + 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i w = sse2_mm_srai_epi64(t, e, ve); + d = _mm_sub_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + else + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp - 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i w = sse2_mm_srai_epi64(t, e, ve); + d = _mm_sub_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + } + else if (a == -1) + { // any case with a == -1, which is not 5/3 predict + int i = (int)h_width; + if (even) + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp + 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_sub_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_add_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + else + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp - 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_sub_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_add_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + } + else { + // general case + // 32bit multiplication is not supported in sse2; we need sse4.1, + // where we can use _mm_mullo_epi32, which multiplies + // 32bit x 32bit, keeping the LSBs + if (even) + for (ui32 i = h_width; i > 0; --i, sp++, dp++) + *dp += (b + a * (sp[0] + sp[1])) >> e; + else + for (ui32 i = h_width; i > 0; --i, sp++, dp++) + *dp += (b + a * (sp[-1] + sp[0])) >> e; + } + + // swap buffers + si64* t = lp; lp = hp; hp = t; + even = !even; + ui32 w = l_width; l_width = h_width; h_width = w; + } + } + else { + if (even) + ldst->i64[0] = src->i64[0]; + else + hdst->i64[0] = src->i64[0] << 1; + } + } + + ///////////////////////////////////////////////////////////////////////// + void sse2_rev_horz_ana(const param_atk* atk, const line_buf* ldst, + const line_buf* hdst, const line_buf* src, + ui32 width, bool even) + { + if (src->flags & line_buf::LFT_32BIT) + { + assert((ldst == NULL || ldst->flags & line_buf::LFT_32BIT) && + (hdst == NULL || hdst->flags & line_buf::LFT_32BIT)); + sse2_rev_horz_ana32(atk, ldst, hdst, src, width, even); + } + else + { + assert((ldst == NULL || ldst->flags & line_buf::LFT_64BIT) && + (hdst == NULL || hdst->flags & line_buf::LFT_64BIT) && + (src == NULL || src->flags & line_buf::LFT_64BIT)); + sse2_rev_horz_ana64(atk, ldst, hdst, src, width, even); + } + } ////////////////////////////////////////////////////////////////////////// - void sse2_rev_horz_syn(const param_atk* atk, const line_buf* dst, - const line_buf* lsrc, const line_buf* hsrc, - ui32 width, bool even) + void sse2_rev_horz_syn32(const param_atk* atk, const line_buf* dst, + const line_buf* lsrc, const line_buf* hsrc, + ui32 width, bool even) { if (width > 1) { @@ -326,7 +695,7 @@ namespace ojph { const lifting_step* s = atk->get_step(j); const si32 a = s->rev.Aatk; const si32 b = s->rev.Batk; - const si32 e = s->rev.Eatk; + const ui8 e = s->rev.Eatk; __m128i vb = _mm_set1_epi32(b); // extension @@ -446,7 +815,7 @@ namespace ojph { float* spl = lsrc->f32; float* sph = hsrc->f32; int w = (int)width; - SSE_INTERLEAVE(dp, spl, sph, w, even); + SSE_INTERLEAVE32(dp, spl, sph, w, even); } } else { @@ -457,5 +826,174 @@ namespace ojph { } } + ////////////////////////////////////////////////////////////////////////// + void sse2_rev_horz_syn64(const param_atk* atk, const line_buf* dst, + const line_buf* lsrc, const line_buf* hsrc, + ui32 width, bool even) + { + if (width > 1) + { + bool ev = even; + si64* oth = hsrc->i64, * aug = lsrc->i64; + ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass + ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass + ui32 num_steps = atk->get_num_steps(); + for (ui32 j = 0; j < num_steps; ++j) + { + const lifting_step* s = atk->get_step(j); + const si32 a = s->rev.Aatk; + const si32 b = s->rev.Batk; + const ui8 e = s->rev.Eatk; + __m128i vb = _mm_set1_epi64x(b); + __m128i ve = _mm_set1_epi64x(1ULL << (63 - e)); + + // extension + oth[-1] = oth[0]; + oth[oth_width] = oth[oth_width - 1]; + // lifting step + const si64* sp = oth; + si64* dp = aug; + if (a == 1) + { // 5/3 update and any case with a == 1 + int i = (int)aug_width; + if (ev) + { + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp - 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_add_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_sub_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + } + else + { + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp + 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_add_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_sub_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + } + } + else if (a == -1 && b == 1 && e == 1) + { // 5/3 predict + int i = (int)aug_width; + if (ev) + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp - 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i w = sse2_mm_srai_epi64(t, e, ve); + d = _mm_add_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + else + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp + 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i w = sse2_mm_srai_epi64(t, e, ve); + d = _mm_add_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + } + else if (a == -1) + { // any case with a == -1, which is not 5/3 predict + int i = (int)aug_width; + if (ev) + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp - 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_sub_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_sub_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + else + for (; i > 0; i -= 2, sp += 2, dp += 2) + { + __m128i s1 = _mm_load_si128((__m128i*)sp); + __m128i s2 = _mm_loadu_si128((__m128i*)(sp + 1)); + __m128i d = _mm_load_si128((__m128i*)dp); + __m128i t = _mm_add_epi64(s1, s2); + __m128i v = _mm_sub_epi64(vb, t); + __m128i w = sse2_mm_srai_epi64(v, e, ve); + d = _mm_sub_epi64(d, w); + _mm_store_si128((__m128i*)dp, d); + } + } + else { + // general case + // 32bit multiplication is not supported in sse2; we need sse4.1, + // where we can use _mm_mullo_epi32, which multiplies + // 32bit x 32bit, keeping the LSBs + if (ev) + for (ui32 i = aug_width; i > 0; --i, sp++, dp++) + *dp -= (b + a * (sp[-1] + sp[0])) >> e; + else + for (ui32 i = aug_width; i > 0; --i, sp++, dp++) + *dp -= (b + a * (sp[0] + sp[1])) >> e; + } + + // swap buffers + si64* t = aug; aug = oth; oth = t; + ev = !ev; + ui32 w = aug_width; aug_width = oth_width; oth_width = w; + } + + // combine both lsrc and hsrc into dst + { + double* dp = (double*)dst->p; + double* spl = (double*)lsrc->p; + double* sph = (double*)hsrc->p; + int w = (int)width; + sse2_interleave64(dp, spl, sph, w, even); + } + } + else { + if (even) + dst->i64[0] = lsrc->i64[0]; + else + dst->i64[0] = hsrc->i64[0] >> 1; + } + } + + ///////////////////////////////////////////////////////////////////////// + void sse2_rev_horz_syn(const param_atk* atk, const line_buf* dst, + const line_buf* lsrc, const line_buf* hsrc, + ui32 width, bool even) + { + if (dst->flags & line_buf::LFT_32BIT) + { + assert((lsrc == NULL || lsrc->flags & line_buf::LFT_32BIT) && + (hsrc == NULL || hsrc->flags & line_buf::LFT_32BIT)); + sse2_rev_horz_syn32(atk, dst, lsrc, hsrc, width, even); + } + else + { + assert((dst == NULL || dst->flags & line_buf::LFT_64BIT) && + (lsrc == NULL || lsrc->flags & line_buf::LFT_64BIT) && + (hsrc == NULL || hsrc->flags & line_buf::LFT_64BIT)); + sse2_rev_horz_syn64(atk, dst, lsrc, hsrc, width, even); + } + } + } // !local } // !ojph