Przeglądaj źródła

Replace use of reflection with MemorySegmentAccessInput (#109061)

This commit replaces the use of reflection with the newly added MemorySegmentAccessInput.
Chris Hegarty 1 rok temu
rodzic
commit
6b62c5129d

+ 0 - 1
distribution/tools/server-cli/src/main/java/org/elasticsearch/server/cli/SystemJvmOptions.java

@@ -74,7 +74,6 @@ final class SystemJvmOptions {
                  * while we explore alternatives. See org.elasticsearch.xpack.searchablesnapshots.preallocate.Preallocate.
                  */
                 "--add-opens=java.base/java.io=org.elasticsearch.preallocate",
-                "--add-opens=org.apache.lucene.core/org.apache.lucene.store=org.elasticsearch.vec",
                 maybeEnableNativeAccess(),
                 maybeOverrideDockerCgroup(distroType),
                 maybeSetActiveProcessorCount(nodeSettings),

+ 9 - 7
libs/vec/src/main21/java/org/elasticsearch/vec/VectorScorerFactoryImpl.java

@@ -8,11 +8,12 @@
 
 package org.elasticsearch.vec;
 
+import org.apache.lucene.store.FilterIndexInput;
 import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.MemorySegmentAccessInput;
 import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
 import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
 import org.elasticsearch.nativeaccess.NativeAccess;
-import org.elasticsearch.vec.internal.IndexInputUtils;
 import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.DotProductSupplier;
 import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.EuclideanSupplier;
 import org.elasticsearch.vec.internal.Int7SQVectorScorerSupplier.MaxInnerProductSupplier;
@@ -36,15 +37,16 @@ class VectorScorerFactoryImpl implements VectorScorerFactory {
         RandomAccessQuantizedByteVectorValues values,
         float scoreCorrectionConstant
     ) {
-        input = IndexInputUtils.unwrapAndCheckInputOrNull(input);
-        if (input == null) {
-            return Optional.empty(); // the input type is not MemorySegment based
+        input = FilterIndexInput.unwrapOnlyTest(input);
+        if (input instanceof MemorySegmentAccessInput == false) {
+            return Optional.empty();
         }
+        MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input;
         checkInvariants(values.size(), values.dimension(), input);
         return switch (similarityType) {
-            case COSINE, DOT_PRODUCT -> Optional.of(new DotProductSupplier(input, values, scoreCorrectionConstant));
-            case EUCLIDEAN -> Optional.of(new EuclideanSupplier(input, values, scoreCorrectionConstant));
-            case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(input, values, scoreCorrectionConstant));
+            case COSINE, DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values, scoreCorrectionConstant));
+            case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values, scoreCorrectionConstant));
+            case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values, scoreCorrectionConstant));
         };
     }
 

+ 0 - 90
libs/vec/src/main21/java/org/elasticsearch/vec/internal/IndexInputUtils.java

