فهرست منبع

Fixing lucene snapshot compilation due to new VectorValue interface (#108516)

* Fixing lucene snapshot compilation due to new VectorValue interface

* fixing test compilation
Benjamin Trent 1 سال پیش
والد
کامیت
92c48e5270

+ 23 - 0
server/src/main/java/org/elasticsearch/index/codec/vectors/ES814ScalarQuantizedVectorsWriter.java

@@ -39,6 +39,7 @@ import org.apache.lucene.index.Sorter;
 import org.apache.lucene.index.VectorEncoding;
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.IndexInput;
 import org.apache.lucene.store.IndexOutput;
@@ -457,6 +458,8 @@ public final class ES814ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
                         docsWithField.cardinality(),
                         mergedQuantizationState, // TODO: bits
                         false, // TODO compress
+                        fieldInfo.getVectorSimilarityFunction(),
+                        vectorsScorer,
                         quantizationDataInput
                     )
                 );
@@ -649,6 +652,11 @@ public final class ES814ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
             curDoc = target;
             return docID();
         }
+
+        @Override
+        public VectorScorer scorer(float[] floats) throws IOException {
+            throw new UnsupportedOperationException();
+        }
     }
 
     private static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
@@ -770,6 +778,11 @@ public final class ES814ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
         public float getScoreCorrectionConstant() throws IOException {
             return current.values.getScoreCorrectionConstant();
         }
+
+        @Override
+        public VectorScorer vectorScorer(float[] floats) throws IOException {
+            throw new UnsupportedOperationException();
+        }
     }
 
     private static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
@@ -836,6 +849,11 @@ public final class ES814ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
             return doc;
         }
 
+        @Override
+        public VectorScorer vectorScorer(float[] floats) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
         private void quantize() throws IOException {
             if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
                 System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
@@ -932,5 +950,10 @@ public final class ES814ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
         public int advance(int target) throws IOException {
             return in.advance(target);
         }
+
+        @Override
+        public VectorScorer vectorScorer(float[] floats) throws IOException {
+            throw new UnsupportedOperationException();
+        }
     }
 }

+ 6 - 0
server/src/main/java/org/elasticsearch/index/mapper/vectors/DenormalizedCosineFloatVectorValues.java

@@ -10,6 +10,7 @@ package org.elasticsearch.index.mapper.vectors;
 
 import org.apache.lucene.index.FloatVectorValues;
 import org.apache.lucene.index.NumericDocValues;
+import org.apache.lucene.search.VectorScorer;
 
 import java.io.IOException;
 
@@ -63,6 +64,11 @@ public class DenormalizedCosineFloatVectorValues extends FloatVectorValues {
         return in.advance(target);
     }
 
+    @Override
+    public VectorScorer scorer(float[] floats) throws IOException {
+        return in.scorer(floats);
+    }
+
     public float magnitude() {
         return magnitude;
     }

+ 83 - 0
server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java

@@ -23,6 +23,7 @@ import org.apache.lucene.index.Terms;
 import org.apache.lucene.index.TermsEnum;
 import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.search.VectorScorer;
 import org.apache.lucene.search.suggest.document.CompletionTerms;
 import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.BytesRef;
@@ -481,6 +482,27 @@ class ExitableDirectoryReader extends FilterDirectoryReader {
             return in.vectorValue();
         }
 
+        @Override
+        public VectorScorer scorer(byte[] bytes) throws IOException {
+            VectorScorer scorer = in.scorer(bytes);
+            if (scorer == null) {
+                return null;
+            }
+            return new VectorScorer() {
+                private final DocIdSetIterator iterator = new ExitableDocSetIterator(scorer.iterator(), queryCancellation);
+
+                @Override
+                public float score() throws IOException {
+                    return scorer.score();
+                }
+
+                @Override
+                public DocIdSetIterator iterator() {
+                    return iterator;
+                }
+            };
+        }
+
         @Override
         public int docID() {
             return in.docID();
@@ -531,11 +553,72 @@ class ExitableDirectoryReader extends FilterDirectoryReader {
             return nextDoc;
         }
 
+        @Override
+        public VectorScorer scorer(float[] target) throws IOException {
+            VectorScorer scorer = in.scorer(target);
+            if (scorer == null) {
+                return null;
+            }
+            return new VectorScorer() {
+                private final DocIdSetIterator iterator = new ExitableDocSetIterator(scorer.iterator(), queryCancellation);
+
+                @Override
+                public float score() throws IOException {
+                    return scorer.score();
+                }
+
+                @Override
+                public DocIdSetIterator iterator() {
+                    return iterator;
+                }
+            };
+        }
+
         private void checkAndThrowWithSampling() {
             if ((calls++ & ExitableIntersectVisitor.MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) {
                 this.queryCancellation.checkCancelled();
             }
         }
+    }
 
+    private static class ExitableDocSetIterator extends DocIdSetIterator {
+        private int calls;
+        private final DocIdSetIterator in;
+        private final QueryCancellation queryCancellation;
+
+        private ExitableDocSetIterator(DocIdSetIterator in, QueryCancellation queryCancellation) {
+            this.in = in;
+            this.queryCancellation = queryCancellation;
+        }
+
+        @Override
+        public int docID() {
+            return in.docID();
+        }
+
+        @Override
+        public int advance(int target) throws IOException {
+            final int advance = in.advance(target);
+            checkAndThrowWithSampling();
+            return advance;
+        }
+
+        @Override
+        public int nextDoc() throws IOException {
+            final int nextDoc = in.nextDoc();
+            checkAndThrowWithSampling();
+            return nextDoc;
+        }
+
+        @Override
+        public long cost() {
+            return in.cost();
+        }
+
+        private void checkAndThrowWithSampling() {
+            if ((calls++ & ExitableIntersectVisitor.MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) {
+                this.queryCancellation.checkCancelled();
+            }
+        }
     }
 }

+ 11 - 0
server/src/test/java/org/elasticsearch/index/mapper/vectors/KnnDenseVectorScriptDocValuesTests.java

@@ -10,6 +10,7 @@ package org.elasticsearch.index.mapper.vectors;
 
 import org.apache.lucene.index.ByteVectorValues;
 import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.search.VectorScorer;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
 import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField;
 import org.elasticsearch.script.field.vectors.DenseVector;
@@ -230,6 +231,11 @@ public class KnnDenseVectorScriptDocValuesTests extends ESTestCase {
                 }
                 return index = target;
             }
+
+            @Override
+            public VectorScorer scorer(byte[] floats) throws IOException {
+                throw new UnsupportedOperationException();
+            }
         };
     }
 
@@ -270,6 +276,11 @@ public class KnnDenseVectorScriptDocValuesTests extends ESTestCase {
                 }
                 return index = target;
             }
+
+            @Override
+            public VectorScorer scorer(float[] floats) throws IOException {
+                throw new UnsupportedOperationException();
+            }
         };
     }
 }