Просмотр исходного кода

JDKVectorLibrary: update low-level bounds checks and add benchmark (#130216)

This commit updates the low-level bounds checks in JDKVectorLibrary and add benchmark, so that we can more easily bench the low-level operations.

Note: I added the mr-jar gradle plugin to the benchmarks so that we can compile with preview features in Java 21, namely MemorySegment.
Chris Hegarty 3 месяцев назад
Родитель
Сommit
4d3b699067

+ 2 - 0
benchmarks/build.gradle

@@ -13,6 +13,7 @@ import org.elasticsearch.gradle.OS
 apply plugin: org.elasticsearch.gradle.internal.ElasticsearchJavaBasePlugin
 apply plugin: 'java-library'
 apply plugin: 'application'
+apply plugin: 'elasticsearch.mrjar'
 
 var os = org.gradle.internal.os.OperatingSystem.current()
 
@@ -46,6 +47,7 @@ dependencies {
   api(project(':x-pack:plugin:core'))
   api(project(':x-pack:plugin:esql'))
   api(project(':x-pack:plugin:esql:compute'))
+  implementation project(path: ':libs:native')
   implementation project(path: ':libs:simdvec')
   expression(project(path: ':modules:lang-expression', configuration: 'zip'))
   painless(project(path: ':modules:lang-painless', configuration: 'zip'))

+ 129 - 0
benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmark.java

@@ -0,0 +1,129 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.benchmark.vector;
+
+import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.common.logging.LogConfigurator;
+import org.elasticsearch.common.logging.NodeNamePatternConverter;
+import org.elasticsearch.nativeaccess.NativeAccess;
+import org.elasticsearch.nativeaccess.VectorSimilarityFunctions;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Level;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Warmup;
+
+import java.lang.foreign.Arena;
+import java.lang.foreign.MemorySegment;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.NANOSECONDS)
+@State(Scope.Benchmark)
+@Warmup(iterations = 3, time = 1)
+@Measurement(iterations = 5, time = 1)
+public class JDKVectorInt7uBenchmark {
+
+    static {
+        NodeNamePatternConverter.setGlobalNodeName("foo");
+        LogConfigurator.loadLog4jPlugins();
+        LogConfigurator.configureESLogging(); // native access requires logging to be initialized
+    }
+
+    byte[] byteArrayA;
+    byte[] byteArrayB;
+    MemorySegment heapSegA, heapSegB;
+    MemorySegment nativeSegA, nativeSegB;
+
+    Arena arena;
+
+    @Param({ "1", "128", "207", "256", "300", "512", "702", "1024" })
+    public int size;
+
+    @Setup(Level.Iteration)
+    public void init() {
+        byteArrayA = new byte[size];
+        byteArrayB = new byte[size];
+        for (int i = 0; i < size; ++i) {
+            randomInt7BytesBetween(byteArrayA);
+            randomInt7BytesBetween(byteArrayB);
+        }
+        heapSegA = MemorySegment.ofArray(byteArrayA);
+        heapSegB = MemorySegment.ofArray(byteArrayB);
+
+        arena = Arena.ofConfined();
+        nativeSegA = arena.allocate((long) byteArrayA.length);
+        MemorySegment.copy(MemorySegment.ofArray(byteArrayA), 0L, nativeSegA, 0L, byteArrayA.length);
+        nativeSegB = arena.allocate((long) byteArrayB.length);
+        MemorySegment.copy(MemorySegment.ofArray(byteArrayB), 0L, nativeSegB, 0L, byteArrayB.length);
+    }
+
+    @TearDown
+    public void teardown() {
+        arena.close();
+    }
+
+    @Benchmark
+    @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public int dotProductLucene() {
+        return VectorUtil.dotProduct(byteArrayA, byteArrayB);
+    }
+
+    @Benchmark
+    @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public int dotProductNativeWithNativeSeg() {
+        return dotProduct7u(nativeSegA, nativeSegB, size);
+    }
+
+    @Benchmark
+    @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+    public int dotProductNativeWithHeapSeg() {
+        return dotProduct7u(heapSegA, heapSegB, size);
+    }
+
+    static final VectorSimilarityFunctions vectorSimilarityFunctions = vectorSimilarityFunctions();
+
+    static VectorSimilarityFunctions vectorSimilarityFunctions() {
+        return NativeAccess.instance().getVectorSimilarityFunctions().get();
+    }
+
+    int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
+        try {
+            return (int) vectorSimilarityFunctions.dotProductHandle7u().invokeExact(a, b, length);
+        } catch (Throwable e) {
+            if (e instanceof Error err) {
+                throw err;
+            } else if (e instanceof RuntimeException re) {
+                throw re;
+            } else {
+                throw new RuntimeException(e);
+            }
+        }
+    }
+
+    // Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
+    static final byte MIN_INT7_VALUE = 0;
+    static final byte MAX_INT7_VALUE = 127;
+
+    static void randomInt7BytesBetween(byte[] bytes) {
+        var random = ThreadLocalRandom.current();
+        for (int i = 0, len = bytes.length; i < len;) {
+            bytes[i++] = (byte) random.nextInt(MIN_INT7_VALUE, MAX_INT7_VALUE + 1);
+        }
+    }
+}