@@ -1,90 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0 and the Server Side Public License, v 1; you may not use this file except
- * in compliance with, at your election, the Elastic License 2.0 or the Server
- * Side Public License, v 1.
- */
-
-package org.elasticsearch.vec.internal;
-
-import org.apache.lucene.store.FilterIndexInput;
-import org.apache.lucene.store.IndexInput;
-
-import java.lang.foreign.MemorySegment;
-import java.lang.invoke.MethodHandles;
-import java.lang.invoke.VarHandle;
-import java.security.AccessController;
-import java.security.PrivilegedAction;
-import java.security.PrivilegedActionException;
-import java.security.PrivilegedExceptionAction;
-
-public final class IndexInputUtils {
-
-    static final Class<?> MSINDEX_CLS, MS_MSINDEX_CLS;
-    static final VarHandle SEGMENTS_ARRAY, CHUNK_SIZE_POWER, CHUNK_SIZE_MASK, MULTI_OFFSET;
-
-    static {
-        try {
-            MSINDEX_CLS = Class.forName("org.apache.lucene.store.MemorySegmentIndexInput");
-            MS_MSINDEX_CLS = Class.forName("org.apache.lucene.store.MemorySegmentIndexInput$MultiSegmentImpl");
-            var lookup = privilegedPrivateLookupIn(MSINDEX_CLS, MethodHandles.lookup());
-            SEGMENTS_ARRAY = privilegedFindVarHandle(lookup, MSINDEX_CLS, "segments", MemorySegment[].class);
-            CHUNK_SIZE_POWER = privilegedFindVarHandle(lookup, MSINDEX_CLS, "chunkSizePower", int.class);
-            CHUNK_SIZE_MASK = privilegedFindVarHandle(lookup, MSINDEX_CLS, "chunkSizeMask", long.class);
-            MULTI_OFFSET = privilegedFindVarHandle(lookup, MS_MSINDEX_CLS, "offset", long.class);
-        } catch (ClassNotFoundException e) {
-            throw new AssertionError(e);
-        } catch (IllegalAccessException e) {
-            throw new AssertionError("should not happen, check opens", e);
-        } catch (PrivilegedActionException e) {
-            throw new AssertionError("should not happen", e);
-        }
-    }
-
-    @SuppressWarnings("removal")
-    static VarHandle privilegedFindVarHandle(MethodHandles.Lookup lookup, Class<?> cls, String name, Class<?> type)
-        throws PrivilegedActionException {
-        PrivilegedExceptionAction<VarHandle> pa = () -> lookup.findVarHandle(cls, name, type);
-        return AccessController.doPrivileged(pa);
-    }
-
-    private IndexInputUtils() {}
-
-    /** Unwraps and returns the input if it's a MemorySegment backed input. Otherwise, null. */
-    public static IndexInput unwrapAndCheckInputOrNull(IndexInput input) {
-        input = FilterIndexInput.unwrap(input);
-        if (MSINDEX_CLS.isAssignableFrom(input.getClass())) {
-            return input;
-        }
-        return null;
-    }
-
-    static MemorySegment[] segmentArray(IndexInput input) {
-        return (MemorySegment[]) SEGMENTS_ARRAY.get(input);
-    }
-
-    static long chunkSizeMask(IndexInput input) {
-        return (long) CHUNK_SIZE_MASK.get(input);
-    }
-
-    static int chunkSizePower(IndexInput input) {
-        return (int) CHUNK_SIZE_POWER.get(input);
-    }
-
-    static long offset(IndexInput input) {
-        return (long) MULTI_OFFSET.get(input);
-    }
-
-    @SuppressWarnings("removal")
-    static MethodHandles.Lookup privilegedPrivateLookupIn(Class<?> cls, MethodHandles.Lookup lookup) throws IllegalAccessException {
-        PrivilegedAction<MethodHandles.Lookup> pa = () -> {
-            try {
-                return MethodHandles.privateLookupIn(cls, lookup);
-            } catch (IllegalAccessException e) {
-                throw new AssertionError("should not happen, check opens", e);
-            }
-        };
-        return AccessController.doPrivileged(pa);
-    }
-}

+ 26 - 54
libs/vec/src/main21/java/org/elasticsearch/vec/internal/Int7SQVectorScorerSupplier.java

@@ -8,7 +8,7 @@
 
 package org.elasticsearch.vec.internal;
 
-import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.MemorySegmentAccessInput;
 import org.apache.lucene.util.hnsw.RandomVectorScorer;
 import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
 import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
@@ -29,18 +29,12 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS
     final int dims;
     final int maxOrd;
     final float scoreCorrectionConstant;
-    final IndexInput input;
+    final MemorySegmentAccessInput input;
     final RandomAccessQuantizedByteVectorValues values; // to support ordToDoc/getAcceptOrds
     final ScalarQuantizedVectorSimilarity fallbackScorer;
 
