--- doc/config.txt.orig 2021-06-20 15:05:49.000000000 -0600 +++ doc/config.txt 2021-06-23 19:59:29.902142132 -0600 @@ -420,6 +420,7 @@ NTL_AVOID_BRANCHING=off NTL_GF2X_NOINLINE=off NTL_GF2X_ALTCODE=off NTL_GF2X_ALTCODE1=off +NTL_LOADTIME_CPU=off GMP_INCDIR=$(GMP_PREFIX)/include GMP_LIBDIR=$(GMP_PREFIX)/lib @@ -734,6 +735,10 @@ NTL_GF2X_ALTCODE1=off # Yet another alternative implementation for GF2X multiplication. +NTL_LOADTIME_CPU=off + +# switch to check CPU characteristics at load time and use routines +# optimized for the executing CPU. ########## More GMP Options: --- include/NTL/config.h.orig 2021-06-20 15:05:49.000000000 -0600 +++ include/NTL/config.h 2021-06-23 19:59:29.903142133 -0600 @@ -549,6 +549,19 @@ to be defined. Of course, to unset a f #error "NTL_SAFE_VECTORS defined but not NTL_STD_CXX11 or NTL_STD_CXX14" #endif +#if 0 +#define NTL_LOADTIME_CPU + +/* + * With this flag enabled, detect advanced CPU features at load time instead + * of at compile time. This flag is intended for distributions, so that they + * can compile for the lowest common denominator CPU, but still support newer + * CPUs. + * + * This flag is useful only on x86_64 platforms with gcc 4.8 or later. + */ + +#endif --- include/NTL/ctools.h.orig 2021-06-20 15:05:49.000000000 -0600 +++ include/NTL/ctools.h 2021-06-23 19:59:29.904142134 -0600 @@ -518,6 +518,155 @@ char *_ntl_make_aligned(char *p, long al // this should be big enough to satisfy any SIMD instructions, // and it should also be as big as a cache line +/* Determine CPU characteristics at runtime */ +#ifdef NTL_LOADTIME_CPU +#if !defined(__x86_64__) +#error Runtime CPU support is only available on x86_64. +#endif +#ifndef __GNUC__ +#error Runtime CPU support is only available with GCC. +#endif +#if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 6) +#error Runtime CPU support is only available with GCC 4.6 or later. +#endif + +#include +#ifndef bit_SSSE3 +#define bit_SSSE3 (1 << 9) +#endif +#ifndef bit_PCLMUL +#define bit_PCLMUL (1 << 1) +#endif +#ifndef bit_AVX +#define bit_AVX (1 << 28) +#endif +#ifndef bit_FMA +#define bit_FMA (1 << 12) +#endif +#ifndef bit_AVX2 +#define bit_AVX2 (1 << 5) +#endif + +#define BASE_FUNC(type,name) static type name##_base +#define TARGET_FUNC(arch,suffix,type,name) \ + static type __attribute__((target (arch))) name##_##suffix +#define SSSE3_FUNC(type,name) TARGET_FUNC("ssse3",ssse3,type,name) +#define PCLMUL_FUNC(type,name) TARGET_FUNC("pclmul,ssse3",pclmul,type,name) +#define AVX_FUNC(type,name) TARGET_FUNC("avx,pclmul,ssse3",avx,type,name) +#define FMA_FUNC(type,name) TARGET_FUNC("fma,avx,pclmul,ssse3",fma,type,name) +#define AVX2_FUNC(type,name) TARGET_FUNC("avx2,fma,avx,pclmul,ssse3",avx2,type,name) +#define SSSE3_RESOLVER(st,type,name,params) \ + extern "C" { \ + static type (*resolve_##name(void)) params { \ + if (__builtin_expect(have_avx2, 0) < 0) { \ + unsigned int eax, ebx, ecx, edx; \ + if (__get_cpuid(7, &eax, &ebx, &ecx, &edx)) { \ + have_avx2 = ((ebx & bit_AVX2) != 0); \ + } else { \ + have_avx2 = 0; \ + } \ + } \ + if (__builtin_expect(have_ssse3, 0) < 0) { \ + unsigned int eax, ebx, ecx, edx; \ + if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { \ + have_ssse3 = ((ecx & bit_SSSE3) != 0); \ + } else { \ + have_ssse3 = 0; \ + } \ + } \ + if (have_avx2) return &name##_avx2; \ + if (have_ssse3) return &name##_ssse3; \ + return &name##_base; \ + } \ + } \ + st type __attribute__((ifunc ("resolve_" #name))) name params +#define PCLMUL_RESOLVER(st,type,name,params) \ + extern "C" { \ + static type (*resolve_##name(void)) params { \ + if (__builtin_expect(have_pclmul, 0) < 0) { \ + unsigned int eax, ebx, ecx, edx; \ + if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { \ + have_pclmul = ((ecx & bit_PCLMUL) != 0); \ + have_avx = ((ecx & bit_AVX) != 0); \ + have_fma = ((ecx & bit_FMA) != 0); \ + } else { \ + have_pclmul = 0; \ + have_avx = 0; \ + have_fma = 0; \ + } \ + } \ + if (have_avx) return &name##_avx; \ + if (have_pclmul) return &name##_pclmul; \ + return &name##_base; \ + } \ + } \ + st type __attribute__((ifunc ("resolve_" #name))) name params +#define AVX_RESOLVER(st,type,name,params) \ + extern "C" { \ + static type (*resolve_##name(void)) params { \ + if (__builtin_expect(have_pclmul, 0) < 0) { \ + unsigned int eax, ebx, ecx, edx; \ + if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { \ + have_pclmul = ((ecx & bit_PCLMUL) != 0); \ + have_avx = ((ecx & bit_AVX) != 0); \ + have_fma = ((ecx & bit_FMA) != 0); \ + } else { \ + have_pclmul = 0; \ + have_avx = 0; \ + have_fma = 0; \ + } \ + } \ + return have_avx ? &name##_avx : &name##_base; \ + } \ + } \ + st type __attribute__((ifunc ("resolve_" #name))) name params +#define FMA_RESOLVER(st,type,name,params) \ + extern "C" { \ + static type (*resolve_##name(void)) params { \ + if (__builtin_expect(have_pclmul, 0) < 0) { \ + unsigned int eax, ebx, ecx, edx; \ + if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { \ + have_pclmul = ((ecx & bit_PCLMUL) != 0); \ + have_avx = ((ecx & bit_AVX) != 0); \ + have_fma = ((ecx & bit_FMA) != 0); \ + } else { \ + have_pclmul = 0; \ + have_avx = 0; \ + have_fma = 0; \ + } \ + } \ + return have_fma ? &name##_fma : &name##_avx; \ + } \ + } \ + st type __attribute__((ifunc ("resolve_" #name))) name params +#define AVX2_RESOLVER(st,type,name,params) \ + extern "C" { \ + static type (*resolve_##name(void)) params { \ + if (__builtin_expect(have_avx2, 0) < 0) { \ + unsigned int eax, ebx, ecx, edx; \ + if (__get_cpuid(7, &eax, &ebx, &ecx, &edx)) { \ + have_avx2 = ((ebx & bit_AVX2) != 0); \ + } else { \ + have_avx2 = 0; \ + } \ + } \ + if (__builtin_expect(have_pclmul, 0) < 0) { \ + unsigned int eax, ebx, ecx, edx; \ + if (__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { \ + have_pclmul = ((ecx & bit_PCLMUL) != 0); \ + have_avx = ((ecx & bit_AVX) != 0); \ + have_fma = ((ecx & bit_FMA) != 0); \ + } else { \ + have_pclmul = 0; \ + have_avx = 0; \ + have_fma = 0; \ + } \ + } \ + return have_avx2 ? &name##_avx2 : &name##_fma; \ + } \ + } \ + st type __attribute__((ifunc ("resolve_" #name))) name params +#endif #ifdef NTL_HAVE_BUILTIN_CLZL --- include/NTL/MatPrime.h.orig 2021-06-20 15:05:49.000000000 -0600 +++ include/NTL/MatPrime.h 2021-06-23 19:59:29.904142134 -0600 @@ -20,7 +20,7 @@ NTL_OPEN_NNS -#ifdef NTL_HAVE_AVX +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) #define NTL_MatPrime_NBITS (23) #else #define NTL_MatPrime_NBITS NTL_SP_NBITS --- include/NTL/REPORT_ALL_FEATURES.h.orig 2021-06-20 15:05:49.000000000 -0600 +++ include/NTL/REPORT_ALL_FEATURES.h 2021-06-23 19:59:29.905142135 -0600 @@ -63,3 +63,6 @@ std::cerr << "NTL_HAVE_KMA\n"; #endif +#ifdef NTL_LOADTIME_CPU + std::cerr << "NTL_LOADTIME_CPU\n"; +#endif --- src/cfile.orig 2021-06-20 15:05:49.000000000 -0600 +++ src/cfile 2021-06-23 19:59:29.906142136 -0600 @@ -449,6 +449,19 @@ to be defined. Of course, to unset a f #endif +#if @{NTL_LOADTIME_CPU} +#define NTL_LOADTIME_CPU + +/* + * With this flag enabled, detect advanced CPU features at load time instead + * of at compile time. This flag is intended for distributions, so that they + * can compile for the lowest common denominator CPU, but still support newer + * CPUs. + * + * This flag is useful only on x86_64 platforms with gcc 4.8 or later. + */ + +#endif #if @{NTL_CRT_ALTCODE} --- src/DispSettings.cpp.orig 2021-06-20 15:05:49.000000000 -0600 +++ src/DispSettings.cpp 2021-06-23 19:59:29.906142136 -0600 @@ -192,6 +192,9 @@ cout << "Performance Options:\n"; cout << "NTL_RANDOM_AES256CTR\n"; #endif +#ifdef NTL_LOADTIME_CPU + cout << "NTL_LOADTIME_CPU\n"; +#endif cout << "***************************/\n"; cout << "\n\n"; --- src/DoConfig.orig 2021-06-20 15:05:49.000000000 -0600 +++ src/DoConfig 2021-06-23 19:59:29.907142137 -0600 @@ -1,6 +1,7 @@ # This is a perl script, invoked from a shell use warnings; # this doesn't work on older versions of perl +use Config; system("echo '*** CompilerOutput.log ***' > CompilerOutput.log"); @@ -92,6 +93,7 @@ system("echo '*** CompilerOutput.log *** 'NTL_GF2X_NOINLINE' => 'off', 'NTL_GF2X_ALTCODE' => 'off', 'NTL_GF2X_ALTCODE1' => 'off', +'NTL_LOADTIME_CPU' => 'off', 'NTL_RANDOM_AES256CTR' => 'off', @@ -176,6 +178,14 @@ if ($MakeVal{'CXXFLAGS'} =~ '-march=') { $MakeFlag{'NATIVE'} = 'off'; } +# special processing: NTL_LOADTIME_CPU on x86/x86_64 only and => NTL_GF2X_NOINLINE + +if ($ConfigFlag{'NTL_LOADTIME_CPU'} eq 'on') { + if (!$Config{archname} =~ /x86_64/) { + die "Error: NTL_LOADTIME_CPU currently only available with x86_64...sorry\n"; + } + $ConfigFlag{'NTL_GF2X_NOINLINE'} = 'on'; +} # some special MakeVal values that are determined by SHARED --- src/GF2EX.cpp.orig 2021-06-20 15:05:48.000000000 -0600 +++ src/GF2EX.cpp 2021-06-23 19:59:29.908142138 -0600 @@ -801,7 +801,7 @@ void mul(GF2EX& c, const GF2EX& a, const if (GF2E::WordLength() <= 1) use_kron_mul = true; -#if (defined(NTL_GF2X_LIB) && defined(NTL_HAVE_PCLMUL)) +#if (defined(NTL_GF2X_LIB) && (defined(NTL_HAVE_PCLMUL) || defined(NTL_LOADTIME_CPU))) // With gf2x library and pclmul, KronMul is better in a larger range, but // it is very hard to characterize that range. The following is very // conservative. --- src/GF2X1.cpp.orig 2021-06-20 15:05:48.000000000 -0600 +++ src/GF2X1.cpp 2021-06-23 19:59:29.910142141 -0600 @@ -18,7 +18,7 @@ // simple scaling factor for some crossover points: // we use a lower crossover of the underlying multiplication // is faster -#if (defined(NTL_GF2X_LIB) || defined(NTL_HAVE_PCLMUL)) +#if (defined(NTL_GF2X_LIB) || defined(NTL_HAVE_PCLMUL) || defined(NTL_LOADTIME_CPU)) #define XOVER_SCALE (1L) #else #define XOVER_SCALE (2L) --- src/GF2X.cpp.orig 2021-06-20 15:05:48.000000000 -0600 +++ src/GF2X.cpp 2021-06-23 19:59:29.911142142 -0600 @@ -27,6 +27,22 @@ pclmul_mul1 (unsigned long *c, unsigned _mm_storeu_si128((__m128i*)c, _mm_clmulepi64_si128(aa, bb, 0)); } +#elif defined(NTL_LOADTIME_CPU) + +#include + +static int have_pclmul = -1; +static int have_avx = -1; +static int have_fma = -1; + +#define NTL_INLINE inline + +#define pclmul_mul1(c,a,b) do { \ + __m128i aa = _mm_setr_epi64( _mm_cvtsi64_m64(a), _mm_cvtsi64_m64(0)); \ + __m128i bb = _mm_setr_epi64( _mm_cvtsi64_m64(b), _mm_cvtsi64_m64(0)); \ + _mm_storeu_si128((__m128i*)(c), _mm_clmulepi64_si128(aa, bb, 0)); \ +} while (0) + #else @@ -556,6 +572,27 @@ void add(GF2X& x, const GF2X& a, const G +#ifdef NTL_LOADTIME_CPU + +BASE_FUNC(void,mul1)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b) +{ + NTL_EFF_BB_MUL_CODE0 +} + +PCLMUL_FUNC(void,mul1)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b) +{ + pclmul_mul1(c, a, b); +} + +AVX_FUNC(void,mul1)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b) +{ + pclmul_mul1(c, a, b); +} + +PCLMUL_RESOLVER(static,void,mul1,(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b)); + +#else + static NTL_INLINE void mul1(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b) { @@ -568,6 +605,7 @@ NTL_EFF_BB_MUL_CODE0 } +#endif #ifdef NTL_GF2X_NOINLINE @@ -592,6 +630,51 @@ NTL_EFF_BB_MUL_CODE0 #endif +#ifdef NTL_LOADTIME_CPU + +BASE_FUNC(void,Mul1) +(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a) +{ + NTL_EFF_BB_MUL_CODE1 +} + +PCLMUL_FUNC(void,Mul1) +(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a) +{ + long i; + unsigned long carry, prod[2]; + + carry = 0; + for (i = 0; i < sb; i++) { + pclmul_mul1(prod, bp[i], a); + cp[i] = carry ^ prod[0]; + carry = prod[1]; + } + + cp[sb] = carry; +} + +AVX_FUNC(void,Mul1) +(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a) +{ + long i; + unsigned long carry, prod[2]; + + carry = 0; + for (i = 0; i < sb; i++) { + pclmul_mul1(prod, bp[i], a); + cp[i] = carry ^ prod[0]; + carry = prod[1]; + } + + cp[sb] = carry; +} + +PCLMUL_RESOLVER(static,void,Mul1, + (_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a)); + +#else + static void Mul1(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a) { @@ -620,6 +703,53 @@ NTL_EFF_BB_MUL_CODE1 // warning #13200: No EMMS instruction before return } +#endif + +#ifdef NTL_LOADTIME_CPU + +BASE_FUNC(void,AddMul1) +(_ntl_ulong *cp, const _ntl_ulong* bp, long sb, _ntl_ulong a) +{ + NTL_EFF_BB_MUL_CODE2 +} + +PCLMUL_FUNC(void,AddMul1) +(_ntl_ulong *cp, const _ntl_ulong* bp, long sb, _ntl_ulong a) +{ + long i; + unsigned long carry, prod[2]; + + carry = 0; + for (i = 0; i < sb; i++) { + pclmul_mul1(prod, bp[i], a); + cp[i] ^= carry ^ prod[0]; + carry = prod[1]; + } + + cp[sb] ^= carry; +} + +AVX_FUNC(void,AddMul1) +(_ntl_ulong *cp, const _ntl_ulong* bp, long sb, _ntl_ulong a) +{ + long i; + unsigned long carry, prod[2]; + + carry = 0; + for (i = 0; i < sb; i++) { + pclmul_mul1(prod, bp[i], a); + cp[i] ^= carry ^ prod[0]; + carry = prod[1]; + } + + cp[sb] ^= carry; +} + +PCLMUL_RESOLVER(static,void,AddMul1, + (_ntl_ulong *cp, const _ntl_ulong* bp, long sb, _ntl_ulong a)); + +#else + static void AddMul1(_ntl_ulong *cp, const _ntl_ulong* bp, long sb, _ntl_ulong a) { @@ -648,6 +778,52 @@ NTL_EFF_BB_MUL_CODE2 } +#endif + +#ifdef NTL_LOADTIME_CPU + +BASE_FUNC(void,Mul1_short) +(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a) +{ + NTL_EFF_SHORT_BB_MUL_CODE1 +} + +PCLMUL_FUNC(void,Mul1_short) +(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a) +{ + long i; + unsigned long carry, prod[2]; + + carry = 0; + for (i = 0; i < sb; i++) { + pclmul_mul1(prod, bp[i], a); + cp[i] = carry ^ prod[0]; + carry = prod[1]; + } + + cp[sb] = carry; +} + +AVX_FUNC(void,Mul1_short) +(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a) +{ + long i; + unsigned long carry, prod[2]; + + carry = 0; + for (i = 0; i < sb; i++) { + pclmul_mul1(prod, bp[i], a); + cp[i] = carry ^ prod[0]; + carry = prod[1]; + } + + cp[sb] = carry; +} + +PCLMUL_RESOLVER(static,void,Mul1_short, + (_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a)); + +#else static void Mul1_short(_ntl_ulong *cp, const _ntl_ulong *bp, long sb, _ntl_ulong a) @@ -677,9 +853,29 @@ NTL_EFF_SHORT_BB_MUL_CODE1 // warning #13200: No EMMS instruction before return } +#endif +#ifdef NTL_LOADTIME_CPUE +BASE_FUNC(void,mul_half)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b) +{ + NTL_EFF_HALF_BB_MUL_CODE0 +} + +PCLMUL_FUNC(void,mul_half)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b) +{ + pclmul_mul1(c, a, b); +} + +AVX_FUNC(void,mul_half)(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b) +{ + pclmul_mul1(c, a, b); +} + +PCLMUL_RESOLVER(static,void,mul_half,(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b)); + +#else static void mul_half(_ntl_ulong *c, _ntl_ulong a, _ntl_ulong b) @@ -694,6 +890,7 @@ NTL_EFF_HALF_BB_MUL_CODE0 } +#endif // mul2...mul8 hard-code 2x2...8x8 word multiplies. // I adapted these routines from LiDIA (except mul3, see below). @@ -1611,6 +1808,77 @@ static const _ntl_ulong sqrtab[256] = { +#ifdef NTL_LOADTIME_CPU + +BASE_FUNC(void,sqr)(GF2X& c, const GF2X& a) +{ + long sa = a.xrep.length(); + if (sa <= 0) { + clear(c); + return; + } + + c.xrep.SetLength(sa << 1); + _ntl_ulong *cp = c.xrep.elts(); + const _ntl_ulong *ap = a.xrep.elts(); + + for (long i = sa-1; i >= 0; i--) { + _ntl_ulong *c = cp + (i << 1); + _ntl_ulong a = ap[i]; + _ntl_ulong hi, lo; + + NTL_BB_SQR_CODE + + c[0] = lo; + c[1] = hi; + } + + c.normalize(); + return; +} + +PCLMUL_FUNC(void,sqr)(GF2X& c, const GF2X& a) +{ + long sa = a.xrep.length(); + if (sa <= 0) { + clear(c); + return; + } + + c.xrep.SetLength(sa << 1); + _ntl_ulong *cp = c.xrep.elts(); + const _ntl_ulong *ap = a.xrep.elts(); + + for (long i = sa-1; i >= 0; i--) + pclmul_mul1 (cp + (i << 1), ap[i], ap[i]); + + c.normalize(); + return; +} + +AVX_FUNC(void,sqr)(GF2X& c, const GF2X& a) +{ + long sa = a.xrep.length(); + if (sa <= 0) { + clear(c); + return; + } + + c.xrep.SetLength(sa << 1); + _ntl_ulong *cp = c.xrep.elts(); + const _ntl_ulong *ap = a.xrep.elts(); + + for (long i = sa-1; i >= 0; i--) + pclmul_mul1 (cp + (i << 1), ap[i], ap[i]); + + c.normalize(); + return; +} + +PCLMUL_RESOLVER(,void,sqr,(GF2X& c, const GF2X& a)); + +#else + static inline void sqr1(_ntl_ulong *c, _ntl_ulong a) { @@ -1651,6 +1919,7 @@ void sqr(GF2X& c, const GF2X& a) return; } +#endif void LeftShift(GF2X& c, const GF2X& a, long n) --- src/InitSettings.cpp.orig 2021-06-20 15:05:49.000000000 -0600 +++ src/InitSettings.cpp 2021-06-23 19:59:29.912142143 -0600 @@ -190,6 +190,11 @@ int main() cout << "NTL_RANGE_CHECK=0\n"; #endif +#ifdef NTL_LOADTIME_CPU + cout << "NTL_LOADTIME_CPU=1\n"; +#else + cout << "NTL_LOADTIME_CPU=0\n"; +#endif // the following are not actual config flags, but help --- src/mat_lzz_p.cpp.orig 2021-06-20 15:05:48.000000000 -0600 +++ src/mat_lzz_p.cpp 2021-06-23 19:59:29.915142146 -0600 @@ -9,6 +9,15 @@ #ifdef NTL_HAVE_AVX #include +#define AVX_ACTIVE 1 +#elif defined(NTL_LOADTIME_CPU) +#include +#define AVX_ACTIVE have_avx + +static int have_pclmul = -1; +static int have_avx = -1; +static int have_fma = -1; +static int have_avx2 = -1; #endif NTL_START_IMPL @@ -634,7 +643,7 @@ void mul(mat_zz_p& X, const mat_zz_p& A, #ifdef NTL_HAVE_LL_TYPE -#ifdef NTL_HAVE_AVX +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) #define MAX_DBL_INT ((1L << NTL_DOUBLE_PRECISION)-1) // max int representable exactly as a double @@ -648,10 +657,12 @@ void mul(mat_zz_p& X, const mat_zz_p& A, // MUL_ADD(a, b, c): a += b*c +#define FMA_MUL_ADD(a, b, c) a = _mm256_fmadd_pd(b, c, a) +#define AVX_MUL_ADD(a, b, c) a = _mm256_add_pd(a, _mm256_mul_pd(b, c)) #ifdef NTL_HAVE_FMA -#define MUL_ADD(a, b, c) a = _mm256_fmadd_pd(b, c, a) +#define MUL_ADD(a, b, c) FMA_MUL_ADD(a, b, c) #else -#define MUL_ADD(a, b, c) a = _mm256_add_pd(a, _mm256_mul_pd(b, c)) +#define MUL_ADD(a, b, c) AVX_MUL_ADD(a, b, c) #endif @@ -931,6 +942,94 @@ void muladd3_by_16(double *x, const doub #else +#if defined(NTL_LOADTIME_CPU) + +AVX_FUNC(void,muladd1_by_32) +(double *x, const double *a, const double *b, long n) +{ + __m256d avec, bvec; + + + __m256d acc0=_mm256_load_pd(x + 0*4); + __m256d acc1=_mm256_load_pd(x + 1*4); + __m256d acc2=_mm256_load_pd(x + 2*4); + __m256d acc3=_mm256_load_pd(x + 3*4); + __m256d acc4=_mm256_load_pd(x + 4*4); + __m256d acc5=_mm256_load_pd(x + 5*4); + __m256d acc6=_mm256_load_pd(x + 6*4); + __m256d acc7=_mm256_load_pd(x + 7*4); + + + for (long i = 0; i < n; i++) { + avec = _mm256_broadcast_sd(a); a++; + + + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc0, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc1, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc2, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc3, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc4, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc5, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc6, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc7, avec, bvec); + } + + + _mm256_store_pd(x + 0*4, acc0); + _mm256_store_pd(x + 1*4, acc1); + _mm256_store_pd(x + 2*4, acc2); + _mm256_store_pd(x + 3*4, acc3); + _mm256_store_pd(x + 4*4, acc4); + _mm256_store_pd(x + 5*4, acc5); + _mm256_store_pd(x + 6*4, acc6); + _mm256_store_pd(x + 7*4, acc7); +} + +FMA_FUNC(void,muladd1_by_32) +(double *x, const double *a, const double *b, long n) +{ + __m256d avec, bvec; + + + __m256d acc0=_mm256_load_pd(x + 0*4); + __m256d acc1=_mm256_load_pd(x + 1*4); + __m256d acc2=_mm256_load_pd(x + 2*4); + __m256d acc3=_mm256_load_pd(x + 3*4); + __m256d acc4=_mm256_load_pd(x + 4*4); + __m256d acc5=_mm256_load_pd(x + 5*4); + __m256d acc6=_mm256_load_pd(x + 6*4); + __m256d acc7=_mm256_load_pd(x + 7*4); + + + for (long i = 0; i < n; i++) { + avec = _mm256_broadcast_sd(a); a++; + + + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc0, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc1, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc2, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc3, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc4, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc5, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc6, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc7, avec, bvec); + } + + + _mm256_store_pd(x + 0*4, acc0); + _mm256_store_pd(x + 1*4, acc1); + _mm256_store_pd(x + 2*4, acc2); + _mm256_store_pd(x + 3*4, acc3); + _mm256_store_pd(x + 4*4, acc4); + _mm256_store_pd(x + 5*4, acc5); + _mm256_store_pd(x + 6*4, acc6); + _mm256_store_pd(x + 7*4, acc7); +} + +FMA_RESOLVER(static,void,muladd1_by_32, + (double *x, const double *a, const double *b, long n)); + +#else static void muladd1_by_32(double *x, const double *a, const double *b, long n) @@ -973,6 +1072,167 @@ void muladd1_by_32(double *x, const doub _mm256_store_pd(x + 7*4, acc7); } +#endif + +#ifdef NTL_LOADTIME_CPU + +AVX_FUNC(void,muladd2_by_32) +(double *x, const double *a, const double *b, long n) +{ + __m256d avec0, avec1, bvec; + __m256d acc00, acc01, acc02, acc03; + __m256d acc10, acc11, acc12, acc13; + + + // round 0 + + acc00=_mm256_load_pd(x + 0*4 + 0*MAT_BLK_SZ); + acc01=_mm256_load_pd(x + 1*4 + 0*MAT_BLK_SZ); + acc02=_mm256_load_pd(x + 2*4 + 0*MAT_BLK_SZ); + acc03=_mm256_load_pd(x + 3*4 + 0*MAT_BLK_SZ); + + acc10=_mm256_load_pd(x + 0*4 + 1*MAT_BLK_SZ); + acc11=_mm256_load_pd(x + 1*4 + 1*MAT_BLK_SZ); + acc12=_mm256_load_pd(x + 2*4 + 1*MAT_BLK_SZ); + acc13=_mm256_load_pd(x + 3*4 + 1*MAT_BLK_SZ); + + for (long i = 0; i < n; i++) { + avec0 = _mm256_broadcast_sd(&a[i]); + avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]); + + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); AVX_MUL_ADD(acc00, avec0, bvec); AVX_MUL_ADD(acc10, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); AVX_MUL_ADD(acc01, avec0, bvec); AVX_MUL_ADD(acc11, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); AVX_MUL_ADD(acc02, avec0, bvec); AVX_MUL_ADD(acc12, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); AVX_MUL_ADD(acc03, avec0, bvec); AVX_MUL_ADD(acc13, avec1, bvec); + } + + + _mm256_store_pd(x + 0*4 + 0*MAT_BLK_SZ, acc00); + _mm256_store_pd(x + 1*4 + 0*MAT_BLK_SZ, acc01); + _mm256_store_pd(x + 2*4 + 0*MAT_BLK_SZ, acc02); + _mm256_store_pd(x + 3*4 + 0*MAT_BLK_SZ, acc03); + + _mm256_store_pd(x + 0*4 + 1*MAT_BLK_SZ, acc10); + _mm256_store_pd(x + 1*4 + 1*MAT_BLK_SZ, acc11); + _mm256_store_pd(x + 2*4 + 1*MAT_BLK_SZ, acc12); + _mm256_store_pd(x + 3*4 + 1*MAT_BLK_SZ, acc13); + + // round 1 + + acc00=_mm256_load_pd(x + 4*4 + 0*MAT_BLK_SZ); + acc01=_mm256_load_pd(x + 5*4 + 0*MAT_BLK_SZ); + acc02=_mm256_load_pd(x + 6*4 + 0*MAT_BLK_SZ); + acc03=_mm256_load_pd(x + 7*4 + 0*MAT_BLK_SZ); + + acc10=_mm256_load_pd(x + 4*4 + 1*MAT_BLK_SZ); + acc11=_mm256_load_pd(x + 5*4 + 1*MAT_BLK_SZ); + acc12=_mm256_load_pd(x + 6*4 + 1*MAT_BLK_SZ); + acc13=_mm256_load_pd(x + 7*4 + 1*MAT_BLK_SZ); + + for (long i = 0; i < n; i++) { + avec0 = _mm256_broadcast_sd(&a[i]); + avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]); + + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4+MAT_BLK_SZ/2]); AVX_MUL_ADD(acc00, avec0, bvec); AVX_MUL_ADD(acc10, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4+MAT_BLK_SZ/2]); AVX_MUL_ADD(acc01, avec0, bvec); AVX_MUL_ADD(acc11, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4+MAT_BLK_SZ/2]); AVX_MUL_ADD(acc02, avec0, bvec); AVX_MUL_ADD(acc12, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4+MAT_BLK_SZ/2]); AVX_MUL_ADD(acc03, avec0, bvec); AVX_MUL_ADD(acc13, avec1, bvec); + } + + + _mm256_store_pd(x + 4*4 + 0*MAT_BLK_SZ, acc00); + _mm256_store_pd(x + 5*4 + 0*MAT_BLK_SZ, acc01); + _mm256_store_pd(x + 6*4 + 0*MAT_BLK_SZ, acc02); + _mm256_store_pd(x + 7*4 + 0*MAT_BLK_SZ, acc03); + + _mm256_store_pd(x + 4*4 + 1*MAT_BLK_SZ, acc10); + _mm256_store_pd(x + 5*4 + 1*MAT_BLK_SZ, acc11); + _mm256_store_pd(x + 6*4 + 1*MAT_BLK_SZ, acc12); + _mm256_store_pd(x + 7*4 + 1*MAT_BLK_SZ, acc13); + +} + +FMA_FUNC(void,muladd2_by_32) +(double *x, const double *a, const double *b, long n) +{ + __m256d avec0, avec1, bvec; + __m256d acc00, acc01, acc02, acc03; + __m256d acc10, acc11, acc12, acc13; + + + // round 0 + + acc00=_mm256_load_pd(x + 0*4 + 0*MAT_BLK_SZ); + acc01=_mm256_load_pd(x + 1*4 + 0*MAT_BLK_SZ); + acc02=_mm256_load_pd(x + 2*4 + 0*MAT_BLK_SZ); + acc03=_mm256_load_pd(x + 3*4 + 0*MAT_BLK_SZ); + + acc10=_mm256_load_pd(x + 0*4 + 1*MAT_BLK_SZ); + acc11=_mm256_load_pd(x + 1*4 + 1*MAT_BLK_SZ); + acc12=_mm256_load_pd(x + 2*4 + 1*MAT_BLK_SZ); + acc13=_mm256_load_pd(x + 3*4 + 1*MAT_BLK_SZ); + + for (long i = 0; i < n; i++) { + avec0 = _mm256_broadcast_sd(&a[i]); + avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]); + + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); + } + + + _mm256_store_pd(x + 0*4 + 0*MAT_BLK_SZ, acc00); + _mm256_store_pd(x + 1*4 + 0*MAT_BLK_SZ, acc01); + _mm256_store_pd(x + 2*4 + 0*MAT_BLK_SZ, acc02); + _mm256_store_pd(x + 3*4 + 0*MAT_BLK_SZ, acc03); + + _mm256_store_pd(x + 0*4 + 1*MAT_BLK_SZ, acc10); + _mm256_store_pd(x + 1*4 + 1*MAT_BLK_SZ, acc11); + _mm256_store_pd(x + 2*4 + 1*MAT_BLK_SZ, acc12); + _mm256_store_pd(x + 3*4 + 1*MAT_BLK_SZ, acc13); + + // round 1 + + acc00=_mm256_load_pd(x + 4*4 + 0*MAT_BLK_SZ); + acc01=_mm256_load_pd(x + 5*4 + 0*MAT_BLK_SZ); + acc02=_mm256_load_pd(x + 6*4 + 0*MAT_BLK_SZ); + acc03=_mm256_load_pd(x + 7*4 + 0*MAT_BLK_SZ); + + acc10=_mm256_load_pd(x + 4*4 + 1*MAT_BLK_SZ); + acc11=_mm256_load_pd(x + 5*4 + 1*MAT_BLK_SZ); + acc12=_mm256_load_pd(x + 6*4 + 1*MAT_BLK_SZ); + acc13=_mm256_load_pd(x + 7*4 + 1*MAT_BLK_SZ); + + for (long i = 0; i < n; i++) { + avec0 = _mm256_broadcast_sd(&a[i]); + avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]); + + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); + } + + + _mm256_store_pd(x + 4*4 + 0*MAT_BLK_SZ, acc00); + _mm256_store_pd(x + 5*4 + 0*MAT_BLK_SZ, acc01); + _mm256_store_pd(x + 6*4 + 0*MAT_BLK_SZ, acc02); + _mm256_store_pd(x + 7*4 + 0*MAT_BLK_SZ, acc03); + + _mm256_store_pd(x + 4*4 + 1*MAT_BLK_SZ, acc10); + _mm256_store_pd(x + 5*4 + 1*MAT_BLK_SZ, acc11); + _mm256_store_pd(x + 6*4 + 1*MAT_BLK_SZ, acc12); + _mm256_store_pd(x + 7*4 + 1*MAT_BLK_SZ, acc13); + +} + +FMA_RESOLVER(static,void,muladd2_by_32, + (double *x, const double *a, const double *b, long n)); + +#else + static void muladd2_by_32(double *x, const double *a, const double *b, long n) { @@ -1049,6 +1309,212 @@ void muladd2_by_32(double *x, const doub } +#endif + +#ifdef NTL_LOADTIME_CPU +FMA_FUNC(void,muladd3_by_32) +(double *x, const double *a, const double *b, long n) +{ + __m256d avec0, avec1, avec2, bvec; + __m256d acc00, acc01, acc02, acc03; + __m256d acc10, acc11, acc12, acc13; + __m256d acc20, acc21, acc22, acc23; + + + // round 0 + + acc00=_mm256_load_pd(x + 0*4 + 0*MAT_BLK_SZ); + acc01=_mm256_load_pd(x + 1*4 + 0*MAT_BLK_SZ); + acc02=_mm256_load_pd(x + 2*4 + 0*MAT_BLK_SZ); + acc03=_mm256_load_pd(x + 3*4 + 0*MAT_BLK_SZ); + + acc10=_mm256_load_pd(x + 0*4 + 1*MAT_BLK_SZ); + acc11=_mm256_load_pd(x + 1*4 + 1*MAT_BLK_SZ); + acc12=_mm256_load_pd(x + 2*4 + 1*MAT_BLK_SZ); + acc13=_mm256_load_pd(x + 3*4 + 1*MAT_BLK_SZ); + + acc20=_mm256_load_pd(x + 0*4 + 2*MAT_BLK_SZ); + acc21=_mm256_load_pd(x + 1*4 + 2*MAT_BLK_SZ); + acc22=_mm256_load_pd(x + 2*4 + 2*MAT_BLK_SZ); + acc23=_mm256_load_pd(x + 3*4 + 2*MAT_BLK_SZ); + + for (long i = 0; i < n; i++) { + avec0 = _mm256_broadcast_sd(&a[i]); + avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]); + avec2 = _mm256_broadcast_sd(&a[i+2*MAT_BLK_SZ]); + + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); FMA_MUL_ADD(acc20, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); FMA_MUL_ADD(acc21, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); FMA_MUL_ADD(acc22, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); FMA_MUL_ADD(acc23, avec2, bvec); + } + + + _mm256_store_pd(x + 0*4 + 0*MAT_BLK_SZ, acc00); + _mm256_store_pd(x + 1*4 + 0*MAT_BLK_SZ, acc01); + _mm256_store_pd(x + 2*4 + 0*MAT_BLK_SZ, acc02); + _mm256_store_pd(x + 3*4 + 0*MAT_BLK_SZ, acc03); + + _mm256_store_pd(x + 0*4 + 1*MAT_BLK_SZ, acc10); + _mm256_store_pd(x + 1*4 + 1*MAT_BLK_SZ, acc11); + _mm256_store_pd(x + 2*4 + 1*MAT_BLK_SZ, acc12); + _mm256_store_pd(x + 3*4 + 1*MAT_BLK_SZ, acc13); + + _mm256_store_pd(x + 0*4 + 2*MAT_BLK_SZ, acc20); + _mm256_store_pd(x + 1*4 + 2*MAT_BLK_SZ, acc21); + _mm256_store_pd(x + 2*4 + 2*MAT_BLK_SZ, acc22); + _mm256_store_pd(x + 3*4 + 2*MAT_BLK_SZ, acc23); + + // round 1 + + acc00=_mm256_load_pd(x + 4*4 + 0*MAT_BLK_SZ); + acc01=_mm256_load_pd(x + 5*4 + 0*MAT_BLK_SZ); + acc02=_mm256_load_pd(x + 6*4 + 0*MAT_BLK_SZ); + acc03=_mm256_load_pd(x + 7*4 + 0*MAT_BLK_SZ); + + acc10=_mm256_load_pd(x + 4*4 + 1*MAT_BLK_SZ); + acc11=_mm256_load_pd(x + 5*4 + 1*MAT_BLK_SZ); + acc12=_mm256_load_pd(x + 6*4 + 1*MAT_BLK_SZ); + acc13=_mm256_load_pd(x + 7*4 + 1*MAT_BLK_SZ); + + acc20=_mm256_load_pd(x + 4*4 + 2*MAT_BLK_SZ); + acc21=_mm256_load_pd(x + 5*4 + 2*MAT_BLK_SZ); + acc22=_mm256_load_pd(x + 6*4 + 2*MAT_BLK_SZ); + acc23=_mm256_load_pd(x + 7*4 + 2*MAT_BLK_SZ); + + for (long i = 0; i < n; i++) { + avec0 = _mm256_broadcast_sd(&a[i]); + avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]); + avec2 = _mm256_broadcast_sd(&a[i+2*MAT_BLK_SZ]); + + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); FMA_MUL_ADD(acc20, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); FMA_MUL_ADD(acc21, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); FMA_MUL_ADD(acc22, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); FMA_MUL_ADD(acc23, avec2, bvec); + } + + + _mm256_store_pd(x + 4*4 + 0*MAT_BLK_SZ, acc00); + _mm256_store_pd(x + 5*4 + 0*MAT_BLK_SZ, acc01); + _mm256_store_pd(x + 6*4 + 0*MAT_BLK_SZ, acc02); + _mm256_store_pd(x + 7*4 + 0*MAT_BLK_SZ, acc03); + + _mm256_store_pd(x + 4*4 + 1*MAT_BLK_SZ, acc10); + _mm256_store_pd(x + 5*4 + 1*MAT_BLK_SZ, acc11); + _mm256_store_pd(x + 6*4 + 1*MAT_BLK_SZ, acc12); + _mm256_store_pd(x + 7*4 + 1*MAT_BLK_SZ, acc13); + + _mm256_store_pd(x + 4*4 + 2*MAT_BLK_SZ, acc20); + _mm256_store_pd(x + 5*4 + 2*MAT_BLK_SZ, acc21); + _mm256_store_pd(x + 6*4 + 2*MAT_BLK_SZ, acc22); + _mm256_store_pd(x + 7*4 + 2*MAT_BLK_SZ, acc23); + +} + +AVX2_FUNC(void,muladd3_by_32) +(double *x, const double *a, const double *b, long n) +{ + __m256d avec0, avec1, avec2, bvec; + __m256d acc00, acc01, acc02, acc03; + __m256d acc10, acc11, acc12, acc13; + __m256d acc20, acc21, acc22, acc23; + + + // round 0 + + acc00=_mm256_load_pd(x + 0*4 + 0*MAT_BLK_SZ); + acc01=_mm256_load_pd(x + 1*4 + 0*MAT_BLK_SZ); + acc02=_mm256_load_pd(x + 2*4 + 0*MAT_BLK_SZ); + acc03=_mm256_load_pd(x + 3*4 + 0*MAT_BLK_SZ); + + acc10=_mm256_load_pd(x + 0*4 + 1*MAT_BLK_SZ); + acc11=_mm256_load_pd(x + 1*4 + 1*MAT_BLK_SZ); + acc12=_mm256_load_pd(x + 2*4 + 1*MAT_BLK_SZ); + acc13=_mm256_load_pd(x + 3*4 + 1*MAT_BLK_SZ); + + acc20=_mm256_load_pd(x + 0*4 + 2*MAT_BLK_SZ); + acc21=_mm256_load_pd(x + 1*4 + 2*MAT_BLK_SZ); + acc22=_mm256_load_pd(x + 2*4 + 2*MAT_BLK_SZ); + acc23=_mm256_load_pd(x + 3*4 + 2*MAT_BLK_SZ); + + for (long i = 0; i < n; i++) { + avec0 = _mm256_broadcast_sd(&a[i]); + avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]); + avec2 = _mm256_broadcast_sd(&a[i+2*MAT_BLK_SZ]); + + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); FMA_MUL_ADD(acc20, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); FMA_MUL_ADD(acc21, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); FMA_MUL_ADD(acc22, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); FMA_MUL_ADD(acc23, avec2, bvec); + } + + + _mm256_store_pd(x + 0*4 + 0*MAT_BLK_SZ, acc00); + _mm256_store_pd(x + 1*4 + 0*MAT_BLK_SZ, acc01); + _mm256_store_pd(x + 2*4 + 0*MAT_BLK_SZ, acc02); + _mm256_store_pd(x + 3*4 + 0*MAT_BLK_SZ, acc03); + + _mm256_store_pd(x + 0*4 + 1*MAT_BLK_SZ, acc10); + _mm256_store_pd(x + 1*4 + 1*MAT_BLK_SZ, acc11); + _mm256_store_pd(x + 2*4 + 1*MAT_BLK_SZ, acc12); + _mm256_store_pd(x + 3*4 + 1*MAT_BLK_SZ, acc13); + + _mm256_store_pd(x + 0*4 + 2*MAT_BLK_SZ, acc20); + _mm256_store_pd(x + 1*4 + 2*MAT_BLK_SZ, acc21); + _mm256_store_pd(x + 2*4 + 2*MAT_BLK_SZ, acc22); + _mm256_store_pd(x + 3*4 + 2*MAT_BLK_SZ, acc23); + + // round 1 + + acc00=_mm256_load_pd(x + 4*4 + 0*MAT_BLK_SZ); + acc01=_mm256_load_pd(x + 5*4 + 0*MAT_BLK_SZ); + acc02=_mm256_load_pd(x + 6*4 + 0*MAT_BLK_SZ); + acc03=_mm256_load_pd(x + 7*4 + 0*MAT_BLK_SZ); + + acc10=_mm256_load_pd(x + 4*4 + 1*MAT_BLK_SZ); + acc11=_mm256_load_pd(x + 5*4 + 1*MAT_BLK_SZ); + acc12=_mm256_load_pd(x + 6*4 + 1*MAT_BLK_SZ); + acc13=_mm256_load_pd(x + 7*4 + 1*MAT_BLK_SZ); + + acc20=_mm256_load_pd(x + 4*4 + 2*MAT_BLK_SZ); + acc21=_mm256_load_pd(x + 5*4 + 2*MAT_BLK_SZ); + acc22=_mm256_load_pd(x + 6*4 + 2*MAT_BLK_SZ); + acc23=_mm256_load_pd(x + 7*4 + 2*MAT_BLK_SZ); + + for (long i = 0; i < n; i++) { + avec0 = _mm256_broadcast_sd(&a[i]); + avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]); + avec2 = _mm256_broadcast_sd(&a[i+2*MAT_BLK_SZ]); + + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); FMA_MUL_ADD(acc20, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); FMA_MUL_ADD(acc21, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); FMA_MUL_ADD(acc22, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4+MAT_BLK_SZ/2]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); FMA_MUL_ADD(acc23, avec2, bvec); + } + + + _mm256_store_pd(x + 4*4 + 0*MAT_BLK_SZ, acc00); + _mm256_store_pd(x + 5*4 + 0*MAT_BLK_SZ, acc01); + _mm256_store_pd(x + 6*4 + 0*MAT_BLK_SZ, acc02); + _mm256_store_pd(x + 7*4 + 0*MAT_BLK_SZ, acc03); + + _mm256_store_pd(x + 4*4 + 1*MAT_BLK_SZ, acc10); + _mm256_store_pd(x + 5*4 + 1*MAT_BLK_SZ, acc11); + _mm256_store_pd(x + 6*4 + 1*MAT_BLK_SZ, acc12); + _mm256_store_pd(x + 7*4 + 1*MAT_BLK_SZ, acc13); + + _mm256_store_pd(x + 4*4 + 2*MAT_BLK_SZ, acc20); + _mm256_store_pd(x + 5*4 + 2*MAT_BLK_SZ, acc21); + _mm256_store_pd(x + 6*4 + 2*MAT_BLK_SZ, acc22); + _mm256_store_pd(x + 7*4 + 2*MAT_BLK_SZ, acc23); + +} + +AVX2_RESOLVER(static,void,muladd3_by_32, + (double *x, const double *a, const double *b, long n)); + +#else + // NOTE: this makes things slower on an AVX1 platform --- not enough registers // it could be faster on AVX2/FMA, where there should be enough registers static @@ -1150,6 +1616,75 @@ void muladd3_by_32(double *x, const doub } +#endif + +#ifdef NTL_LOADTIME_CPU + +AVX_FUNC(void,muladd1_by_16) +(double *x, const double *a, const double *b, long n) +{ + __m256d avec, bvec; + + + __m256d acc0=_mm256_load_pd(x + 0*4); + __m256d acc1=_mm256_load_pd(x + 1*4); + __m256d acc2=_mm256_load_pd(x + 2*4); + __m256d acc3=_mm256_load_pd(x + 3*4); + + + for (long i = 0; i < n; i++) { + avec = _mm256_broadcast_sd(a); a++; + + + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc0, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc1, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc2, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; AVX_MUL_ADD(acc3, avec, bvec); + b += 16; + } + + + _mm256_store_pd(x + 0*4, acc0); + _mm256_store_pd(x + 1*4, acc1); + _mm256_store_pd(x + 2*4, acc2); + _mm256_store_pd(x + 3*4, acc3); +} + +FMA_FUNC(void,muladd1_by_16) +(double *x, const double *a, const double *b, long n) +{ + __m256d avec, bvec; + + + __m256d acc0=_mm256_load_pd(x + 0*4); + __m256d acc1=_mm256_load_pd(x + 1*4); + __m256d acc2=_mm256_load_pd(x + 2*4); + __m256d acc3=_mm256_load_pd(x + 3*4); + + + for (long i = 0; i < n; i++) { + avec = _mm256_broadcast_sd(a); a++; + + + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc0, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc1, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc2, avec, bvec); + bvec = _mm256_load_pd(b); b += 4; FMA_MUL_ADD(acc3, avec, bvec); + b += 16; + } + + + _mm256_store_pd(x + 0*4, acc0); + _mm256_store_pd(x + 1*4, acc1); + _mm256_store_pd(x + 2*4, acc2); + _mm256_store_pd(x + 3*4, acc3); +} + +FMA_RESOLVER(static,void,muladd1_by_16, + (double *x, const double *a, const double *b, long n)); + +#else + static void muladd1_by_16(double *x, const double *a, const double *b, long n) { @@ -1180,10 +1715,11 @@ void muladd1_by_16(double *x, const doub _mm256_store_pd(x + 3*4, acc3); } +#endif -static -void muladd2_by_16(double *x, const double *a, const double *b, long n) +static void __attribute__((target ("avx,pclmul"))) +muladd2_by_16(double *x, const double *a, const double *b, long n) { __m256d avec0, avec1, bvec; __m256d acc00, acc01, acc02, acc03; @@ -1206,10 +1742,10 @@ void muladd2_by_16(double *x, const doub avec0 = _mm256_broadcast_sd(&a[i]); avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]); - bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); MUL_ADD(acc00, avec0, bvec); MUL_ADD(acc10, avec1, bvec); - bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); MUL_ADD(acc01, avec0, bvec); MUL_ADD(acc11, avec1, bvec); - bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); MUL_ADD(acc02, avec0, bvec); MUL_ADD(acc12, avec1, bvec); - bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); MUL_ADD(acc03, avec0, bvec); MUL_ADD(acc13, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); AVX_MUL_ADD(acc00, avec0, bvec); AVX_MUL_ADD(acc10, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); AVX_MUL_ADD(acc01, avec0, bvec); AVX_MUL_ADD(acc11, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); AVX_MUL_ADD(acc02, avec0, bvec); AVX_MUL_ADD(acc12, avec1, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); AVX_MUL_ADD(acc03, avec0, bvec); AVX_MUL_ADD(acc13, avec1, bvec); } @@ -1226,8 +1762,8 @@ void muladd2_by_16(double *x, const doub } -static -void muladd3_by_16(double *x, const double *a, const double *b, long n) +static void __attribute__((target("fma,pclmul"))) +muladd3_by_16(double *x, const double *a, const double *b, long n) { __m256d avec0, avec1, avec2, bvec; __m256d acc00, acc01, acc02, acc03; @@ -1257,10 +1793,10 @@ void muladd3_by_16(double *x, const doub avec1 = _mm256_broadcast_sd(&a[i+MAT_BLK_SZ]); avec2 = _mm256_broadcast_sd(&a[i+2*MAT_BLK_SZ]); - bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); MUL_ADD(acc00, avec0, bvec); MUL_ADD(acc10, avec1, bvec); MUL_ADD(acc20, avec2, bvec); - bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); MUL_ADD(acc01, avec0, bvec); MUL_ADD(acc11, avec1, bvec); MUL_ADD(acc21, avec2, bvec); - bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); MUL_ADD(acc02, avec0, bvec); MUL_ADD(acc12, avec1, bvec); MUL_ADD(acc22, avec2, bvec); - bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); MUL_ADD(acc03, avec0, bvec); MUL_ADD(acc13, avec1, bvec); MUL_ADD(acc23, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+0*4]); FMA_MUL_ADD(acc00, avec0, bvec); FMA_MUL_ADD(acc10, avec1, bvec); FMA_MUL_ADD(acc20, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+1*4]); FMA_MUL_ADD(acc01, avec0, bvec); FMA_MUL_ADD(acc11, avec1, bvec); FMA_MUL_ADD(acc21, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+2*4]); FMA_MUL_ADD(acc02, avec0, bvec); FMA_MUL_ADD(acc12, avec1, bvec); FMA_MUL_ADD(acc22, avec2, bvec); + bvec = _mm256_load_pd(&b[i*MAT_BLK_SZ+3*4]); FMA_MUL_ADD(acc03, avec0, bvec); FMA_MUL_ADD(acc13, avec1, bvec); FMA_MUL_ADD(acc23, avec2, bvec); } @@ -1289,6 +1825,29 @@ void muladd3_by_16(double *x, const doub +#ifdef NTL_LOADTIME_CPU +static inline +void muladd_all_by_32(long first, long last, double *x, const double *a, const double *b, long n) +{ + long i = first; + + if (have_fma) { + // process three rows at a time + for (; i <= last-3; i+=3) + muladd3_by_32(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n); + for (; i < last; i++) + muladd1_by_32(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n); + } else { + // process only two rows at a time: not enough registers :-( + for (; i <= last-2; i+=2) + muladd2_by_32(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n); + for (; i < last; i++) + muladd1_by_32(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n); + } +} + +#else + static inline void muladd_all_by_32(long first, long last, double *x, const double *a, const double *b, long n) { @@ -1308,6 +1867,30 @@ void muladd_all_by_32(long first, long l #endif } +#endif + +#ifdef NTL_LOADTIME_CPU + +static inline +void muladd_all_by_16(long first, long last, double *x, const double *a, const double *b, long n) +{ + long i = first; + if (have_fma) { + // processing three rows at a time is faster + for (; i <= last-3; i+=3) + muladd3_by_16(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n); + for (; i < last; i++) + muladd1_by_16(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n); + } else { + // process only two rows at a time: not enough registers :-( + for (; i <= last-2; i+=2) + muladd2_by_16(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n); + for (; i < last; i++) + muladd1_by_16(x + i*MAT_BLK_SZ, a + i*MAT_BLK_SZ, b, n); + } +} + +#else static inline void muladd_all_by_16(long first, long last, double *x, const double *a, const double *b, long n) @@ -1328,6 +1911,8 @@ void muladd_all_by_16(long first, long l #endif } +#endif + static inline void muladd_all_by_32_width(long first, long last, double *x, const double *a, const double *b, long n, long width) { @@ -1343,6 +1928,74 @@ void muladd_all_by_32_width(long first, // this assumes n is a multiple of 16 +#ifdef NTL_LOADTIME_CPU +AVX_FUNC(void,muladd_interval) +(double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n) +{ + __m256d xvec0, xvec1, xvec2, xvec3; + __m256d yvec0, yvec1, yvec2, yvec3; + + __m256d cvec = _mm256_broadcast_sd(&c); + + for (long i = 0; i < n; i += 16, x += 16, y += 16) { + xvec0 = _mm256_load_pd(x+0*4); + xvec1 = _mm256_load_pd(x+1*4); + xvec2 = _mm256_load_pd(x+2*4); + xvec3 = _mm256_load_pd(x+3*4); + + yvec0 = _mm256_load_pd(y+0*4); + yvec1 = _mm256_load_pd(y+1*4); + yvec2 = _mm256_load_pd(y+2*4); + yvec3 = _mm256_load_pd(y+3*4); + + AVX_MUL_ADD(xvec0, yvec0, cvec); + AVX_MUL_ADD(xvec1, yvec1, cvec); + AVX_MUL_ADD(xvec2, yvec2, cvec); + AVX_MUL_ADD(xvec3, yvec3, cvec); + + _mm256_store_pd(x + 0*4, xvec0); + _mm256_store_pd(x + 1*4, xvec1); + _mm256_store_pd(x + 2*4, xvec2); + _mm256_store_pd(x + 3*4, xvec3); + } +} + +FMA_FUNC(void,muladd_interval) +(double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n) +{ + __m256d xvec0, xvec1, xvec2, xvec3; + __m256d yvec0, yvec1, yvec2, yvec3; + + __m256d cvec = _mm256_broadcast_sd(&c); + + for (long i = 0; i < n; i += 16, x += 16, y += 16) { + xvec0 = _mm256_load_pd(x+0*4); + xvec1 = _mm256_load_pd(x+1*4); + xvec2 = _mm256_load_pd(x+2*4); + xvec3 = _mm256_load_pd(x+3*4); + + yvec0 = _mm256_load_pd(y+0*4); + yvec1 = _mm256_load_pd(y+1*4); + yvec2 = _mm256_load_pd(y+2*4); + yvec3 = _mm256_load_pd(y+3*4); + + FMA_MUL_ADD(xvec0, yvec0, cvec); + FMA_MUL_ADD(xvec1, yvec1, cvec); + FMA_MUL_ADD(xvec2, yvec2, cvec); + FMA_MUL_ADD(xvec3, yvec3, cvec); + + _mm256_store_pd(x + 0*4, xvec0); + _mm256_store_pd(x + 1*4, xvec1); + _mm256_store_pd(x + 2*4, xvec2); + _mm256_store_pd(x + 3*4, xvec3); + } +} + +FMA_RESOLVER(static,void,muladd_interval, + (double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n)); + +#else + static inline void muladd_interval(double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n) { @@ -1374,6 +2027,106 @@ void muladd_interval(double * NTL_RESTRI } } +#endif + +#ifdef NTL_LOADTIME_CPU +AVX_FUNC(void,muladd_interval1) +(double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n) +{ + + __m256d xvec0, xvec1, xvec2, xvec3; + __m256d yvec0, yvec1, yvec2, yvec3; + __m256d cvec; + + if (n >= 4) + cvec = _mm256_broadcast_sd(&c); + + long i=0; + for (; i <= n-16; i += 16, x += 16, y += 16) { + xvec0 = _mm256_load_pd(x+0*4); + xvec1 = _mm256_load_pd(x+1*4); + xvec2 = _mm256_load_pd(x+2*4); + xvec3 = _mm256_load_pd(x+3*4); + + yvec0 = _mm256_load_pd(y+0*4); + yvec1 = _mm256_load_pd(y+1*4); + yvec2 = _mm256_load_pd(y+2*4); + yvec3 = _mm256_load_pd(y+3*4); + + AVX_MUL_ADD(xvec0, yvec0, cvec); + AVX_MUL_ADD(xvec1, yvec1, cvec); + AVX_MUL_ADD(xvec2, yvec2, cvec); + AVX_MUL_ADD(xvec3, yvec3, cvec); + + _mm256_store_pd(x + 0*4, xvec0); + _mm256_store_pd(x + 1*4, xvec1); + _mm256_store_pd(x + 2*4, xvec2); + _mm256_store_pd(x + 3*4, xvec3); + } + + for (; i <= n-4; i += 4, x += 4, y += 4) { + xvec0 = _mm256_load_pd(x+0*4); + yvec0 = _mm256_load_pd(y+0*4); + AVX_MUL_ADD(xvec0, yvec0, cvec); + _mm256_store_pd(x + 0*4, xvec0); + } + + for (; i < n; i++, x++, y++) { + *x += (*y)*c; + } +} + +FMA_FUNC(void,muladd_interval1) +(double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n) +{ + + __m256d xvec0, xvec1, xvec2, xvec3; + __m256d yvec0, yvec1, yvec2, yvec3; + __m256d cvec; + + if (n >= 4) + cvec = _mm256_broadcast_sd(&c); + + long i=0; + for (; i <= n-16; i += 16, x += 16, y += 16) { + xvec0 = _mm256_load_pd(x+0*4); + xvec1 = _mm256_load_pd(x+1*4); + xvec2 = _mm256_load_pd(x+2*4); + xvec3 = _mm256_load_pd(x+3*4); + + yvec0 = _mm256_load_pd(y+0*4); + yvec1 = _mm256_load_pd(y+1*4); + yvec2 = _mm256_load_pd(y+2*4); + yvec3 = _mm256_load_pd(y+3*4); + + FMA_MUL_ADD(xvec0, yvec0, cvec); + FMA_MUL_ADD(xvec1, yvec1, cvec); + FMA_MUL_ADD(xvec2, yvec2, cvec); + FMA_MUL_ADD(xvec3, yvec3, cvec); + + _mm256_store_pd(x + 0*4, xvec0); + _mm256_store_pd(x + 1*4, xvec1); + _mm256_store_pd(x + 2*4, xvec2); + _mm256_store_pd(x + 3*4, xvec3); + } + + for (; i <= n-4; i += 4, x += 4, y += 4) { + xvec0 = _mm256_load_pd(x+0*4); + yvec0 = _mm256_load_pd(y+0*4); + FMA_MUL_ADD(xvec0, yvec0, cvec); + _mm256_store_pd(x + 0*4, xvec0); + } + + for (; i < n; i++, x++, y++) { + *x += (*y)*c; + } +} + +FMA_RESOLVER(static,void,muladd_interval1, + (double * NTL_RESTRICT x, double * NTL_RESTRICT y, double c, long n)); + +#else + // this one is more general: does not assume that n is a // multiple of 16 static inline @@ -1422,6 +2175,7 @@ void muladd_interval1(double * NTL_RESTR } } +#endif #endif @@ -3009,10 +3763,10 @@ void alt_mul_LL(const mat_window_zz_p& X } -#ifdef NTL_HAVE_AVX +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) -static -void blk_mul_DD(const mat_window_zz_p& X, +static void __attribute__((target("avx,pclmul"))) +blk_mul_DD(const mat_window_zz_p& X, const const_mat_window_zz_p& A, const const_mat_window_zz_p& B) { long n = A.NumRows(); @@ -3351,12 +4105,13 @@ void mul_base (const mat_window_zz_p& X, long p = zz_p::modulus(); long V = MAT_BLK_SZ*4; -#ifdef NTL_HAVE_AVX +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) // experimentally, blk_mul_DD beats all the alternatives // if each dimension is at least 16 - if (n >= 16 && l >= 16 && m >= 16 && + if (AVX_ACTIVE && + n >= 16 && l >= 16 && m >= 16 && p-1 <= MAX_DBL_INT && V <= (MAX_DBL_INT-(p-1))/(p-1) && V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) @@ -3451,7 +4206,8 @@ void mul_strassen(const mat_window_zz_p& // this code determines if mul_base triggers blk_mul_DD, // in which case a higher crossover is used -#if (defined(NTL_HAVE_LL_TYPE) && defined(NTL_HAVE_AVX)) +#if (defined(NTL_HAVE_LL_TYPE) && (defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU))) + if (AVX_ACTIVE) { long V = MAT_BLK_SZ*4; long p = zz_p::modulus(); @@ -3950,10 +4706,10 @@ void alt_inv_L(zz_p& d, mat_zz_p& X, con -#ifdef NTL_HAVE_AVX +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) -static -void alt_inv_DD(zz_p& d, mat_zz_p& X, const mat_zz_p& A, bool relax) +static void __attribute__((target("avx,pclmul"))) +alt_inv_DD(zz_p& d, mat_zz_p& X, const mat_zz_p& A, bool relax) { long n = A.NumRows(); @@ -4118,10 +4874,10 @@ void alt_inv_DD(zz_p& d, mat_zz_p& X, co -#ifdef NTL_HAVE_AVX +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) -static -void blk_inv_DD(zz_p& d, mat_zz_p& X, const mat_zz_p& A, bool relax) +static void __attribute__((target("avx,pclmul"))) +blk_inv_DD(zz_p& d, mat_zz_p& X, const mat_zz_p& A, bool relax) { long n = A.NumRows(); @@ -4879,8 +5635,9 @@ void relaxed_inv(zz_p& d, mat_zz_p& X, c else if (n/MAT_BLK_SZ < 4) { long V = 64; -#ifdef NTL_HAVE_AVX - if (p-1 <= MAX_DBL_INT && +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) + if (AVX_ACTIVE && + p-1 <= MAX_DBL_INT && V <= (MAX_DBL_INT-(p-1))/(p-1) && V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) { @@ -4905,8 +5662,9 @@ void relaxed_inv(zz_p& d, mat_zz_p& X, c else { long V = 4*MAT_BLK_SZ; -#ifdef NTL_HAVE_AVX - if (p-1 <= MAX_DBL_INT && +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) + if (AVX_ACTIVE && + p-1 <= MAX_DBL_INT && V <= (MAX_DBL_INT-(p-1))/(p-1) && V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) { @@ -5312,10 +6070,10 @@ void alt_tri_L(zz_p& d, const mat_zz_p& -#ifdef NTL_HAVE_AVX +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) -static -void alt_tri_DD(zz_p& d, const mat_zz_p& A, const vec_zz_p *bp, +static void __attribute__((target("avx,pclmul"))) +alt_tri_DD(zz_p& d, const mat_zz_p& A, const vec_zz_p *bp, vec_zz_p *xp, bool trans, bool relax) { long n = A.NumRows(); @@ -5502,10 +6260,10 @@ void alt_tri_DD(zz_p& d, const mat_zz_p& -#ifdef NTL_HAVE_AVX +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) -static -void blk_tri_DD(zz_p& d, const mat_zz_p& A, const vec_zz_p *bp, +static void __attribute__((target("avx,pclmul"))) +blk_tri_DD(zz_p& d, const mat_zz_p& A, const vec_zz_p *bp, vec_zz_p *xp, bool trans, bool relax) { long n = A.NumRows(); @@ -6316,8 +7074,9 @@ void tri(zz_p& d, const mat_zz_p& A, con else if (n/MAT_BLK_SZ < 4) { long V = 64; -#ifdef NTL_HAVE_AVX - if (p-1 <= MAX_DBL_INT && +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) + if (AVX_ACTIVE && + p-1 <= MAX_DBL_INT && V <= (MAX_DBL_INT-(p-1))/(p-1) && V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) { @@ -6342,8 +7101,9 @@ void tri(zz_p& d, const mat_zz_p& A, con else { long V = 4*MAT_BLK_SZ; -#ifdef NTL_HAVE_AVX - if (p-1 <= MAX_DBL_INT && +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) + if (AVX_ACTIVE && + p-1 <= MAX_DBL_INT && V <= (MAX_DBL_INT-(p-1))/(p-1) && V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) { @@ -6589,7 +7349,7 @@ long elim_basic(const mat_zz_p& A, mat_z #ifdef NTL_HAVE_LL_TYPE -#ifdef NTL_HAVE_AVX +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) static inline @@ -8057,8 +8817,9 @@ long elim(const mat_zz_p& A, mat_zz_p *i else { long V = 4*MAT_BLK_SZ; -#ifdef NTL_HAVE_AVX - if (p-1 <= MAX_DBL_INT && +#if defined(NTL_HAVE_AVX) || defined(NTL_LOADTIME_CPU) + if (AVX_ACTIVE && + p-1 <= MAX_DBL_INT && V <= (MAX_DBL_INT-(p-1))/(p-1) && V*(p-1) <= (MAX_DBL_INT-(p-1))/(p-1)) { --- src/QuickTest.cpp.orig 2021-06-20 15:05:49.000000000 -0600 +++ src/QuickTest.cpp 2021-06-23 19:59:29.916142147 -0600 @@ -326,6 +326,9 @@ cerr << "Performance Options:\n"; cerr << "NTL_GF2X_NOINLINE\n"; #endif +#ifdef NTL_LOADTIME_CPU + cerr << "NTL_LOADTIME_CPU\n"; +#endif cerr << "\n\n"; --- src/WizardAux.orig 2021-06-20 15:05:49.000000000 -0600 +++ src/WizardAux 2021-06-23 19:59:29.916142147 -0600 @@ -89,6 +89,7 @@ system("$ARGV[0] InitSettings"); 'NTL_GF2X_NOINLINE' => 0, 'NTL_FFT_BIGTAB' => 0, 'NTL_FFT_LAZYMUL' => 0, +'NTL_LOADTIME_CPU' => 0, 'WIZARD_HACK' => '#define NTL_WIZARD_HACK', --- src/ZZ.cpp.orig 2021-06-20 15:05:48.000000000 -0600 +++ src/ZZ.cpp 2021-06-23 19:59:29.918142149 -0600 @@ -14,6 +14,13 @@ #elif defined(NTL_HAVE_SSSE3) #include #include +#elif defined(NTL_LOADTIME_CPU) +#include +#include +#include + +static int have_avx2 = -1; +static int have_ssse3 = -1; #endif #if defined(NTL_HAVE_KMA) @@ -3268,6 +3275,590 @@ struct RandomStream_impl { }; +#elif defined(NTL_LOADTIME_CPU) + +// round selector, specified values: +// 8: low security - high speed +// 12: mid security - mid speed +// 20: high security - low speed +#ifndef CHACHA_RNDS +#define CHACHA_RNDS 20 +#endif + +typedef __m128i ssse3_ivec_t; +typedef __m256i avx2_ivec_t; + +#define SSSE3_DELTA _mm_set_epi32(0,0,0,1) +#define AVX2_DELTA _mm256_set_epi64x(0,2,0,2) + +#define SSSE3_START _mm_setzero_si128() +#define AVX2_START _mm256_set_epi64x(0,1,0,0) + +#define SSSE3_NONCE(nonce) _mm_set_epi64x(nonce,0) +#define AVX2_NONCE(nonce) _mm256_set_epi64x(nonce, 1, nonce, 0) + +#define SSSE3_STOREU_VEC(m,r) _mm_storeu_si128((__m128i*)(m), r) +#define AVX2_STOREU_VEC(m,r) _mm256_storeu_si256((__m256i*)(m), r) + +#define SSSE3_STORE_VEC(m,r) _mm_store_si128((__m128i*)(m), r) +#define AVX2_STORE_VEC(m,r) _mm256_store_si256((__m256i*)(m), r) + +#define SSSE3_LOAD_VEC(r,m) r = _mm_load_si128((const __m128i *)(m)) +#define AVX2_LOAD_VEC(r,m) r = _mm256_load_si256((const __m256i *)(m)) + +#define SSSE3_LOADU_VEC_128(r, m) r = _mm_loadu_si128((const __m128i*)(m)) +#define AVX2_LOADU_VEC_128(r, m) r = _mm256_broadcastsi128_si256(_mm_loadu_si128((const __m128i*)(m))) + +#define SSSE3_ADD_VEC_32(a,b) _mm_add_epi32(a, b) +#define AVX2_ADD_VEC_32(a,b) _mm256_add_epi32(a, b) + +#define SSSE3_ADD_VEC_64(a,b) _mm_add_epi64(a, b) +#define AVX2_ADD_VEC_64(a,b) _mm256_add_epi64(a, b) + +#define SSSE3_XOR_VEC(a,b) _mm_xor_si128(a, b) +#define AVX2_XOR_VEC(a,b) _mm256_xor_si256(a, b) + +#define SSSE3_ROR_VEC_V1(x) _mm_shuffle_epi32(x,_MM_SHUFFLE(0,3,2,1)) +#define AVX2_ROR_VEC_V1(x) _mm256_shuffle_epi32(x,_MM_SHUFFLE(0,3,2,1)) + +#define SSSE3_ROR_VEC_V2(x) _mm_shuffle_epi32(x,_MM_SHUFFLE(1,0,3,2)) +#define AVX2_ROR_VEC_V2(x) _mm256_shuffle_epi32(x,_MM_SHUFFLE(1,0,3,2)) + +#define SSSE3_ROR_VEC_V3(x) _mm_shuffle_epi32(x,_MM_SHUFFLE(2,1,0,3)) +#define AVX2_ROR_VEC_V3(x) _mm256_shuffle_epi32(x,_MM_SHUFFLE(2,1,0,3)) + +#define SSSE3_ROL_VEC_7(x) SSSE3_XOR_VEC(_mm_slli_epi32(x, 7), _mm_srli_epi32(x,25)) +#define AVX2_ROL_VEC_7(x) AVX2_XOR_VEC(_mm256_slli_epi32(x, 7), _mm256_srli_epi32(x,25)) + +#define SSSE3_ROL_VEC_12(x) SSSE3_XOR_VEC(_mm_slli_epi32(x,12), _mm_srli_epi32(x,20)) +#define AVX2_ROL_VEC_12(x) AVX2_XOR_VEC(_mm256_slli_epi32(x,12), _mm256_srli_epi32(x,20)) + +#define SSSE3_ROL_VEC_8(x) _mm_shuffle_epi8(x,_mm_set_epi8(14,13,12,15,10,9,8,11,6,5,4,7,2,1,0,3)) +#define AVX2_ROL_VEC_8(x) _mm256_shuffle_epi8(x,_mm256_set_epi8(14,13,12,15,10,9,8,11,6,5,4,7,2,1,0,3,14,13,12,15,10,9,8,11,6,5,4,7,2,1,0,3)) + +#define SSSE3_ROL_VEC_16(x) _mm_shuffle_epi8(x,_mm_set_epi8(13,12,15,14,9,8,11,10,5,4,7,6,1,0,3,2)) +#define AVX2_ROL_VEC_16(x) _mm256_shuffle_epi8(x,_mm256_set_epi8(13,12,15,14,9,8,11,10,5,4,7,6,1,0,3,2,13,12,15,14,9,8,11,10,5,4,7,6,1,0,3,2)) + +#define SSSE3_WRITEU_VEC(op, d, v0, v1, v2, v3) \ + SSSE3_STOREU_VEC(op + (d + 0*4), v0); \ + SSSE3_STOREU_VEC(op + (d + 4*4), v1); \ + SSSE3_STOREU_VEC(op + (d + 8*4), v2); \ + SSSE3_STOREU_VEC(op + (d +12*4), v3); +#define AVX2_WRITEU_VEC(op, d, v0, v1, v2, v3) \ + AVX2_STOREU_VEC(op + (d + 0*4), _mm256_permute2x128_si256(v0, v1, 0x20)); \ + AVX2_STOREU_VEC(op + (d + 8*4), _mm256_permute2x128_si256(v2, v3, 0x20)); \ + AVX2_STOREU_VEC(op + (d +16*4), _mm256_permute2x128_si256(v0, v1, 0x31)); \ + AVX2_STOREU_VEC(op + (d +24*4), _mm256_permute2x128_si256(v2, v3, 0x31)); + +#define SSSE3_WRITE_VEC(op, d, v0, v1, v2, v3) \ + SSSE3_STORE_VEC(op + (d + 0*4), v0); \ + SSSE3_STORE_VEC(op + (d + 4*4), v1); \ + SSSE3_STORE_VEC(op + (d + 8*4), v2); \ + SSSE3_STORE_VEC(op + (d +12*4), v3); +#define AVX2_WRITE_VEC(op, d, v0, v1, v2, v3) \ + AVX2_STORE_VEC(op + (d + 0*4), _mm256_permute2x128_si256(v0, v1, 0x20)); \ + AVX2_STORE_VEC(op + (d + 8*4), _mm256_permute2x128_si256(v2, v3, 0x20)); \ + AVX2_STORE_VEC(op + (d +16*4), _mm256_permute2x128_si256(v0, v1, 0x31)); \ + AVX2_STORE_VEC(op + (d +24*4), _mm256_permute2x128_si256(v2, v3, 0x31)); + +#define SSSE3_SZ_VEC (16) +#define AVX2_SZ_VEC (32) + +#define SSSE3_RANSTREAM_NCHUNKS (4) +// leads to a BUFSZ of 512 + +#define AVX2_RANSTREAM_NCHUNKS (2) +// leads to a BUFSZ of 512 + +#define SSSE3_DQROUND_VECTORS_VEC(a,b,c,d) \ + a = SSSE3_ADD_VEC_32(a,b); d = SSSE3_XOR_VEC(d,a); d = SSSE3_ROL_VEC_16(d); \ + c = SSSE3_ADD_VEC_32(c,d); b = SSSE3_XOR_VEC(b,c); b = SSSE3_ROL_VEC_12(b); \ + a = SSSE3_ADD_VEC_32(a,b); d = SSSE3_XOR_VEC(d,a); d = SSSE3_ROL_VEC_8(d); \ + c = SSSE3_ADD_VEC_32(c,d); b = SSSE3_XOR_VEC(b,c); b = SSSE3_ROL_VEC_7(b); \ + b = SSSE3_ROR_VEC_V1(b); c = SSSE3_ROR_VEC_V2(c); d = SSSE3_ROR_VEC_V3(d); \ + a = SSSE3_ADD_VEC_32(a,b); d = SSSE3_XOR_VEC(d,a); d = SSSE3_ROL_VEC_16(d); \ + c = SSSE3_ADD_VEC_32(c,d); b = SSSE3_XOR_VEC(b,c); b = SSSE3_ROL_VEC_12(b); \ + a = SSSE3_ADD_VEC_32(a,b); d = SSSE3_XOR_VEC(d,a); d = SSSE3_ROL_VEC_8(d); \ + c = SSSE3_ADD_VEC_32(c,d); b = SSSE3_XOR_VEC(b,c); b = SSSE3_ROL_VEC_7(b); \ + b = SSSE3_ROR_VEC_V3(b); c = SSSE3_ROR_VEC_V2(c); d = SSSE3_ROR_VEC_V1(d); + +#define AVX2_DQROUND_VECTORS_VEC(a,b,c,d) \ + a = AVX2_ADD_VEC_32(a,b); d = AVX2_XOR_VEC(d,a); d = AVX2_ROL_VEC_16(d); \ + c = AVX2_ADD_VEC_32(c,d); b = AVX2_XOR_VEC(b,c); b = AVX2_ROL_VEC_12(b); \ + a = AVX2_ADD_VEC_32(a,b); d = AVX2_XOR_VEC(d,a); d = AVX2_ROL_VEC_8(d); \ + c = AVX2_ADD_VEC_32(c,d); b = AVX2_XOR_VEC(b,c); b = AVX2_ROL_VEC_7(b); \ + b = AVX2_ROR_VEC_V1(b); c = AVX2_ROR_VEC_V2(c); d = AVX2_ROR_VEC_V3(d); \ + a = AVX2_ADD_VEC_32(a,b); d = AVX2_XOR_VEC(d,a); d = AVX2_ROL_VEC_16(d); \ + c = AVX2_ADD_VEC_32(c,d); b = AVX2_XOR_VEC(b,c); b = AVX2_ROL_VEC_12(b); \ + a = AVX2_ADD_VEC_32(a,b); d = AVX2_XOR_VEC(d,a); d = AVX2_ROL_VEC_8(d); \ + c = AVX2_ADD_VEC_32(c,d); b = AVX2_XOR_VEC(b,c); b = AVX2_ROL_VEC_7(b); \ + b = AVX2_ROR_VEC_V3(b); c = AVX2_ROR_VEC_V2(c); d = AVX2_ROR_VEC_V1(d); + +#define SSSE3_RANSTREAM_STATESZ (4*SSSE3_SZ_VEC) +#define AVX2_RANSTREAM_STATESZ (4*AVX2_SZ_VEC) + +#define SSSE3_RANSTREAM_CHUNKSZ (2*SSSE3_RANSTREAM_STATESZ) +#define AVX2_RANSTREAM_CHUNKSZ (2*AVX2_RANSTREAM_STATESZ) + +#define SSSE3_RANSTREAM_BUFSZ (SSSE3_RANSTREAM_NCHUNKS*SSSE3_RANSTREAM_CHUNKSZ) +#define AVX2_RANSTREAM_BUFSZ (AVX2_RANSTREAM_NCHUNKS*AVX2_RANSTREAM_CHUNKSZ) + +static void allocate_space(AlignedArray &state_store, + AlignedArray &buf_store) +{ + if (have_avx2) { + state_store.SetLength(AVX2_RANSTREAM_STATESZ); + buf_store.SetLength(AVX2_RANSTREAM_BUFSZ); + } else { + state_store.SetLength(SSSE3_RANSTREAM_STATESZ); + buf_store.SetLength(SSSE3_RANSTREAM_BUFSZ); + } +}; + +BASE_FUNC(void, randomstream_impl_init) +(_ntl_uint32 *state, + AlignedArray &state_store __attribute__((unused)), + AlignedArray &buf_store __attribute__((unused)), + const unsigned char *key) +{ + salsa20_init(state, key); +} + +SSSE3_FUNC(void, randomstream_impl_init) +(_ntl_uint32 *state_ignored __attribute__((unused)), + AlignedArray &state_store, + AlignedArray &buf_store, + const unsigned char *key) +{ + allocate_space(state_store, buf_store); + + unsigned char *state = state_store.elts(); + + unsigned int chacha_const[] = { + 0x61707865,0x3320646E,0x79622D32,0x6B206574 + }; + + ssse3_ivec_t d0, d1, d2, d3; + SSSE3_LOADU_VEC_128(d0, chacha_const); + SSSE3_LOADU_VEC_128(d1, key); + SSSE3_LOADU_VEC_128(d2, key+16); + + d3 = SSSE3_START; + + SSSE3_STORE_VEC(state + 0*SSSE3_SZ_VEC, d0); + SSSE3_STORE_VEC(state + 1*SSSE3_SZ_VEC, d1); + SSSE3_STORE_VEC(state + 2*SSSE3_SZ_VEC, d2); + SSSE3_STORE_VEC(state + 3*SSSE3_SZ_VEC, d3); +} + +AVX2_FUNC(void, randomstream_impl_init) +(_ntl_uint32 *state_ignored __attribute__((unused)), + AlignedArray &state_store, + AlignedArray &buf_store, + const unsigned char *key) +{ + allocate_space(state_store, buf_store); + + unsigned char *state = state_store.elts(); + + unsigned int chacha_const[] = { + 0x61707865,0x3320646E,0x79622D32,0x6B206574 + }; + + avx2_ivec_t d0, d1, d2, d3; + AVX2_LOADU_VEC_128(d0, chacha_const); + AVX2_LOADU_VEC_128(d1, key); + AVX2_LOADU_VEC_128(d2, key+16); + + d3 = AVX2_START; + + AVX2_STORE_VEC(state + 0*AVX2_SZ_VEC, d0); + AVX2_STORE_VEC(state + 1*AVX2_SZ_VEC, d1); + AVX2_STORE_VEC(state + 2*AVX2_SZ_VEC, d2); + AVX2_STORE_VEC(state + 3*AVX2_SZ_VEC, d3); +} + +SSSE3_RESOLVER(static, void, randomstream_impl_init, + (_ntl_uint32 *state, AlignedArray &state_store, + AlignedArray &buf_store, const unsigned char *key)); + +BASE_FUNC(long, randomstream_get_bytes) +(_ntl_uint32 *state, + unsigned char *buf, + AlignedArray &state_store __attribute__((unused)), + AlignedArray &buf_store __attribute__((unused)), + long &chunk_count __attribute__((unused)), + unsigned char *NTL_RESTRICT res, + long n, + long pos) +{ + if (n < 0) LogicError("RandomStream::get: bad args"); + + long i, j; + + if (n <= 64-pos) { + for (i = 0; i < n; i++) res[i] = buf[pos+i]; + pos += n; + return pos; + } + + // read remainder of buffer + for (i = 0; i < 64-pos; i++) res[i] = buf[pos+i]; + n -= 64-pos; + res += 64-pos; + pos = 64; + + _ntl_uint32 wdata[16]; + + // read 64-byte chunks + for (i = 0; i <= n-64; i += 64) { + salsa20_apply(state, wdata); + for (j = 0; j < 16; j++) + FROMLE(res + i + 4*j, wdata[j]); + } + + if (i < n) { + salsa20_apply(state, wdata); + + for (j = 0; j < 16; j++) + FROMLE(buf + 4*j, wdata[j]); + + pos = n-i; + for (j = 0; j < pos; j++) + res[i+j] = buf[j]; + } + + return pos; +} + +SSSE3_FUNC(long, randomstream_get_bytes) +(_ntl_uint32 *state_ignored __attribute__((unused)), + unsigned char *buf_ignored __attribute__((unused)), + AlignedArray &state_store, + AlignedArray &buf_store, + long &chunk_count, + unsigned char *NTL_RESTRICT res, + long n, + long pos) +{ + if (n < 0) LogicError("RandomStream::get: bad args"); + if (n == 0) return pos; + + unsigned char *NTL_RESTRICT buf = buf_store.elts(); + + if (n <= SSSE3_RANSTREAM_BUFSZ-pos) { + std::memcpy(&res[0], &buf[pos], n); + pos += n; + return pos; + } + + unsigned char *NTL_RESTRICT state = state_store.elts(); + + ssse3_ivec_t d0, d1, d2, d3; + SSSE3_LOAD_VEC(d0, state + 0*SSSE3_SZ_VEC); + SSSE3_LOAD_VEC(d1, state + 1*SSSE3_SZ_VEC); + SSSE3_LOAD_VEC(d2, state + 2*SSSE3_SZ_VEC); + SSSE3_LOAD_VEC(d3, state + 3*SSSE3_SZ_VEC); + + // read remainder of buffer + std::memcpy(&res[0], &buf[pos], SSSE3_RANSTREAM_BUFSZ-pos); + n -= SSSE3_RANSTREAM_BUFSZ-pos; + res += SSSE3_RANSTREAM_BUFSZ-pos; + pos = SSSE3_RANSTREAM_BUFSZ; + + long i = 0; + for (; i <= n-SSSE3_RANSTREAM_BUFSZ; i += SSSE3_RANSTREAM_BUFSZ) { + chunk_count |= SSSE3_RANSTREAM_NCHUNKS; // disable small buffer strategy + + for (long j = 0; j < SSSE3_RANSTREAM_NCHUNKS; j++) { + ssse3_ivec_t v0=d0, v1=d1, v2=d2, v3=d3; + ssse3_ivec_t v4=d0, v5=d1, v6=d2, v7=SSSE3_ADD_VEC_64(d3, SSSE3_DELTA); + + for (long k = 0; k < CHACHA_RNDS/2; k++) { + SSSE3_DQROUND_VECTORS_VEC(v0,v1,v2,v3) + SSSE3_DQROUND_VECTORS_VEC(v4,v5,v6,v7) + } + + SSSE3_WRITEU_VEC(res+i+j*(8*SSSE3_SZ_VEC), 0, SSSE3_ADD_VEC_32(v0,d0), SSSE3_ADD_VEC_32(v1,d1), SSSE3_ADD_VEC_32(v2,d2), SSSE3_ADD_VEC_32(v3,d3)) + d3 = SSSE3_ADD_VEC_64(d3, SSSE3_DELTA); + SSSE3_WRITEU_VEC(res+i+j*(8*SSSE3_SZ_VEC), 4*SSSE3_SZ_VEC, SSSE3_ADD_VEC_32(v4,d0), SSSE3_ADD_VEC_32(v5,d1), SSSE3_ADD_VEC_32(v6,d2), SSSE3_ADD_VEC_32(v7,d3)) + d3 = SSSE3_ADD_VEC_64(d3, SSSE3_DELTA); + } + + } + + if (i < n) { + + long nchunks; + + if (chunk_count < SSSE3_RANSTREAM_NCHUNKS) { + nchunks = long(cast_unsigned((n-i)+SSSE3_RANSTREAM_CHUNKSZ-1)/SSSE3_RANSTREAM_CHUNKSZ); + chunk_count += nchunks; + } + else + nchunks = SSSE3_RANSTREAM_NCHUNKS; + + long pos_offset = SSSE3_RANSTREAM_BUFSZ - nchunks*SSSE3_RANSTREAM_CHUNKSZ; + buf += pos_offset; + + for (long j = 0; j < nchunks; j++) { + ssse3_ivec_t v0=d0, v1=d1, v2=d2, v3=d3; + ssse3_ivec_t v4=d0, v5=d1, v6=d2, v7=SSSE3_ADD_VEC_64(d3, SSSE3_DELTA); + + for (long k = 0; k < CHACHA_RNDS/2; k++) { + SSSE3_DQROUND_VECTORS_VEC(v0,v1,v2,v3) + SSSE3_DQROUND_VECTORS_VEC(v4,v5,v6,v7) + } + + SSSE3_WRITE_VEC(buf+j*(8*SSSE3_SZ_VEC), 0, SSSE3_ADD_VEC_32(v0,d0), SSSE3_ADD_VEC_32(v1,d1), SSSE3_ADD_VEC_32(v2,d2), SSSE3_ADD_VEC_32(v3,d3)) + d3 = SSSE3_ADD_VEC_64(d3, SSSE3_DELTA); + SSSE3_WRITE_VEC(buf+j*(8*SSSE3_SZ_VEC), 4*SSSE3_SZ_VEC, SSSE3_ADD_VEC_32(v4,d0), SSSE3_ADD_VEC_32(v5,d1), SSSE3_ADD_VEC_32(v6,d2), SSSE3_ADD_VEC_32(v7,d3)) + d3 = SSSE3_ADD_VEC_64(d3, SSSE3_DELTA); + } + + pos = n-i+pos_offset; + std::memcpy(&res[i], &buf[0], n-i); + } + + SSSE3_STORE_VEC(state + 3*SSSE3_SZ_VEC, d3); + + return pos; +} + +AVX2_FUNC(long, randomstream_get_bytes) +(_ntl_uint32 *state_ignored __attribute__((unused)), + unsigned char *buf_ignored __attribute__((unused)), + AlignedArray &state_store, + AlignedArray &buf_store, + long &chunk_count, + unsigned char *NTL_RESTRICT res, + long n, + long pos) +{ + if (n < 0) LogicError("RandomStream::get: bad args"); + if (n == 0) return pos; + + unsigned char *NTL_RESTRICT buf = buf_store.elts(); + + if (n <= AVX2_RANSTREAM_BUFSZ-pos) { + std::memcpy(&res[0], &buf[pos], n); + pos += n; + return pos; + } + + unsigned char *NTL_RESTRICT state = state_store.elts(); + + avx2_ivec_t d0, d1, d2, d3; + AVX2_LOAD_VEC(d0, state + 0*AVX2_SZ_VEC); + AVX2_LOAD_VEC(d1, state + 1*AVX2_SZ_VEC); + AVX2_LOAD_VEC(d2, state + 2*AVX2_SZ_VEC); + AVX2_LOAD_VEC(d3, state + 3*AVX2_SZ_VEC); + + // read remainder of buffer + std::memcpy(&res[0], &buf[pos], AVX2_RANSTREAM_BUFSZ-pos); + n -= AVX2_RANSTREAM_BUFSZ-pos; + res += AVX2_RANSTREAM_BUFSZ-pos; + pos = AVX2_RANSTREAM_BUFSZ; + + long i = 0; + for (; i <= n-AVX2_RANSTREAM_BUFSZ; i += AVX2_RANSTREAM_BUFSZ) { + chunk_count |= AVX2_RANSTREAM_NCHUNKS; // disable small buffer strategy + + for (long j = 0; j < AVX2_RANSTREAM_NCHUNKS; j++) { + avx2_ivec_t v0=d0, v1=d1, v2=d2, v3=d3; + avx2_ivec_t v4=d0, v5=d1, v6=d2, v7=AVX2_ADD_VEC_64(d3, AVX2_DELTA); + + for (long k = 0; k < CHACHA_RNDS/2; k++) { + AVX2_DQROUND_VECTORS_VEC(v0,v1,v2,v3) + AVX2_DQROUND_VECTORS_VEC(v4,v5,v6,v7) + } + + AVX2_WRITEU_VEC(res+i+j*(8*AVX2_SZ_VEC), 0, AVX2_ADD_VEC_32(v0,d0), AVX2_ADD_VEC_32(v1,d1), AVX2_ADD_VEC_32(v2,d2), AVX2_ADD_VEC_32(v3,d3)) + d3 = AVX2_ADD_VEC_64(d3, AVX2_DELTA); + AVX2_WRITEU_VEC(res+i+j*(8*AVX2_SZ_VEC), 4*AVX2_SZ_VEC, AVX2_ADD_VEC_32(v4,d0), AVX2_ADD_VEC_32(v5,d1), AVX2_ADD_VEC_32(v6,d2), AVX2_ADD_VEC_32(v7,d3)) + d3 = AVX2_ADD_VEC_64(d3, AVX2_DELTA); + } + + } + + if (i < n) { + + long nchunks; + + if (chunk_count < AVX2_RANSTREAM_NCHUNKS) { + nchunks = long(cast_unsigned((n-i)+AVX2_RANSTREAM_CHUNKSZ-1)/AVX2_RANSTREAM_CHUNKSZ); + chunk_count += nchunks; + } + else + nchunks = AVX2_RANSTREAM_NCHUNKS; + + long pos_offset = AVX2_RANSTREAM_BUFSZ - nchunks*AVX2_RANSTREAM_CHUNKSZ; + buf += pos_offset; + + for (long j = 0; j < nchunks; j++) { + avx2_ivec_t v0=d0, v1=d1, v2=d2, v3=d3; + avx2_ivec_t v4=d0, v5=d1, v6=d2, v7=AVX2_ADD_VEC_64(d3, AVX2_DELTA); + + for (long k = 0; k < CHACHA_RNDS/2; k++) { + AVX2_DQROUND_VECTORS_VEC(v0,v1,v2,v3) + AVX2_DQROUND_VECTORS_VEC(v4,v5,v6,v7) + } + + AVX2_WRITE_VEC(buf+j*(8*AVX2_SZ_VEC), 0, AVX2_ADD_VEC_32(v0,d0), AVX2_ADD_VEC_32(v1,d1), AVX2_ADD_VEC_32(v2,d2), AVX2_ADD_VEC_32(v3,d3)) + d3 = AVX2_ADD_VEC_64(d3, AVX2_DELTA); + AVX2_WRITE_VEC(buf+j*(8*AVX2_SZ_VEC), 4*AVX2_SZ_VEC, AVX2_ADD_VEC_32(v4,d0), AVX2_ADD_VEC_32(v5,d1), AVX2_ADD_VEC_32(v6,d2), AVX2_ADD_VEC_32(v7,d3)) + d3 = AVX2_ADD_VEC_64(d3, AVX2_DELTA); + } + + pos = n-i+pos_offset; + std::memcpy(&res[i], &buf[0], n-i); + } + + AVX2_STORE_VEC(state + 3*AVX2_SZ_VEC, d3); + + return pos; +} + +SSSE3_RESOLVER(static, long, randomstream_get_bytes, + (_ntl_uint32 *state, unsigned char *buf, + AlignedArray &state_store, + AlignedArray &buf_store, + long &chunk_count, + unsigned char *NTL_RESTRICT res, + long n, + long pos)); + +BASE_FUNC(void, randomstream_set_nonce) +(_ntl_uint32 *state, + AlignedArray &state_store __attribute__((unused)), + long &chunk_count __attribute__((unused)), + unsigned long nonce) +{ + _ntl_uint32 nonce0, nonce1; + + nonce0 = nonce; + nonce0 = INT32MASK(nonce0); + + nonce1 = 0; + +#if (NTL_BITS_PER_LONG > 32) + nonce1 = nonce >> 32; + nonce1 = INT32MASK(nonce1); +#endif + + state[12] = 0; + state[13] = 0; + state[14] = nonce0; + state[15] = nonce1; +} + +SSSE3_FUNC(void, randomstream_set_nonce) +(_ntl_uint32 *state_ignored __attribute__((unused)), + AlignedArray &state_store, + long &chunk_count, + unsigned long nonce) +{ + unsigned char *state = state_store.elts(); + ssse3_ivec_t d3; + d3 = SSSE3_NONCE(nonce); + SSSE3_STORE_VEC(state + 3*SSSE3_SZ_VEC, d3); + chunk_count = 0; +} + +AVX2_FUNC(void, randomstream_set_nonce) +(_ntl_uint32 *state_ignored __attribute__((unused)), + AlignedArray &state_store, + long &chunk_count, + unsigned long nonce) +{ + unsigned char *state = state_store.elts(); + avx2_ivec_t d3; + d3 = AVX2_NONCE(nonce); + AVX2_STORE_VEC(state + 3*AVX2_SZ_VEC, d3); + chunk_count = 0; +} + +SSSE3_RESOLVER(, void, randomstream_set_nonce, + (_ntl_uint32 *state, + AlignedArray &state_store, + long &chunk_count, + unsigned long nonce)); + +struct RandomStream_impl { + AlignedArray state_store; + AlignedArray buf_store; + long chunk_count; + _ntl_uint32 state[16]; + unsigned char buf[64]; + + explicit + RandomStream_impl(const unsigned char *key) + { + randomstream_impl_init(state, state_store, buf_store, key); + chunk_count = 0; + } + + RandomStream_impl(const RandomStream_impl& other) + { + if (have_avx2 || have_ssse3) { + allocate_space(state_store, buf_store); + } + *this = other; + } + + RandomStream_impl& operator=(const RandomStream_impl& other) + { + if (have_avx2) { + std::memcpy(state_store.elts(), other.state_store.elts(), AVX2_RANSTREAM_STATESZ); + std::memcpy(buf_store.elts(), other.buf_store.elts(), AVX2_RANSTREAM_BUFSZ); + } else if (have_ssse3) { + std::memcpy(state_store.elts(), other.state_store.elts(), SSSE3_RANSTREAM_STATESZ); + std::memcpy(buf_store.elts(), other.buf_store.elts(), SSSE3_RANSTREAM_BUFSZ); + } + chunk_count = other.chunk_count; + return *this; + } + + const unsigned char * + get_buf() const + { + if (have_avx2 || have_ssse3) { + return buf_store.elts(); + } else { + return &buf[0]; + } + } + + long + get_buf_len() const + { + if (have_avx2) { + return AVX2_RANSTREAM_BUFSZ; + } else if (have_ssse3) { + return SSSE3_RANSTREAM_BUFSZ; + } else { + return 64; + } + } + + // bytes are generated in chunks of RANSTREAM_BUFSZ bytes, except that + // initially, we may generate a few chunks of RANSTREAM_CHUNKSZ + // bytes. This optimizes a bit for short bursts following a reset. + + long + get_bytes(unsigned char *NTL_RESTRICT res, + long n, long pos) + { + return randomstream_get_bytes(state, buf, state_store, buf_store, + chunk_count, res, n, pos); + } + + void + set_nonce(unsigned long nonce) + { + randomstream_set_nonce(state, state_store, chunk_count, nonce); + } +}; #else