+ 62 - 0
benchmarks/src/test/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmarkTests.java

@@ -0,0 +1,62 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.benchmark.vector;
+
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.test.ESTestCase;
+import org.openjdk.jmh.annotations.Param;
+
+import java.util.Arrays;
+
+public class JDKVectorInt7uBenchmarkTests extends ESTestCase {
+
+    final double delta = 1e-3;
+    final int size;
+
+    public JDKVectorInt7uBenchmarkTests(int size) {
+        this.size = size;
+    }
+
+    public void testDotProduct() {
+        for (int i = 0; i < 100; i++) {
+            var bench = new JDKVectorInt7uBenchmark();
+            bench.size = size;
+            bench.init();
+            try {
+                float expected = dotProductScalar(bench.byteArrayA, bench.byteArrayB);
+                assertEquals(expected, bench.dotProductLucene(), delta);
+                assertEquals(expected, bench.dotProductNativeWithHeapSeg(), delta);
+                assertEquals(expected, bench.dotProductNativeWithNativeSeg(), delta);
+            } finally {
+                bench.teardown();
+            }
+        }
+    }
+
+    @ParametersFactory
+    public static Iterable<Object[]> parametersFactory() {
+        try {
+            var params = JDKVectorInt7uBenchmark.class.getField("size").getAnnotationsByType(Param.class)[0].value();
+            return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator();
+        } catch (NoSuchFieldException e) {
+            throw new AssertionError(e);
+        }
+    }
+
+    /** Computes the dot product of the given vectors a and b. */
+    static int dotProductScalar(byte[] a, byte[] b) {
+        int res = 0;
+        for (int i = 0; i < a.length; i++) {
+            res += a[i] * b[i];
+        }
+        return res;
+    }
+}

+ 9 - 12
libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

@@ -20,6 +20,7 @@ import java.lang.foreign.MemorySegment;
 import java.lang.invoke.MethodHandle;
 import java.lang.invoke.MethodHandles;
 import java.lang.invoke.MethodType;
+import java.util.Objects;
 
 import static java.lang.foreign.ValueLayout.ADDRESS;
 import static java.lang.foreign.ValueLayout.JAVA_INT;
@@ -99,13 +100,8 @@ public final class JdkVectorLibrary implements VectorLibrary {
          * @param length the vector dimensions
          */
         static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
-            assert length >= 0;
-            if (a.byteSize() != b.byteSize()) {
-                throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
-            }
-            if (length > a.byteSize()) {
-                throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
-            }
+            checkByteSize(a, b);
+            Objects.checkFromIndexSize(0, length, (int) a.byteSize());
             return dot7u(a, b, length);
         }
 