-    final MemorySegment segment;
-    final MemorySegment[] segments;
-    final long offset;
-    final int chunkSizePower;
-    final long chunkSizeMask;
-
     protected Int7SQVectorScorerSupplier(
-        IndexInput input,
+        MemorySegmentAccessInput input,
         RandomAccessQuantizedByteVectorValues values,
         float scoreCorrectionConstant,
         ScalarQuantizedVectorSimilarity fallbackScorer
@@ -51,17 +45,6 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS
         this.maxOrd = values.size();
         this.scoreCorrectionConstant = scoreCorrectionConstant;
         this.fallbackScorer = fallbackScorer;
-
-        this.segments = IndexInputUtils.segmentArray(input);
-        if (segments.length == 1) {
-            segment = segments[0];
-            offset = 0L;
-        } else {
-            segment = null;
-            offset = IndexInputUtils.offset(input);
-        }
-        this.chunkSizePower = IndexInputUtils.chunkSizePower(input);
-        this.chunkSizeMask = IndexInputUtils.chunkSizeMask(input);
     }
 
     protected final void checkOrdinal(int ord) {
@@ -78,19 +61,17 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS
         long firstByteOffset = (long) firstOrd * (length + Float.BYTES);
         long secondByteOffset = (long) secondOrd * (length + Float.BYTES);
 
-        MemorySegment firstSeg = segmentSlice(firstByteOffset, length);
+        MemorySegment firstSeg = input.segmentSliceOrNull(firstByteOffset, length);
         if (firstSeg == null) {
             return fallbackScore(firstByteOffset, secondByteOffset);
         }
-        input.seek(firstByteOffset + length);
-        float firstOffset = Float.intBitsToFloat(input.readInt());
+        float firstOffset = Float.intBitsToFloat(input.readInt(firstByteOffset + length));
 
-        MemorySegment secondSeg = segmentSlice(secondByteOffset, length);
+        MemorySegment secondSeg = input.segmentSliceOrNull(secondByteOffset, length);
         if (secondSeg == null) {
             return fallbackScore(firstByteOffset, secondByteOffset);
         }
-        input.seek(secondByteOffset + length);
-        float secondOffset = Float.intBitsToFloat(input.readInt());
+        float secondOffset = Float.intBitsToFloat(input.readInt(secondByteOffset + length));
 
         return scoreFromSegments(firstSeg, firstOffset, secondSeg, secondOffset);
     }
@@ -98,15 +79,13 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS
     abstract float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float bOffset);
 
     protected final float fallbackScore(long firstByteOffset, long secondByteOffset) throws IOException {
-        input.seek(firstByteOffset);
         byte[] a = new byte[dims];
-        input.readBytes(a, 0, a.length);
-        float aOffsetValue = Float.intBitsToFloat(input.readInt());
+        input.readBytes(firstByteOffset, a, 0, a.length);
+        float aOffsetValue = Float.intBitsToFloat(input.readInt(firstByteOffset + dims));
 
-        input.seek(secondByteOffset);
         byte[] b = new byte[dims];
-        input.readBytes(b, 0, a.length);
-        float bOffsetValue = Float.intBitsToFloat(input.readInt());
+        input.readBytes(secondByteOffset, b, 0, a.length);
+        float bOffsetValue = Float.intBitsToFloat(input.readInt(secondByteOffset + dims));
 
         return fallbackScorer.score(a, aOffsetValue, b, bOffsetValue);
     }
@@ -122,28 +101,13 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS
         };
     }
 
