Browse Source

Fix NPE in flat_bbq scorer when all vectors are missing (#129548)

It is possible to get all the way down to the knn format reader and
there be no vectors in the index. 

This execution path is possible if utilizing nested queries (which
bypasses the higher level checks in
`KnnFloatVectorQuery#approximateSearch`).

bbq_flat should check for the existence of vectors before attempting to
create the scorer.
Benjamin Trent 4 months ago
parent
commit
80667d0c8a

+ 5 - 0
docs/changelog/129548.yaml

@@ -0,0 +1,5 @@
+pr: 129548
+summary: Fix NPE in `flat_bbq` scorer when all vectors are missing
+area: Vector Search
+type: bug
+issues: []

+ 3 - 0
server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java

@@ -59,6 +59,9 @@ class ES816BinaryFlatVectorsScorer implements FlatVectorsScorer {
         float[] target
     ) throws IOException {
         if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) {
+            assert binarizedVectors.getQuantizer() != null
+                : "BinarizedByteVectorValues must have a quantizer for ES816BinaryFlatVectorsScorer";
+            assert binarizedVectors.size() > 0 : "BinarizedByteVectorValues must have at least one vector for ES816BinaryFlatVectorsScorer";
             BinaryQuantizer quantizer = binarizedVectors.getQuantizer();
             float[] centroid = binarizedVectors.getCentroid();
             // FIXME: precompute this once?

+ 1 - 1
server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsReader.java

@@ -160,7 +160,7 @@ public class ES816BinaryQuantizedVectorsReader extends FlatVectorsReader impleme
     @Override
     public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
         FieldEntry fi = fields.get(field);
-        if (fi == null) {
+        if (fi == null || fi.size() == 0) {
             return null;
         }
         return vectorScorer.getRandomVectorScorer(

+ 3 - 0
server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java

@@ -66,6 +66,9 @@ public class ES818BinaryFlatVectorsScorer implements FlatVectorsScorer {
         float[] target
     ) throws IOException {
         if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) {
+            assert binarizedVectors.getQuantizer() != null
+                : "BinarizedByteVectorValues must have a quantizer for ES816BinaryFlatVectorsScorer";
+            assert binarizedVectors.size() > 0 : "BinarizedByteVectorValues must have at least one vector for ES816BinaryFlatVectorsScorer";
             OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer();
             float[] centroid = binarizedVectors.getCentroid();
             // We make a copy as the quantization process mutates the input

+ 1 - 1
server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java

@@ -160,7 +160,7 @@ public class ES818BinaryQuantizedVectorsReader extends FlatVectorsReader impleme
     @Override
     public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
         FieldEntry fi = fields.get(field);
-        if (fi == null) {
+        if (fi == null || fi.size() == 0) {
             return null;
         }
         return vectorScorer.getRandomVectorScorer(

+ 65 - 0
server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java

@@ -25,6 +25,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat;
 import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
 import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
 import org.apache.lucene.document.KnnFloatVectorField;
 import org.apache.lucene.index.CodecReader;
 import org.apache.lucene.index.DirectoryReader;
@@ -34,12 +35,21 @@ import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.KnnVectorValues;
 import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.SoftDeletesRetentionMergePolicy;
+import org.apache.lucene.index.Term;
 import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.FieldExistsQuery;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.MatchAllDocsQuery;
 import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.search.join.BitSetProducer;
+import org.apache.lucene.search.join.CheckJoinIndex;
+import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
+import org.apache.lucene.search.join.QueryBitSetProducer;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
 import org.apache.lucene.tests.util.TestUtil;
@@ -48,6 +58,9 @@ import org.elasticsearch.index.codec.vectors.BQVectorUtils;
 import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
 import java.util.Locale;
 
 import static java.lang.String.format;
@@ -70,6 +83,58 @@ public class ES816BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat
         return codec;
     }
 
+    static String encodeInts(int[] i) {
+        return Arrays.toString(i);
+    }
+
+    static BitSetProducer parentFilter(IndexReader r) throws IOException {
+        // Create a filter that defines "parent" documents in the index
+        BitSetProducer parentsFilter = new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent")));
+        CheckJoinIndex.check(r, parentsFilter);
+        return parentsFilter;
+    }
+
+    Document makeParent(int[] children) {
+        Document parent = new Document();
+        parent.add(newStringField("docType", "_parent", Field.Store.NO));
+        parent.add(newStringField("id", encodeInts(children), Field.Store.YES));
+        return parent;
+    }
+
+    public void testEmptyDiversifiedChildSearch() throws Exception {
+        String fieldName = "field";
+        int dims = random().nextInt(4, 65);
+        float[] vector = randomVector(dims);
+        VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+        try (Directory d = newDirectory()) {
+            IndexWriterConfig iwc = newIndexWriterConfig().setCodec(codec);
+            iwc.setMergePolicy(new SoftDeletesRetentionMergePolicy("soft_delete", MatchAllDocsQuery::new, iwc.getMergePolicy()));
+            try (IndexWriter w = new IndexWriter(d, iwc)) {
+                List<Document> toAdd = new ArrayList<>();
+                for (int j = 1; j <= 5; j++) {
+                    Document doc = new Document();
+                    doc.add(new KnnFloatVectorField(fieldName, vector, similarityFunction));
+                    doc.add(newStringField("id", Integer.toString(j), Field.Store.YES));
+                    toAdd.add(doc);
+                }
+                toAdd.add(makeParent(new int[] { 1, 2, 3, 4, 5 }));
+                w.addDocuments(toAdd);
+                w.addDocuments(List.of(makeParent(new int[] { 6, 7, 8, 9, 10 })));
+                w.deleteDocuments(new FieldExistsQuery(fieldName), new TermQuery(new Term("id", encodeInts(new int[] { 1, 2, 3, 4, 5 }))));
+                w.flush();
+                w.commit();
+                w.forceMerge(1);
+                try (IndexReader reader = DirectoryReader.open(w)) {
+                    IndexSearcher searcher = new IndexSearcher(reader);
+                    BitSetProducer parentFilter = parentFilter(searcher.getIndexReader());
+                    Query query = new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vector, null, 1, parentFilter);
+                    assertTrue(searcher.search(query, 1).scoreDocs.length == 0);
+                }
+            }
+
+        }
+    }
+
     public void testSearch() throws Exception {
         String fieldName = "field";
         int numVectors = random().nextInt(99, 500);

+ 65 - 0
server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java

@@ -25,6 +25,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat;
 import org.apache.lucene.codecs.KnnVectorsReader;
 import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
 import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
 import org.apache.lucene.document.KnnFloatVectorField;
 import org.apache.lucene.index.CodecReader;
 import org.apache.lucene.index.DirectoryReader;
@@ -34,13 +35,22 @@ import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.KnnVectorValues;
 import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.SoftDeletesRetentionMergePolicy;
+import org.apache.lucene.index.Term;
 import org.apache.lucene.index.VectorSimilarityFunction;
 import org.apache.lucene.misc.store.DirectIODirectory;
+import org.apache.lucene.search.FieldExistsQuery;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.MatchAllDocsQuery;
 import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.search.join.BitSetProducer;
+import org.apache.lucene.search.join.CheckJoinIndex;
+import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
+import org.apache.lucene.search.join.QueryBitSetProducer;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.FSDirectory;
 import org.apache.lucene.store.IOContext;
@@ -64,6 +74,9 @@ import org.elasticsearch.test.IndexSettingsModule;
 import java.io.IOException;
 import java.nio.file.Files;
 import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
 import java.util.Locale;
 import java.util.OptionalLong;
 
@@ -87,6 +100,58 @@ public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat
         return codec;
     }
 
+    static String encodeInts(int[] i) {
+        return Arrays.toString(i);
+    }
+
+    static BitSetProducer parentFilter(IndexReader r) throws IOException {
+        // Create a filter that defines "parent" documents in the index
+        BitSetProducer parentsFilter = new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent")));
+        CheckJoinIndex.check(r, parentsFilter);
+        return parentsFilter;
+    }
+
+    Document makeParent(int[] children) {
+        Document parent = new Document();
+        parent.add(newStringField("docType", "_parent", Field.Store.NO));
+        parent.add(newStringField("id", encodeInts(children), Field.Store.YES));
+        return parent;
+    }
+
+    public void testEmptyDiversifiedChildSearch() throws Exception {
+        String fieldName = "field";
+        int dims = random().nextInt(4, 65);
+        float[] vector = randomVector(dims);
+        VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
+        try (Directory d = newDirectory()) {
+            IndexWriterConfig iwc = newIndexWriterConfig().setCodec(codec);
+            iwc.setMergePolicy(new SoftDeletesRetentionMergePolicy("soft_delete", MatchAllDocsQuery::new, iwc.getMergePolicy()));
+            try (IndexWriter w = new IndexWriter(d, iwc)) {
+                List<Document> toAdd = new ArrayList<>();
+                for (int j = 1; j <= 5; j++) {
+                    Document doc = new Document();
+                    doc.add(new KnnFloatVectorField(fieldName, vector, similarityFunction));
+                    doc.add(newStringField("id", Integer.toString(j), Field.Store.YES));
+                    toAdd.add(doc);
+                }
+                toAdd.add(makeParent(new int[] { 1, 2, 3, 4, 5 }));
+                w.addDocuments(toAdd);
+                w.addDocuments(List.of(makeParent(new int[] { 6, 7, 8, 9, 10 })));
+                w.deleteDocuments(new FieldExistsQuery(fieldName), new TermQuery(new Term("id", encodeInts(new int[] { 1, 2, 3, 4, 5 }))));
+                w.flush();
+                w.commit();
+                w.forceMerge(1);
+                try (IndexReader reader = DirectoryReader.open(w)) {
+                    IndexSearcher searcher = new IndexSearcher(reader);
+                    BitSetProducer parentFilter = parentFilter(searcher.getIndexReader());
+                    Query query = new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vector, null, 1, parentFilter);
+                    assertTrue(searcher.search(query, 1).scoreDocs.length == 0);
+                }
+            }
+
+        }
+    }
+
     public void testSearch() throws Exception {
         String fieldName = "field";
         int numVectors = random().nextInt(99, 500);