@@ -119,14 +115,15 @@ public final class JdkVectorLibrary implements VectorLibrary {
          * @param length the vector dimensions
          */
         static int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
-            assert length >= 0;
+            checkByteSize(a, b);
+            Objects.checkFromIndexSize(0, length, (int) a.byteSize());
+            return sqr7u(a, b, length);
+        }
+
+        static void checkByteSize(MemorySegment a, MemorySegment b) {
             if (a.byteSize() != b.byteSize()) {
                 throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize());
             }
-            if (length > a.byteSize()) {
-                throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
-            }
-            return sqr7u(a, b, length);
         }
 
         private static int dot7u(MemorySegment a, MemorySegment b, int length) {

+ 21 - 4
libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java

@@ -28,6 +28,7 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
     static final byte MAX_INT7_VALUE = 127;
 
     static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;
+    static final Class<IndexOutOfBoundsException> IOOBE = IndexOutOfBoundsException.class;
 
     static final int[] VECTOR_DIMS = { 1, 4, 6, 8, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 1023, 1024, 1025 };
 
@@ -35,8 +36,11 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
 
     static Arena arena;
 
+    final double delta;
+
     public JDKVectorLibraryTests(int size) {
         this.size = size;
+        this.delta = 1e-5 * size; // scale the delta with the size
     }
 
     @BeforeClass
@@ -103,11 +107,24 @@ public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
     public void testIllegalDims() {
         assumeTrue(notSupportedMsg(), supported());
         var segment = arena.allocate((long) size * 3);
-        var e = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
-        assertThat(e.getMessage(), containsString("dimensions differ"));
 
-        e = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
-        assertThat(e.getMessage(), containsString("greater than vector dimensions"));
+        var e1 = expectThrows(IAE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
+        assertThat(e1.getMessage(), containsString("dimensions differ"));
+
+        var e2 = expectThrows(IOOBE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
+        assertThat(e2.getMessage(), containsString("out of bounds for length"));
+
+        var e3 = expectThrows(IOOBE, () -> dotProduct7u(segment.asSlice(0L, size), segment.asSlice(size, size), -1));
+        assertThat(e3.getMessage(), containsString("out of bounds for length"));
+
+        var e4 = expectThrows(IAE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size));
+        assertThat(e4.getMessage(), containsString("dimensions differ"));
+
+        var e5 = expectThrows(IOOBE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1));
+        assertThat(e5.getMessage(), containsString("out of bounds for length"));
+
+        var e6 = expectThrows(IOOBE, () -> squareDistance7u(segment.asSlice(0L, size), segment.asSlice(size, size), -1));
+        assertThat(e6.getMessage(), containsString("out of bounds for length"));
     }
 
     int dotProduct7u(MemorySegment a, MemorySegment b, int length) {

+ 2 - 2
libs/simdvec/native/build.gradle

@@ -23,10 +23,10 @@ var os = org.gradle.internal.os.OperatingSystem.current()
 //  1. Temporarily comment out the download in libs/native/library/build.gradle
 //       libs "org.elasticsearch:vec:${vecVersion}@zip"
 //  2. Copy your locally built libvec binary, e.g.
-//       cp libs/vec/native/build/libs/vec/shared/libvec.dylib libs/native/libraries/build/platform/darwin-aarch64/libvec.dylib
+//       cp libs/simdvec/native/build/libs/vec/shared/aarch64/libvec.dylib libs/native/libraries/build/platform/darwin-aarch64/libvec.dylib
 //
 // Look at the disassemble:
-//  objdump --disassemble-symbols=_dot8s build/libs/vec/shared/libvec.dylib
+//  objdump --disassemble-symbols=_dot7u build/libs/vec/shared/aarch64/libvec.dylib
 // Note: symbol decoration may differ on Linux, i.e. the leading underscore is not present
 //
 // gcc -shared -fpic -o libvec.so -I src/vec/headers/ src/vec/c/vec.c -O3