|
@@ -0,0 +1,150 @@
|
|
|
+/*
|
|
|
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
|
+ * or more contributor license agreements. Licensed under the Elastic License
|
|
|
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
|
|
|
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
|
|
|
+ * Side Public License, v 1.
|
|
|
+ */
|
|
|
+
|
|
|
+#include <stddef.h>
|
|
|
+#include <stdint.h>
|
|
|
+#include "vec.h"
|
|
|
+
|
|
|
+#include <emmintrin.h>
|
|
|
+#include <immintrin.h>
|
|
|
+
|
|
|
+#ifndef DOT7U_STRIDE_BYTES_LEN
|
|
|
+#define DOT7U_STRIDE_BYTES_LEN 32 // Must be a power of 2
|
|
|
+#endif
|
|
|
+
|
|
|
+#ifndef SQR7U_STRIDE_BYTES_LEN
|
|
|
+#define SQR7U_STRIDE_BYTES_LEN 32 // Must be a power of 2
|
|
|
+#endif
|
|
|
+
|
|
|
+#ifdef _MSC_VER
|
|
|
+#include <intrin.h>
|
|
|
+#elif __GNUC__
|
|
|
+#include <x86intrin.h>
|
|
|
+#elif __clang__
|
|
|
+#include <x86intrin.h>
|
|
|
+#endif
|
|
|
+
|
|
|
+// Multi-platform CPUID "intrinsic"; it takes as input a "functionNumber" (or "leaf", the eax registry). "Subleaf"
|
|
|
+// is always 0. Output is stored in the passed output parameter: output[0] = eax, output[1] = ebx, output[2] = ecx,
|
|
|
+// output[3] = edx
|
|
|
+static inline void cpuid(int output[4], int functionNumber) {
|
|
|
+#if defined(__GNUC__) || defined(__clang__)
|
|
|
+ // use inline assembly, Gnu/AT&T syntax
|
|
|
+ int a, b, c, d;
|
|
|
+ __asm("cpuid" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "a"(functionNumber), "c"(0) : );
|
|
|
+ output[0] = a;
|
|
|
+ output[1] = b;
|
|
|
+ output[2] = c;
|
|
|
+ output[3] = d;
|
|
|
+
|
|
|
+#elif defined (_MSC_VER)
|
|
|
+ __cpuidex(output, functionNumber, 0);
|
|
|
+#else
|
|
|
+ #error Unsupported compiler
|
|
|
+#endif
|
|
|
+}
|
|
|
+
|
|
|
+// Utility function to horizontally add 8 32-bit integers
|
|
|
+static inline int hsum_i32_8(const __m256i a) {
|
|
|
+ const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
|
|
|
+ const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
|
|
|
+ const __m128i sum64 = _mm_add_epi32(hi64, sum128);
|
|
|
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
|
|
|
+ return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
|
|
|
+}
|
|
|
+
|
|
|
+EXPORT int vec_caps() {
|
|
|
+ int cpuInfo[4] = {-1};
|
|
|
+ // Calling __cpuid with 0x0 as the function_id argument
|
|
|
+ // gets the number of the highest valid function ID.
|
|
|
+ cpuid(cpuInfo, 0);
|
|
|
+ int functionIds = cpuInfo[0];
|
|
|
+ if (functionIds >= 7) {
|
|
|
+ cpuid(cpuInfo, 7);
|
|
|
+ int ebx = cpuInfo[1];
|
|
|
+ // AVX2 flag is the 5th bit
|
|
|
+ // We assume that all processors that have AVX2 also have FMA3
|
|
|
+ return (ebx & (1 << 5)) != 0;
|
|
|
+ }
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+static inline int32_t dot7u_inner(int8_t* a, int8_t* b, size_t dims) {
|
|
|
+ const __m256i ones = _mm256_set1_epi16(1);
|
|
|
+
|
|
|
+ // Init accumulator(s) with 0
|
|
|
+ __m256i acc1 = _mm256_setzero_si256();
|
|
|
+
|
|
|
+#pragma GCC unroll 4
|
|
|
+ for(int i = 0; i < dims; i += DOT7U_STRIDE_BYTES_LEN) {
|
|
|
+ // Load packed 8-bit integers
|
|
|
+ __m256i va1 = _mm256_loadu_si256(a + i);
|
|
|
+ __m256i vb1 = _mm256_loadu_si256(b + i);
|
|
|
+
|
|
|
+ // Perform multiplication and create 16-bit values
|
|
|
+ // Vertically multiply each unsigned 8-bit integer from va with the corresponding
|
|
|
+ // 8-bit integer from vb, producing intermediate signed 16-bit integers.
|
|
|
+ const __m256i vab = _mm256_maddubs_epi16(va1, vb1);
|
|
|
+ // Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results.
|
|
|
+ acc1 = _mm256_add_epi32(_mm256_madd_epi16(ones, vab), acc1);
|
|
|
+ }
|
|
|
+
|
|
|
+ // reduce (horizontally add all)
|
|
|
+ return hsum_i32_8(acc1);
|
|
|
+}
|
|
|
+
|
|
|
+EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) {
|
|
|
+ int32_t res = 0;
|
|
|
+ int i = 0;
|
|
|
+ if (dims > DOT7U_STRIDE_BYTES_LEN) {
|
|
|
+ i += dims & ~(DOT7U_STRIDE_BYTES_LEN - 1);
|
|
|
+ res = dot7u_inner(a, b, i);
|
|
|
+ }
|
|
|
+ for (; i < dims; i++) {
|
|
|
+ res += a[i] * b[i];
|
|
|
+ }
|
|
|
+ return res;
|
|
|
+}
|
|
|
+
|
|
|
+static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) {
|
|
|
+ // Init accumulator(s) with 0
|
|
|
+ __m256i acc1 = _mm256_setzero_si256();
|
|
|
+
|
|
|
+ const __m256i ones = _mm256_set1_epi16(1);
|
|
|
+
|
|
|
+#pragma GCC unroll 4
|
|
|
+ for(int i = 0; i < dims; i += SQR7U_STRIDE_BYTES_LEN) {
|
|
|
+ // Load packed 8-bit integers
|
|
|
+ __m256i va1 = _mm256_loadu_si256(a + i);
|
|
|
+ __m256i vb1 = _mm256_loadu_si256(b + i);
|
|
|
+
|
|
|
+ const __m256i dist1 = _mm256_sub_epi8(va1, vb1);
|
|
|
+ const __m256i abs_dist1 = _mm256_sign_epi8(dist1, dist1);
|
|
|
+ const __m256i sqr1 = _mm256_maddubs_epi16(abs_dist1, abs_dist1);
|
|
|
+
|
|
|
+ acc1 = _mm256_add_epi32(_mm256_madd_epi16(ones, sqr1), acc1);
|
|
|
+ }
|
|
|
+
|
|
|
+ // reduce (accumulate all)
|
|
|
+ return hsum_i32_8(acc1);
|
|
|
+}
|
|
|
+
|
|
|
+EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) {
|
|
|
+ int32_t res = 0;
|
|
|
+ int i = 0;
|
|
|
+ if (dims > SQR7U_STRIDE_BYTES_LEN) {
|
|
|
+ i += dims & ~(SQR7U_STRIDE_BYTES_LEN - 1);
|
|
|
+ res = sqr7u_inner(a, b, i);
|
|
|
+ }
|
|
|
+ for (; i < dims; i++) {
|
|
|
+ int32_t dist = a[i] - b[i];
|
|
|
+ res += dist * dist;
|
|
|
+ }
|
|
|
+ return res;
|
|
|
+}
|
|
|
+
|