-    protected final MemorySegment segmentSlice(long pos, int length) {
-        if (segment != null) {
-            // single
-            if (checkIndex(pos, segment.byteSize() + 1)) {
-                return segment.asSlice(pos, length);
-            }
-        } else {
-            // multi
-            pos = pos + this.offset;
-            final int si = (int) (pos >> chunkSizePower);
-            final MemorySegment seg = segments[si];
-            long offset = pos & chunkSizeMask;
-            if (checkIndex(offset + length, seg.byteSize() + 1)) {
-                return seg.asSlice(offset, length);
-            }
-        }
-        return null;
-    }
-
     public static final class EuclideanSupplier extends Int7SQVectorScorerSupplier {
 
-        public EuclideanSupplier(IndexInput input, RandomAccessQuantizedByteVectorValues values, float scoreCorrectionConstant) {
+        public EuclideanSupplier(
+            MemorySegmentAccessInput input,
+            RandomAccessQuantizedByteVectorValues values,
+            float scoreCorrectionConstant
+        ) {
             super(input, values, scoreCorrectionConstant, fromVectorSimilarity(EUCLIDEAN, scoreCorrectionConstant, BITS));
         }
 
@@ -162,7 +126,11 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS
 
     public static final class DotProductSupplier extends Int7SQVectorScorerSupplier {
 
-        public DotProductSupplier(IndexInput input, RandomAccessQuantizedByteVectorValues values, float scoreCorrectionConstant) {
+        public DotProductSupplier(
+            MemorySegmentAccessInput input,
+            RandomAccessQuantizedByteVectorValues values,
+            float scoreCorrectionConstant
+        ) {
             super(input, values, scoreCorrectionConstant, fromVectorSimilarity(DOT_PRODUCT, scoreCorrectionConstant, BITS));
         }
 
@@ -182,7 +150,11 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS
 
     public static final class MaxInnerProductSupplier extends Int7SQVectorScorerSupplier {
 
-        public MaxInnerProductSupplier(IndexInput input, RandomAccessQuantizedByteVectorValues values, float scoreCorrectionConstant) {
+        public MaxInnerProductSupplier(
+            MemorySegmentAccessInput input,
+            RandomAccessQuantizedByteVectorValues values,
+            float scoreCorrectionConstant
+        ) {
             super(input, values, scoreCorrectionConstant, fromVectorSimilarity(MAXIMUM_INNER_PRODUCT, scoreCorrectionConstant, BITS));
         }
 

+ 0 - 150
libs/vec/src/test21/java/org/elasticsearch/vec/internal/IndexInputUtilsTests.java

@@ -1,150 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0 and the Server Side Public License, v 1; you may not use this file except
- * in compliance with, at your election, the Elastic License 2.0 or the Server
- * Side Public License, v 1.
- */
-
-package org.elasticsearch.vec.internal;
-
-import org.apache.lucene.store.Directory;
-import org.apache.lucene.store.IOContext;
-import org.apache.lucene.store.IndexInput;
-import org.apache.lucene.store.IndexOutput;
-import org.apache.lucene.store.MMapDirectory;
-import org.elasticsearch.test.ESTestCase;
-
-import java.io.IOException;
-import java.lang.foreign.MemorySegment;
-import java.lang.foreign.ValueLayout;
-import java.util.Arrays;
-import java.util.stream.IntStream;
-
-import static org.hamcrest.core.IsEqual.equalTo;
-
-public class IndexInputUtilsTests extends ESTestCase {
-
-    public void testSingleSegment() throws IOException {
-        try (Directory dir = new MMapDirectory(createTempDir(getTestName()))) {
-            for (int times = 0; times < TIMES; times++) {
-                String fileName = "testSingleSegment" + times;
-                int size = randomIntBetween(10, 127);
-                try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
-                    byte[] ba = new byte[size];
-                    IntStream.range(0, size).forEach(i -> ba[i] = (byte) i);
-                    out.writeBytes(ba, 0, ba.length);
-                }
-                try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
-                    var input = IndexInputUtils.unwrapAndCheckInputOrNull(in);
-                    assertNotNull(input);
-                    {
-                        var segArray = IndexInputUtils.segmentArray(input);
-                        assertThat(segArray.length, equalTo(1));
-                        assertThat(segArray[0].byteSize(), equalTo((long) size));
-
-                        // Out of Bounds - cannot retrieve the segment
-                        assertNull(segmentSlice(input, 0, size + 1));
-                        assertNull(segmentSlice(input, size - 1, 2));
-
-                        var fullSeg = segmentSlice(input, 0, size);
-                        assertNotNull(fullSeg);
-                        for (int i = 0; i < size; i++) {
-                            assertThat(fullSeg.get(ValueLayout.JAVA_BYTE, i), equalTo((byte) i));
-                        }
-
-                        var partialSeg = segmentSlice(input, 1, size - 1);
-                        assertNotNull(partialSeg);
-                        for (int i = 0; i < size - 2; i++) {
-                            assertThat(partialSeg.get(ValueLayout.JAVA_BYTE, i), equalTo((byte) (i + 1)));
-                        }
-                    }
-                    // IndexInput::slice
-                    {
-                        var slice = input.slice("partial slice", 1, size - 2);
-                        var sliceSgArray = IndexInputUtils.segmentArray(slice);
-                        assertThat(sliceSgArray.length, equalTo(1));
-                        assertThat(sliceSgArray[0].byteSize(), equalTo((long) size - 2));
-
-                        var fullSeg = segmentSlice(slice, 0, size - 2);
-                        assertNotNull(fullSeg);
-                        for (int i = 0; i < size - 2; i++) {
-                            assertThat(fullSeg.get(ValueLayout.JAVA_BYTE, i), equalTo((byte) (i + 1)));
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    public void testMultiSegment() throws IOException {
-        try (Directory dir = new MMapDirectory(createTempDir(getTestName()), 32L)) {
-            for (int times = 0; times < TIMES; times++) {
-                String fileName = "testMultiSegment" + times;
-                int size = randomIntBetween(65, 1511);
-                int expectedNumSegs = size / 32 + 1;
-                try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
-                    byte[] ba = new byte[size];
-                    IntStream.range(0, size).forEach(i -> ba[i] = (byte) i);
-                    out.writeBytes(ba, 0, ba.length);
-                }
-                try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
-                    var input = IndexInputUtils.unwrapAndCheckInputOrNull(in);
-                    assertNotNull(input);
-
-                    var fullSegArray = IndexInputUtils.segmentArray(input);
-                    assertThat(fullSegArray.length, equalTo(expectedNumSegs));
-                    assertThat(Arrays.stream(fullSegArray).mapToLong(MemorySegment::byteSize).sum(), equalTo((long) size));
-                    assertThat(IndexInputUtils.offset(input), equalTo(0L));
-
-                    var partialSlice = input.slice("partial slice", 1, size - 1);
-                    assertThat(IndexInputUtils.offset(partialSlice), equalTo(1L));
-                    var msseg1 = segmentSlice(partialSlice, 0, 24);
-                    for (int i = 0; i < 24; i++) {
-                        assertThat(msseg1.get(ValueLayout.JAVA_BYTE, i), equalTo((byte) (i + 1)));
-                    }
-
-                    var fullMSSlice = input.slice("start at full MemorySegment slice", 32, size - 32);
-                    var segArray2 = IndexInputUtils.segmentArray(fullMSSlice);
-                    assertThat(Arrays.stream(segArray2).mapToLong(MemorySegment::byteSize).sum(), equalTo((long) size - 32));
-                    assertThat(IndexInputUtils.offset(fullMSSlice), equalTo(0L));
-                    var msseg2 = segmentSlice(fullMSSlice, 0, 32);
-                    for (int i = 0; i < 32; i++) {
-                        assertThat(msseg2.get(ValueLayout.JAVA_BYTE, i), equalTo((byte) (i + 32)));
-                    }
-
-                    // slice of a slice
-                    var sliceSlice = partialSlice.slice("slice of a slice", 1, partialSlice.length() - 1);
-                    var segSliceSliceArray = IndexInputUtils.segmentArray(sliceSlice);
-                    assertThat(Arrays.stream(segSliceSliceArray).mapToLong(MemorySegment::byteSize).sum(), equalTo((long) size));
-                    assertThat(IndexInputUtils.offset(sliceSlice), equalTo(2L));
-                    var msseg3 = segmentSlice(sliceSlice, 0, 28);
-                    for (int i = 0; i < 28; i++) {
-                        assertThat(msseg3.get(ValueLayout.JAVA_BYTE, i), equalTo((byte) (i + 2)));
-                    }
-
-                }
-            }
-        }
-    }
-
-    static MemorySegment segmentSlice(IndexInput input, long pos, int length) {
-        if (IndexInputUtils.MS_MSINDEX_CLS.isAssignableFrom(input.getClass())) {
-            pos += IndexInputUtils.offset(input);
-        }
-        final int si = (int) (pos >> IndexInputUtils.chunkSizePower(input));
-        final MemorySegment seg = IndexInputUtils.segmentArray(input)[si];
-        long offset = pos & IndexInputUtils.chunkSizeMask(input);
-        if (checkIndex(offset + length, seg.byteSize() + 1)) {
-            return seg.asSlice(offset, length);
-        }
-        return null;
-    }
-
-    static boolean checkIndex(long index, long length) {
-        return index >= 0 && index < length;
-    }
-
-    static final int TIMES = 100; // a loop iteration times
-
-}