From 1c4a14ce94a3fcd2073318eb86027106033a396b Mon Sep 17 00:00:00 2001 From: Aous Naman Date: Sat, 13 Apr 2024 09:45:20 +1000 Subject: [PATCH] avx512 dwt implemented --- src/core/CMakeLists.txt | 9 +- src/core/common/ojph_arch.h | 6 +- src/core/transform/ojph_transform.cpp | 24 +- src/core/transform/ojph_transform_avx.cpp | 9 +- src/core/transform/ojph_transform_avx2.cpp | 2 - src/core/transform/ojph_transform_avx512.cpp | 830 +++++++++++++++++++ src/core/transform/ojph_transform_local.h | 48 +- 7 files changed, 855 insertions(+), 73 deletions(-) create mode 100644 src/core/transform/ojph_transform_avx512.cpp diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 40b9649b..19123a2e 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -18,11 +18,12 @@ file(GLOB TRANSFORM_SSE "transform/*_sse.cpp") file(GLOB TRANSFORM_SSE2 "transform/*_sse2.cpp") file(GLOB TRANSFORM_AVX "transform/*_avx.cpp") file(GLOB TRANSFORM_AVX2 "transform/*_avx2.cpp") +file(GLOB TRANSFORM_AVX512 "transform/*_avx512.cpp") file(GLOB TRANSFORM_WASM "transform/*_wasm.cpp") list(REMOVE_ITEM CODESTREAM ${CODESTREAM_SSE} ${CODESTREAM_SSE2} ${CODESTREAM_AVX} ${CODESTREAM_AVX2} ${CODESTREAM_WASM}) list(REMOVE_ITEM CODING ${CODING_SSSE3} ${CODING_WASM} ${CODING_AVX512}) -list(REMOVE_ITEM TRANSFORM ${TRANSFORM_SSE} ${TRANSFORM_SSE2} ${TRANSFORM_AVX} ${TRANSFORM_AVX2} ${TRANSFORM_WASM}) +list(REMOVE_ITEM TRANSFORM ${TRANSFORM_SSE} ${TRANSFORM_SSE2} ${TRANSFORM_AVX} ${TRANSFORM_AVX2} ${TRANSFORM_AVX512} ${TRANSFORM_WASM}) list(APPEND SOURCES ${CODESTREAM} ${CODING} ${COMMON} ${OTHERS} ${TRANSFORM}) source_group("codestream" FILES ${CODESTREAM}) @@ -42,10 +43,10 @@ if(EMSCRIPTEN) source_group("coding" FILES ${CODING_WASM}) source_group("transform" FILES ${TRANSFORM_WASM}) elseif(NOT OJPH_DISABLE_INTEL_SIMD) - add_library(openjph ${SOURCES} ${CODESTREAM_SSE} ${CODESTREAM_SSE2} ${CODESTREAM_AVX} ${CODESTREAM_AVX2} ${CODING_SSSE3} ${TRANSFORM_SSE} ${TRANSFORM_SSE2} ${TRANSFORM_AVX} ${TRANSFORM_AVX2}) + add_library(openjph ${SOURCES} ${CODESTREAM_SSE} ${CODESTREAM_SSE2} ${CODESTREAM_AVX} ${CODESTREAM_AVX2} ${CODING_SSSE3} ${TRANSFORM_SSE} ${TRANSFORM_SSE2} ${TRANSFORM_AVX} ${TRANSFORM_AVX2} ${TRANSFORM_AVX512}) source_group("codestream" FILES ${CODESTREAM_SSE} ${CODESTREAM_SSE2} ${CODESTREAM_AVX} ${CODESTREAM_AVX2}) source_group("coding" FILES ${CODING_SSSE3}) - source_group("transform" FILES ${TRANSFORM_SSE} ${TRANSFORM_SSE2} ${TRANSFORM_AVX} ${TRANSFORM_AVX2}) + source_group("transform" FILES ${TRANSFORM_SSE} ${TRANSFORM_SSE2} ${TRANSFORM_AVX} ${TRANSFORM_AVX2} ${TRANSFORM_AVX512}) if (OJPH_ENABLE_INTEL_AVX512) target_sources(openjph PRIVATE ${CODING_AVX512}) source_group("coding" FILES ${CODING_AVX512}) @@ -71,6 +72,7 @@ if (MSVC) set_source_files_properties(transform/ojph_colour_avx2.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX2") set_source_files_properties(transform/ojph_transform_avx.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX") set_source_files_properties(transform/ojph_transform_avx2.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX2") + set_source_files_properties(transform/ojph_transform_avx512.cpp PROPERTIES COMPILE_FLAGS "/arch:AVX512") else() set_source_files_properties(codestream/ojph_codestream_avx.cpp PROPERTIES COMPILE_FLAGS -mavx) set_source_files_properties(codestream/ojph_codestream_avx2.cpp PROPERTIES COMPILE_FLAGS -mavx2) @@ -80,6 +82,7 @@ else() set_source_files_properties(transform/ojph_colour_avx2.cpp PROPERTIES COMPILE_FLAGS -mavx2) set_source_files_properties(transform/ojph_transform_avx.cpp PROPERTIES COMPILE_FLAGS -mavx) set_source_files_properties(transform/ojph_transform_avx2.cpp PROPERTIES COMPILE_FLAGS -mavx2) + set_source_files_properties(transform/ojph_transform_avx512.cpp PROPERTIES COMPILE_FLAGS -mavx512f) endif() if (MSVC) diff --git a/src/core/common/ojph_arch.h b/src/core/common/ojph_arch.h index 62b630bb..fa9d077d 100644 --- a/src/core/common/ojph_arch.h +++ b/src/core/common/ojph_arch.h @@ -194,11 +194,7 @@ namespace ojph { //////////////////////////////////////////////////////////////////////////// // constants //////////////////////////////////////////////////////////////////////////// -#ifdef OJPH_ENABLE_INTEL_AVX512 - const ui32 byte_alignment = 64; //64 bytes == 512 bits -#else - const ui32 byte_alignment = 32; //32 bytes == 256 bits -#endif + const ui32 byte_alignment = 64; // 64 bytes == 512 bits const ui32 log_byte_alignment = 31 - count_leading_zeros(byte_alignment); const ui32 object_alignment = 8; diff --git a/src/core/transform/ojph_transform.cpp b/src/core/transform/ojph_transform.cpp index 95ab686c..83eed644 100644 --- a/src/core/transform/ojph_transform.cpp +++ b/src/core/transform/ojph_transform.cpp @@ -145,17 +145,19 @@ namespace ojph { rev_horz_syn = avx2_rev_horz_syn; } - //if (level >= X86_CPU_EXT_LEVEL_AVX512) - //{ - // rev_vert_step = avx512_rev_vert_ana_step; - // rev_horz_ana = avx512_rev_horz_ana; - // rev_horz_syn = avx512_rev_horz_syn; - - // irv_vert_step = avx512_irv_vert_step; - // irv_vert_times_K = avx512_irv_vert_times_K; - // irv_vert_syn_step = avx512_irv_vert_syn_step; - // irv_horz_syn = avx512_irv_horz_syn; - //} +#ifdef OJPH_ENABLE_INTEL_AVX512 + if (level >= X86_CPU_EXT_LEVEL_AVX512) + { + rev_vert_step = avx512_rev_vert_step; + rev_horz_ana = avx512_rev_horz_ana; + rev_horz_syn = avx512_rev_horz_syn; + + irv_vert_step = avx512_irv_vert_step; + irv_vert_times_K = avx512_irv_vert_times_K; + irv_horz_ana = avx512_irv_horz_ana; + irv_horz_syn = avx512_irv_horz_syn; + } +#endif // !OJPH_ENABLE_INTEL_AVX512 #endif // !OJPH_DISABLE_INTEL_SIMD diff --git a/src/core/transform/ojph_transform_avx.cpp b/src/core/transform/ojph_transform_avx.cpp index e7933ff1..08566624 100644 --- a/src/core/transform/ojph_transform_avx.cpp +++ b/src/core/transform/ojph_transform_avx.cpp @@ -88,14 +88,7 @@ namespace ojph { ////////////////////////////////////////////////////////////////////////// void avx_irv_vert_times_K(float K, const line_buf* aug, ui32 repeat) { - __m256 factor = _mm256_set1_ps(K); - float* dst = aug->f32; - int i = (int)repeat; - for (; i > 0; i -= 8, dst += 8) - { - __m256 s = _mm256_load_ps(dst); - _mm256_store_ps(dst, _mm256_mul_ps(factor, s)); - } + avx_multiply_const(aug->f32, K, (int)repeat); } ///////////////////////////////////////////////////////////////////////// diff --git a/src/core/transform/ojph_transform_avx2.cpp b/src/core/transform/ojph_transform_avx2.cpp index 243fe87f..847cd4c4 100644 --- a/src/core/transform/ojph_transform_avx2.cpp +++ b/src/core/transform/ojph_transform_avx2.cpp @@ -514,7 +514,5 @@ namespace ojph { } } - - } // !local } // !ojph diff --git a/src/core/transform/ojph_transform_avx512.cpp b/src/core/transform/ojph_transform_avx512.cpp new file mode 100644 index 00000000..efb7655a --- /dev/null +++ b/src/core/transform/ojph_transform_avx512.cpp @@ -0,0 +1,830 @@ +//***************************************************************************/ +// This software is released under the 2-Clause BSD license, included +// below. +// +// Copyright (c) 2019, Aous Naman +// Copyright (c) 2019, Kakadu Software Pty Ltd, Australia +// Copyright (c) 2019, The University of New South Wales, Australia +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +// TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +// PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +// TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************/ +// This file is part of the OpenJPH software implementation. +// File: ojph_transform_avx2.cpp +// Author: Aous Naman +// Date: 28 August 2019 +//***************************************************************************/ + +#include + +#include "ojph_defs.h" +#include "ojph_arch.h" +#include "ojph_mem.h" +#include "ojph_params.h" +#include "../codestream/ojph_params_local.h" + +#include "ojph_transform.h" +#include "ojph_transform_local.h" + +#include + +namespace ojph { + namespace local { + + ////////////////////////////////////////////////////////////////////////// + // We split multiples of 32 followed by multiples of 16, because + // we assume byte_alignment == 64 + static void avx512_deinterleave(float* dpl, float* dph, float* sp, + int width, bool even) + { + __m512i idx1 = _mm512_set_epi32( + 0x1E, 0x1C, 0x1A, 0x18, 0x16, 0x14, 0x12, 0x10, + 0x0E, 0x0C, 0x0A, 0x08, 0x06, 0x04, 0x02, 0x00 + ); + __m512i idx2 = _mm512_set_epi32( + 0x1F, 0x1D, 0x1B, 0x19, 0x17, 0x15, 0x13, 0x11, + 0x0F, 0x0D, 0x0B, 0x09, 0x07, 0x05, 0x03, 0x01 + ); + if (even) + { + for (; width > 16; width -= 32, sp += 32, dpl += 16, dph += 16) + { + __m512 a = _mm512_load_ps(sp); + __m512 b = _mm512_load_ps(sp + 16); + __m512 c = _mm512_permutex2var_ps(a, idx1, b); + __m512 d = _mm512_permutex2var_ps(a, idx2, b); + _mm512_store_ps(dpl, c); + _mm512_store_ps(dph, d); + } + for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8) + { + __m256 a = _mm256_load_ps(sp); + __m256 b = _mm256_load_ps(sp + 8); + __m256 c = _mm256_permute2f128_ps(a, b, (2 << 4) | (0)); + __m256 d = _mm256_permute2f128_ps(a, b, (3 << 4) | (1)); + __m256 e = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(2, 0, 2, 0)); + __m256 f = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(3, 1, 3, 1)); + _mm256_store_ps(dpl, e); + _mm256_store_ps(dph, f); + } + } + else + { + for (; width > 16; width -= 32, sp += 32, dpl += 16, dph += 16) + { + __m512 a = _mm512_load_ps(sp); + __m512 b = _mm512_load_ps(sp + 16); + __m512 c = _mm512_permutex2var_ps(a, idx2, b); + __m512 d = _mm512_permutex2var_ps(a, idx1, b); + _mm512_store_ps(dpl, c); + _mm512_store_ps(dph, d); + } + for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8) + { + __m256 a = _mm256_load_ps(sp); + __m256 b = _mm256_load_ps(sp + 8); + __m256 c = _mm256_permute2f128_ps(a, b, (2 << 4) | (0)); + __m256 d = _mm256_permute2f128_ps(a, b, (3 << 4) | (1)); + __m256 e = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(2, 0, 2, 0)); + __m256 f = _mm256_shuffle_ps(c, d, _MM_SHUFFLE(3, 1, 3, 1)); + _mm256_store_ps(dpl, f); + _mm256_store_ps(dph, e); + } + } + } + + ////////////////////////////////////////////////////////////////////////// + // We split multiples of 32 followed by multiples of 16, because + // we assume byte_alignment == 64 + static void avx512_interleave(float* dp, float* spl, float* sph, + int width, bool even) + { + __m512i idx1 = _mm512_set_epi32( + 0x17, 0x7, 0x16, 0x6, 0x15, 0x5, 0x14, 0x4, + 0x13, 0x3, 0x12, 0x2, 0x11, 0x1, 0x10, 0x0 + ); + __m512i idx2 = _mm512_set_epi32( + 0x1F, 0xF, 0x1E, 0xE, 0x1D, 0xD, 0x1C, 0xC, + 0x1B, 0xB, 0x1A, 0xA, 0x19, 0x9, 0x18, 0x8 + ); + if (even) + { + for (; width > 16; width -= 32, dp += 32, spl += 16, sph += 16) + { + __m512 a = _mm512_load_ps(spl); + __m512 b = _mm512_load_ps(sph); + __m512 c = _mm512_permutex2var_ps(a, idx1, b); + __m512 d = _mm512_permutex2var_ps(a, idx2, b); + _mm512_store_ps(dp, c); + _mm512_store_ps(dp + 16, d); + } + for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8) + { + __m256 a = _mm256_load_ps(spl); + __m256 b = _mm256_load_ps(sph); + __m256 c = _mm256_unpacklo_ps(a, b); + __m256 d = _mm256_unpackhi_ps(a, b); + __m256 e = _mm256_permute2f128_ps(c, d, (2 << 4) | (0)); + __m256 f = _mm256_permute2f128_ps(c, d, (3 << 4) | (1)); + _mm256_store_ps(dp, e); + _mm256_store_ps(dp + 8, f); + } + } + else + { + for (; width > 16; width -= 32, dp += 32, spl += 16, sph += 16) + { + __m512 a = _mm512_load_ps(spl); + __m512 b = _mm512_load_ps(sph); + __m512 c = _mm512_permutex2var_ps(b, idx1, a); + __m512 d = _mm512_permutex2var_ps(b, idx2, a); + _mm512_store_ps(dp, c); + _mm512_store_ps(dp + 16, d); + } + for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8) + { + __m256 a = _mm256_load_ps(spl); + __m256 b = _mm256_load_ps(sph); + __m256 c = _mm256_unpacklo_ps(b, a); + __m256 d = _mm256_unpackhi_ps(b, a); + __m256 e = _mm256_permute2f128_ps(c, d, (2 << 4) | (0)); + __m256 f = _mm256_permute2f128_ps(c, d, (3 << 4) | (1)); + _mm256_store_ps(dp, e); + _mm256_store_ps(dp + 8, f); + } + } + } + + ////////////////////////////////////////////////////////////////////////// + static inline void avx512_multiply_const(float* p, float f, int width) + { + __m512 factor = _mm512_set1_ps(f); + for (; width > 0; width -= 16, p += 16) + { + __m512 s = _mm512_load_ps(p); + _mm512_store_ps(p, _mm512_mul_ps(factor, s)); + } + } + + ////////////////////////////////////////////////////////////////////////// + void avx512_irv_vert_step(const lifting_step* s, const line_buf* sig, + const line_buf* other, const line_buf* aug, + ui32 repeat, bool synthesis) + { + float a = s->irv.Aatk; + if (synthesis) + a = -a; + + __m512 factor = _mm512_set1_ps(a); + + float* dst = aug->f32; + const float* src1 = sig->f32, * src2 = other->f32; + int i = (int)repeat; + for ( ; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16) + { + __m512 s1 = _mm512_load_ps(src1); + __m512 s2 = _mm512_load_ps(src2); + __m512 d = _mm512_load_ps(dst); + d = _mm512_add_ps(d, _mm512_mul_ps(factor, _mm512_add_ps(s1, s2))); + _mm512_store_ps(dst, d); + } + } + + ////////////////////////////////////////////////////////////////////////// + void avx512_irv_vert_times_K(float K, const line_buf* aug, ui32 repeat) + { + avx512_multiply_const(aug->f32, K, (int)repeat); + } + + ///////////////////////////////////////////////////////////////////////// + void avx512_irv_horz_ana(const param_atk* atk, const line_buf* ldst, + const line_buf* hdst, const line_buf* src, + ui32 width, bool even) + { + if (width > 1) + { + // split src into ldst and hdst + { + float* dpl = ldst->f32; + float* dph = hdst->f32; + float* sp = src->f32; + int w = (int)width; + AVX_DEINTERLEAVE(dpl, dph, sp, w, even); + } + + // the actual horizontal transform + float* hp = hdst->f32, * lp = ldst->f32; + ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass + ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass + ui32 num_steps = atk->get_num_steps(); + for (ui32 j = num_steps; j > 0; --j) + { + const lifting_step* s = atk->get_step(j - 1); + const float a = s->irv.Aatk; + + // extension + lp[-1] = lp[0]; + lp[l_width] = lp[l_width - 1]; + // lifting step + const float* sp = lp; + float* dp = hp; + int i = (int)h_width; + __m512 f = _mm512_set1_ps(a); + if (even) + { + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512 m = _mm512_load_ps(sp); + __m512 n = _mm512_loadu_ps(sp + 1); + __m512 p = _mm512_load_ps(dp); + p = _mm512_add_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n))); + _mm512_store_ps(dp, p); + } + } + else + { + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512 m = _mm512_load_ps(sp); + __m512 n = _mm512_loadu_ps(sp - 1); + __m512 p = _mm512_load_ps(dp); + p = _mm512_add_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n))); + _mm512_store_ps(dp, p); + } + } + + // swap buffers + float* t = lp; lp = hp; hp = t; + even = !even; + ui32 w = l_width; l_width = h_width; h_width = w; + } + + { // multiply by K or 1/K + float K = atk->get_K(); + float K_inv = 1.0f / K; + avx512_multiply_const(lp, K_inv, (int)l_width); + avx512_multiply_const(hp, K, (int)h_width); + } + } + else { + if (even) + ldst->f32[0] = src->f32[0]; + else + hdst->f32[0] = src->f32[0] * 2.0f; + } + } + + ////////////////////////////////////////////////////////////////////////// + void avx512_irv_horz_syn(const param_atk* atk, const line_buf* dst, + const line_buf* lsrc, const line_buf* hsrc, + ui32 width, bool even) + { + if (width > 1) + { + bool ev = even; + float* oth = hsrc->f32, * aug = lsrc->f32; + ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass + ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass + + { // multiply by K or 1/K + float K = atk->get_K(); + float K_inv = 1.0f / K; + avx512_multiply_const(aug, K, (int)aug_width); + avx512_multiply_const(oth, K_inv, (int)oth_width); + } + + // the actual horizontal transform + ui32 num_steps = atk->get_num_steps(); + for (ui32 j = 0; j < num_steps; ++j) + { + const lifting_step* s = atk->get_step(j); + const float a = s->irv.Aatk; + + // extension + oth[-1] = oth[0]; + oth[oth_width] = oth[oth_width - 1]; + // lifting step + const float* sp = oth; + float* dp = aug; + int i = (int)aug_width; + __m512 f = _mm512_set1_ps(a); + if (ev) + { + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512 m = _mm512_load_ps(sp); + __m512 n = _mm512_loadu_ps(sp - 1); + __m512 p = _mm512_load_ps(dp); + p = _mm512_sub_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n))); + _mm512_store_ps(dp, p); + } + } + else + { + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512 m = _mm512_load_ps(sp); + __m512 n = _mm512_loadu_ps(sp + 1); + __m512 p = _mm512_load_ps(dp); + p = _mm512_sub_ps(p, _mm512_mul_ps(f, _mm512_add_ps(m, n))); + _mm512_store_ps(dp, p); + } + } + + // swap buffers + float* t = aug; aug = oth; oth = t; + ev = !ev; + ui32 w = aug_width; aug_width = oth_width; oth_width = w; + } + + // combine both lsrc and hsrc into dst + avx512_interleave(dst->f32, lsrc->f32, hsrc->f32, (int)width, even); + } + else { + if (even) + dst->f32[0] = lsrc->f32[0]; + else + dst->f32[0] = hsrc->f32[0] * 0.5f; + } + } + + + ///////////////////////////////////////////////////////////////////////// + void avx512_rev_vert_step(const lifting_step* s, const line_buf* sig, + const line_buf* other, const line_buf* aug, + ui32 repeat, bool synthesis) + { + const si32 a = s->rev.Aatk; + const si32 b = s->rev.Batk; + const si32 e = s->rev.Eatk; + __m512i va = _mm512_set1_epi32(a); + __m512i vb = _mm512_set1_epi32(b); + + si32* dst = aug->i32; + const si32* src1 = sig->i32, * src2 = other->i32; + // The general definition of the wavelet in Part 2 is slightly + // different to part 2, although they are mathematically equivalent + // here, we identify the simpler form from Part 1 and employ them + if (a == 1) + { // 5/3 update and any case with a == 1 + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)src1); + __m512i s2 = _mm512_load_si512((__m512i*)src2); + __m512i d = _mm512_load_si512((__m512i*)dst); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_add_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dst, d); + } + else + for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)src1); + __m512i s2 = _mm512_load_si512((__m512i*)src2); + __m512i d = _mm512_load_si512((__m512i*)dst); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_add_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dst, d); + } + } + else if (a == -1 && b == 1 && e == 1) + { // 5/3 predict + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)src1); + __m512i s2 = _mm512_load_si512((__m512i*)src2); + __m512i d = _mm512_load_si512((__m512i*)dst); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i w = _mm512_srai_epi32(t, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dst, d); + } + else + for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)src1); + __m512i s2 = _mm512_load_si512((__m512i*)src2); + __m512i d = _mm512_load_si512((__m512i*)dst); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i w = _mm512_srai_epi32(t, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dst, d); + } + } + else if (a == -1) + { // any case with a == -1, which is not 5/3 predict + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)src1); + __m512i s2 = _mm512_load_si512((__m512i*)src2); + __m512i d = _mm512_load_si512((__m512i*)dst); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_sub_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dst, d); + } + else + for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)src1); + __m512i s2 = _mm512_load_si512((__m512i*)src2); + __m512i d = _mm512_load_si512((__m512i*)dst); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_sub_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dst, d); + } + } + else { // general case + int i = (int)repeat; + if (synthesis) + for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)src1); + __m512i s2 = _mm512_load_si512((__m512i*)src2); + __m512i d = _mm512_load_si512((__m512i*)dst); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i u = _mm512_mullo_epi32(va, t); + __m512i v = _mm512_add_epi32(vb, u); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dst, d); + } + else + for (; i > 0; i -= 16, dst += 16, src1 += 16, src2 += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)src1); + __m512i s2 = _mm512_load_si512((__m512i*)src2); + __m512i d = _mm512_load_si512((__m512i*)dst); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i u = _mm512_mullo_epi32(va, t); + __m512i v = _mm512_add_epi32(vb, u); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dst, d); + } + } + } + + ///////////////////////////////////////////////////////////////////////// + void avx512_rev_horz_ana(const param_atk* atk, const line_buf* ldst, + const line_buf* hdst, const line_buf* src, + ui32 width, bool even) + { + if (width > 1) + { + // combine both lsrc and hsrc into dst + { + float* dpl = ldst->f32; + float* dph = hdst->f32; + float* sp = src->f32; + int w = (int)width; + AVX_DEINTERLEAVE(dpl, dph, sp, w, even); + } + + si32* hp = hdst->i32, * lp = ldst->i32; + ui32 l_width = (width + (even ? 1 : 0)) >> 1; // low pass + ui32 h_width = (width + (even ? 0 : 1)) >> 1; // high pass + ui32 num_steps = atk->get_num_steps(); + for (ui32 j = num_steps; j > 0; --j) + { + // first lifting step + const lifting_step* s = atk->get_step(j - 1); + const si32 a = s->rev.Aatk; + const si32 b = s->rev.Batk; + const si32 e = s->rev.Eatk; + __m512i va = _mm512_set1_epi32(a); + __m512i vb = _mm512_set1_epi32(b); + + // extension + lp[-1] = lp[0]; + lp[l_width] = lp[l_width - 1]; + // lifting step + const si32* sp = lp; + si32* dp = hp; + if (a == 1) + { // 5/3 update and any case with a == 1 + int i = (int)h_width; + if (even) + { + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_add_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + } + else + { + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_add_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + } + } + else if (a == -1 && b == 1 && e == 1) + { // 5/3 predict + int i = (int)h_width; + if (even) + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i w = _mm512_srai_epi32(t, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + else + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i w = _mm512_srai_epi32(t, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + } + else if (a == -1) + { // any case with a == -1, which is not 5/3 predict + int i = (int)h_width; + if (even) + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_sub_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + else + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_sub_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + } + else { + // general case + int i = (int)h_width; + if (even) + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i u = _mm512_mullo_epi32(va, t); + __m512i v = _mm512_add_epi32(vb, u); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + else + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i u = _mm512_mullo_epi32(va, t); + __m512i v = _mm512_add_epi32(vb, u); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + } + + // swap buffers + si32* t = lp; lp = hp; hp = t; + even = !even; + ui32 w = l_width; l_width = h_width; h_width = w; + } + } + else { + if (even) + ldst->i32[0] = src->i32[0]; + else + hdst->i32[0] = src->i32[0] << 1; + } + } + + ////////////////////////////////////////////////////////////////////////// + void avx512_rev_horz_syn(const param_atk* atk, const line_buf* dst, + const line_buf* lsrc, const line_buf* hsrc, + ui32 width, bool even) + { + if (width > 1) + { + bool ev = even; + si32* oth = hsrc->i32, * aug = lsrc->i32; + ui32 aug_width = (width + (even ? 1 : 0)) >> 1; // low pass + ui32 oth_width = (width + (even ? 0 : 1)) >> 1; // high pass + ui32 num_steps = atk->get_num_steps(); + for (ui32 j = 0; j < num_steps; ++j) + { + const lifting_step* s = atk->get_step(j); + const si32 a = s->rev.Aatk; + const si32 b = s->rev.Batk; + const si32 e = s->rev.Eatk; + __m512i va = _mm512_set1_epi32(a); + __m512i vb = _mm512_set1_epi32(b); + + // extension + oth[-1] = oth[0]; + oth[oth_width] = oth[oth_width - 1]; + // lifting step + const si32* sp = oth; + si32* dp = aug; + if (a == 1) + { // 5/3 update and any case with a == 1 + int i = (int)aug_width; + if (ev) + { + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_add_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + } + else + { + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_add_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + } + } + else if (a == -1 && b == 1 && e == 1) + { // 5/3 predict + int i = (int)aug_width; + if (ev) + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i w = _mm512_srai_epi32(t, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + else + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i w = _mm512_srai_epi32(t, e); + d = _mm512_add_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + } + else if (a == -1) + { // any case with a == -1, which is not 5/3 predict + int i = (int)aug_width; + if (ev) + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_sub_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + else + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i v = _mm512_sub_epi32(vb, t); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + } + else { + // general case + int i = (int)aug_width; + if (ev) + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp - 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i u = _mm512_mullo_epi32(va, t); + __m512i v = _mm512_add_epi32(vb, u); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + else + for (; i > 0; i -= 16, sp += 16, dp += 16) + { + __m512i s1 = _mm512_load_si512((__m512i*)sp); + __m512i s2 = _mm512_loadu_si512((__m512i*)(sp + 1)); + __m512i d = _mm512_load_si512((__m512i*)dp); + __m512i t = _mm512_add_epi32(s1, s2); + __m512i u = _mm512_mullo_epi32(va, t); + __m512i v = _mm512_add_epi32(vb, u); + __m512i w = _mm512_srai_epi32(v, e); + d = _mm512_sub_epi32(d, w); + _mm512_store_si512((__m512i*)dp, d); + } + } + + // swap buffers + si32* t = aug; aug = oth; oth = t; + ev = !ev; + ui32 w = aug_width; aug_width = oth_width; oth_width = w; + } + + // combine both lsrc and hsrc into dst + avx512_interleave(dst->f32, lsrc->f32, hsrc->f32, (int)width, even); + } + else { + if (even) + dst->i32[0] = lsrc->i32[0]; + else + dst->i32[0] = hsrc->i32[0] >> 1; + } + } + + } // !local +} // !ojph diff --git a/src/core/transform/ojph_transform_local.h b/src/core/transform/ojph_transform_local.h index 3ba9e6d0..ec2a2e12 100644 --- a/src/core/transform/ojph_transform_local.h +++ b/src/core/transform/ojph_transform_local.h @@ -221,13 +221,11 @@ namespace ojph { ////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////// - // We split multiples of 16 followed by multiples of 8, because - // we assume byte_alignment == 32 #define AVX_DEINTERLEAVE(dpl, dph, sp, width, even) \ { \ if (even) \ { \ - for (; width > 8; width -= 16, sp += 16, dpl += 8, dph += 8) \ + for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8) \ { \ __m256 a = _mm256_load_ps(sp); \ __m256 b = _mm256_load_ps(sp + 8); \ @@ -238,19 +236,10 @@ namespace ojph { _mm256_store_ps(dpl, e); \ _mm256_store_ps(dph, f); \ } \ - for (; width > 0; width -= 8, sp += 8, dpl += 4, dph += 4) \ - { \ - __m128 a = _mm_load_ps(sp); \ - __m128 b = _mm_load_ps(sp + 4); \ - __m128 c = _mm_shuffle_ps(a, b, _MM_SHUFFLE(2, 0, 2, 0)); \ - __m128 d = _mm_shuffle_ps(a, b, _MM_SHUFFLE(3, 1, 3, 1)); \ - _mm_store_ps(dpl, c); \ - _mm_store_ps(dph, d); \ - } \ } \ else \ { \ - for (; width > 8; width -= 16, sp += 16, dpl += 8, dph += 8) \ + for (; width > 0; width -= 16, sp += 16, dpl += 8, dph += 8) \ { \ __m256 a = _mm256_load_ps(sp); \ __m256 b = _mm256_load_ps(sp + 8); \ @@ -261,26 +250,15 @@ namespace ojph { _mm256_store_ps(dpl, f); \ _mm256_store_ps(dph, e); \ } \ - for (; width > 0; width -= 8, sp += 8, dpl += 4, dph += 4) \ - { \ - __m128 a = _mm_load_ps(sp); \ - __m128 b = _mm_load_ps(sp + 4); \ - __m128 c = _mm_shuffle_ps(a, b, _MM_SHUFFLE(2, 0, 2, 0)); \ - __m128 d = _mm_shuffle_ps(a, b, _MM_SHUFFLE(3, 1, 3, 1)); \ - _mm_store_ps(dpl, d); \ - _mm_store_ps(dph, c); \ - } \ } \ } ////////////////////////////////////////////////////////////////////////// - // We split multiples of 16 followed by multiples of 8, because - // we assume byte_alignment == 32 #define AVX_INTERLEAVE(dp, spl, sph, width, even) \ { \ if (even) \ { \ - for (; width > 8; width -= 16, dp += 16, spl += 8, sph += 8) \ + for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8) \ { \ __m256 a = _mm256_load_ps(spl); \ __m256 b = _mm256_load_ps(sph); \ @@ -291,19 +269,10 @@ namespace ojph { _mm256_store_ps(dp, e); \ _mm256_store_ps(dp + 8, f); \ } \ - for (; width > 0; width -= 8, dp += 8, spl += 4, sph += 4) \ - { \ - __m128 a = _mm_load_ps(spl); \ - __m128 b = _mm_load_ps(sph); \ - __m128 c = _mm_unpacklo_ps(a, b); \ - __m128 d = _mm_unpackhi_ps(a, b); \ - _mm_store_ps(dp, c); \ - _mm_store_ps(dp + 4, d); \ - } \ } \ else \ { \ - for (; width > 8; width -= 16, dp += 16, spl += 8, sph += 8) \ + for (; width > 0; width -= 16, dp += 16, spl += 8, sph += 8) \ { \ __m256 a = _mm256_load_ps(spl); \ __m256 b = _mm256_load_ps(sph); \ @@ -314,15 +283,6 @@ namespace ojph { _mm256_store_ps(dp, e); \ _mm256_store_ps(dp + 8, f); \ } \ - for (; width > 0; width -= 8, dp += 8, spl += 4, sph += 4) \ - { \ - __m128 a = _mm_load_ps(spl); \ - __m128 b = _mm_load_ps(sph); \ - __m128 c = _mm_unpacklo_ps(b, a); \ - __m128 d = _mm_unpackhi_ps(b, a); \ - _mm_store_ps(dp, c); \ - _mm_store_ps(dp + 4, d); \ - } \ } \ }