|
@@ -25,6 +25,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat;
|
|
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
|
|
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
|
|
import org.apache.lucene.document.Document;
|
|
|
+import org.apache.lucene.document.Field;
|
|
|
import org.apache.lucene.document.KnnFloatVectorField;
|
|
|
import org.apache.lucene.index.CodecReader;
|
|
|
import org.apache.lucene.index.DirectoryReader;
|
|
@@ -34,13 +35,22 @@ import org.apache.lucene.index.IndexWriter;
|
|
|
import org.apache.lucene.index.IndexWriterConfig;
|
|
|
import org.apache.lucene.index.KnnVectorValues;
|
|
|
import org.apache.lucene.index.LeafReader;
|
|
|
+import org.apache.lucene.index.SoftDeletesRetentionMergePolicy;
|
|
|
+import org.apache.lucene.index.Term;
|
|
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
|
|
import org.apache.lucene.misc.store.DirectIODirectory;
|
|
|
+import org.apache.lucene.search.FieldExistsQuery;
|
|
|
import org.apache.lucene.search.IndexSearcher;
|
|
|
import org.apache.lucene.search.KnnFloatVectorQuery;
|
|
|
+import org.apache.lucene.search.MatchAllDocsQuery;
|
|
|
import org.apache.lucene.search.Query;
|
|
|
+import org.apache.lucene.search.TermQuery;
|
|
|
import org.apache.lucene.search.TopDocs;
|
|
|
import org.apache.lucene.search.TotalHits;
|
|
|
+import org.apache.lucene.search.join.BitSetProducer;
|
|
|
+import org.apache.lucene.search.join.CheckJoinIndex;
|
|
|
+import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
|
|
|
+import org.apache.lucene.search.join.QueryBitSetProducer;
|
|
|
import org.apache.lucene.store.Directory;
|
|
|
import org.apache.lucene.store.FSDirectory;
|
|
|
import org.apache.lucene.store.IOContext;
|
|
@@ -64,6 +74,9 @@ import org.elasticsearch.test.IndexSettingsModule;
|
|
|
import java.io.IOException;
|
|
|
import java.nio.file.Files;
|
|
|
import java.nio.file.Path;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.Arrays;
|
|
|
+import java.util.List;
|
|
|
import java.util.Locale;
|
|
|
import java.util.OptionalLong;
|
|
|
|
|
@@ -87,6 +100,58 @@ public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormat
|
|
|
return codec;
|
|
|
}
|
|
|
|
|
|
+ static String encodeInts(int[] i) {
|
|
|
+ return Arrays.toString(i);
|
|
|
+ }
|
|
|
+
|
|
|
+ static BitSetProducer parentFilter(IndexReader r) throws IOException {
|
|
|
+ // Create a filter that defines "parent" documents in the index
|
|
|
+ BitSetProducer parentsFilter = new QueryBitSetProducer(new TermQuery(new Term("docType", "_parent")));
|
|
|
+ CheckJoinIndex.check(r, parentsFilter);
|
|
|
+ return parentsFilter;
|
|
|
+ }
|
|
|
+
|
|
|
+ Document makeParent(int[] children) {
|
|
|
+ Document parent = new Document();
|
|
|
+ parent.add(newStringField("docType", "_parent", Field.Store.NO));
|
|
|
+ parent.add(newStringField("id", encodeInts(children), Field.Store.YES));
|
|
|
+ return parent;
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testEmptyDiversifiedChildSearch() throws Exception {
|
|
|
+ String fieldName = "field";
|
|
|
+ int dims = random().nextInt(4, 65);
|
|
|
+ float[] vector = randomVector(dims);
|
|
|
+ VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
|
|
+ try (Directory d = newDirectory()) {
|
|
|
+ IndexWriterConfig iwc = newIndexWriterConfig().setCodec(codec);
|
|
|
+ iwc.setMergePolicy(new SoftDeletesRetentionMergePolicy("soft_delete", MatchAllDocsQuery::new, iwc.getMergePolicy()));
|
|
|
+ try (IndexWriter w = new IndexWriter(d, iwc)) {
|
|
|
+ List<Document> toAdd = new ArrayList<>();
|
|
|
+ for (int j = 1; j <= 5; j++) {
|
|
|
+ Document doc = new Document();
|
|
|
+ doc.add(new KnnFloatVectorField(fieldName, vector, similarityFunction));
|
|
|
+ doc.add(newStringField("id", Integer.toString(j), Field.Store.YES));
|
|
|
+ toAdd.add(doc);
|
|
|
+ }
|
|
|
+ toAdd.add(makeParent(new int[] { 1, 2, 3, 4, 5 }));
|
|
|
+ w.addDocuments(toAdd);
|
|
|
+ w.addDocuments(List.of(makeParent(new int[] { 6, 7, 8, 9, 10 })));
|
|
|
+ w.deleteDocuments(new FieldExistsQuery(fieldName), new TermQuery(new Term("id", encodeInts(new int[] { 1, 2, 3, 4, 5 }))));
|
|
|
+ w.flush();
|
|
|
+ w.commit();
|
|
|
+ w.forceMerge(1);
|
|
|
+ try (IndexReader reader = DirectoryReader.open(w)) {
|
|
|
+ IndexSearcher searcher = new IndexSearcher(reader);
|
|
|
+ BitSetProducer parentFilter = parentFilter(searcher.getIndexReader());
|
|
|
+ Query query = new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vector, null, 1, parentFilter);
|
|
|
+ assertTrue(searcher.search(query, 1).scoreDocs.length == 0);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testSearch() throws Exception {
|
|
|
String fieldName = "field";
|
|
|
int numVectors = random().nextInt(99, 500);
|