|
@@ -9,8 +9,8 @@
|
|
|
|
|
|
package org.elasticsearch.index.mapper.vectors;
|
|
|
|
|
|
-import org.apache.lucene.search.KnnByteVectorQuery;
|
|
|
import org.apache.lucene.search.KnnFloatVectorQuery;
|
|
|
+import org.apache.lucene.search.PatienceKnnVectorQuery;
|
|
|
import org.apache.lucene.search.Query;
|
|
|
import org.apache.lucene.search.join.BitSetProducer;
|
|
|
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
|
|
@@ -234,7 +234,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
producer,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
);
|
|
|
if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) {
|
|
|
query = rescoreKnnVectorQuery.innerQuery();
|
|
@@ -242,7 +243,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
if (field.getIndexOptions().isFlat()) {
|
|
|
assertThat(query, instanceOf(DiversifyingParentBlockQuery.class));
|
|
|
} else {
|
|
|
- assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class));
|
|
|
+ assertTrue(query instanceof DiversifyingChildrenFloatKnnVectorQuery || query instanceof PatienceKnnVectorQuery);
|
|
|
}
|
|
|
}
|
|
|
{
|
|
@@ -272,12 +273,13 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
producer,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
);
|
|
|
if (field.getIndexOptions().isFlat()) {
|
|
|
assertThat(query, instanceOf(DiversifyingParentBlockQuery.class));
|
|
|
} else {
|
|
|
- assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
|
|
|
+ assertTrue(query instanceof DiversifyingChildrenByteKnnVectorQuery || query instanceof PatienceKnnVectorQuery);
|
|
|
}
|
|
|
|
|
|
vectorData = new VectorData(floatQueryVector, null);
|
|
@@ -289,12 +291,13 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
producer,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
);
|
|
|
if (field.getIndexOptions().isFlat()) {
|
|
|
assertThat(query, instanceOf(DiversifyingParentBlockQuery.class));
|
|
|
} else {
|
|
|
- assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
|
|
|
+ assertTrue(query instanceof DiversifyingChildrenByteKnnVectorQuery || query instanceof PatienceKnnVectorQuery);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -366,7 +369,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
)
|
|
|
);
|
|
|
assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]"));
|
|
@@ -396,7 +400,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
)
|
|
|
);
|
|
|
assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors."));
|
|
@@ -422,7 +427,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
)
|
|
|
);
|
|
|
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
|
|
@@ -453,7 +459,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
);
|
|
|
if (query instanceof RescoreKnnVectorQuery rescoreKnnVectorQuery) {
|
|
|
query = rescoreKnnVectorQuery.innerQuery();
|
|
@@ -461,7 +468,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
if (fieldWith4096dims.getIndexOptions().isFlat()) {
|
|
|
assertThat(query, instanceOf(DenseVectorQuery.Floats.class));
|
|
|
} else {
|
|
|
- assertThat(query, instanceOf(KnnFloatVectorQuery.class));
|
|
|
+ assertTrue(query instanceof KnnFloatVectorQuery || query instanceof PatienceKnnVectorQuery);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -490,12 +497,13 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
);
|
|
|
if (fieldWith4096dims.getIndexOptions().isFlat()) {
|
|
|
assertThat(query, instanceOf(DenseVectorQuery.Bytes.class));
|
|
|
} else {
|
|
|
- assertThat(query, instanceOf(KnnByteVectorQuery.class));
|
|
|
+ assertTrue(query instanceof ESKnnByteVectorQuery || query instanceof PatienceKnnVectorQuery);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -522,7 +530,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
)
|
|
|
);
|
|
|
assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]"));
|
|
@@ -548,7 +557,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
)
|
|
|
);
|
|
|
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
|
|
@@ -563,7 +573,8 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
)
|
|
|
);
|
|
|
assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
|
|
@@ -591,24 +602,33 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
);
|
|
|
|
|
|
if (elementType == BYTE) {
|
|
|
if (nonQuantizedField.getIndexOptions().isFlat()) {
|
|
|
assertThat(knnQuery, instanceOf(DenseVectorQuery.Bytes.class));
|
|
|
} else {
|
|
|
- ESKnnByteVectorQuery esKnnQuery = (ESKnnByteVectorQuery) knnQuery;
|
|
|
- assertThat(esKnnQuery.getK(), is(100));
|
|
|
- assertThat(esKnnQuery.kParam(), is(10));
|
|
|
+ if (knnQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) {
|
|
|
+ assertThat(patienceKnnVectorQuery.getK(), is(100));
|
|
|
+ } else {
|
|
|
+ ESKnnByteVectorQuery knnByteVectorQuery = (ESKnnByteVectorQuery) knnQuery;
|
|
|
+ assertThat(knnByteVectorQuery.getK(), is(100));
|
|
|
+ assertThat(knnByteVectorQuery.kParam(), is(10));
|
|
|
+ }
|
|
|
}
|
|
|
} else {
|
|
|
if (nonQuantizedField.getIndexOptions().isFlat()) {
|
|
|
assertThat(knnQuery, instanceOf(DenseVectorQuery.Floats.class));
|
|
|
} else {
|
|
|
- ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) knnQuery;
|
|
|
- assertThat(esKnnQuery.getK(), is(100));
|
|
|
- assertThat(esKnnQuery.kParam(), is(10));
|
|
|
+ if (knnQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) {
|
|
|
+ assertThat(patienceKnnVectorQuery.getK(), is(100));
|
|
|
+ } else {
|
|
|
+ ESKnnFloatVectorQuery knnFloatVectorQuery = (ESKnnFloatVectorQuery) knnQuery;
|
|
|
+ assertThat(knnFloatVectorQuery.getK(), is(100));
|
|
|
+ assertThat(knnFloatVectorQuery.kParam(), is(10));
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -655,12 +675,13 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
);
|
|
|
if (fieldType.getIndexOptions().isFlat()) {
|
|
|
assertThat(query, instanceOf(DenseVectorQuery.Floats.class));
|
|
|
} else {
|
|
|
- assertThat(query, instanceOf(ESKnnFloatVectorQuery.class));
|
|
|
+ assertTrue(query instanceof ESKnnFloatVectorQuery || query instanceof PatienceKnnVectorQuery);
|
|
|
}
|
|
|
|
|
|
// verify we can override a `0` to a positive number
|
|
@@ -683,20 +704,23 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
);
|
|
|
assertTrue(query instanceof RescoreKnnVectorQuery);
|
|
|
- assertThat(((RescoreKnnVectorQuery) query).k(), equalTo(10));
|
|
|
- ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) ((RescoreKnnVectorQuery) query).innerQuery();
|
|
|
- assertThat(esKnnQuery.kParam(), equalTo(20));
|
|
|
-
|
|
|
+ RescoreKnnVectorQuery rescoreKnnVectorQuery = (RescoreKnnVectorQuery) query;
|
|
|
+ assertThat(rescoreKnnVectorQuery.k(), equalTo(10));
|
|
|
+ Query innerQuery = rescoreKnnVectorQuery.innerQuery();
|
|
|
+ if (innerQuery instanceof ESKnnFloatVectorQuery esKnnFloatVectorQuery) {
|
|
|
+ assertThat(esKnnFloatVectorQuery.kParam(), equalTo(20));
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
public void testFilterSearchThreshold() {
|
|
|
List<Tuple<DenseVectorFieldMapper.ElementType, Function<Query, KnnSearchStrategy>>> cases = List.of(
|
|
|
- Tuple.tuple(FLOAT, q -> ((ESKnnFloatVectorQuery) q).getStrategy()),
|
|
|
- Tuple.tuple(BYTE, q -> ((ESKnnByteVectorQuery) q).getStrategy()),
|
|
|
- Tuple.tuple(BIT, q -> ((ESKnnByteVectorQuery) q).getStrategy())
|
|
|
+ Tuple.tuple(FLOAT, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnFloatVectorQuery) q).getStrategy()),
|
|
|
+ Tuple.tuple(BYTE, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnByteVectorQuery) q).getStrategy()),
|
|
|
+ Tuple.tuple(BIT, q -> q instanceof PatienceKnnVectorQuery ? null : ((ESKnnByteVectorQuery) q).getStrategy())
|
|
|
);
|
|
|
for (var tuple : cases) {
|
|
|
DenseVectorFieldType fieldType = new DenseVectorFieldType(
|
|
@@ -720,25 +744,31 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- DenseVectorFieldMapper.FilterHeuristic.FANOUT
|
|
|
+ DenseVectorFieldMapper.FilterHeuristic.FANOUT,
|
|
|
+ randomBoolean()
|
|
|
);
|
|
|
KnnSearchStrategy strategy = tuple.v2().apply(query);
|
|
|
- assertTrue(strategy instanceof KnnSearchStrategy.Hnsw);
|
|
|
- assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(0));
|
|
|
-
|
|
|
- query = fieldType.createKnnQuery(
|
|
|
- VectorData.fromFloats(new float[] { 1, 4, 10 }),
|
|
|
- 10,
|
|
|
- 100,
|
|
|
- 0f,
|
|
|
- null,
|
|
|
- null,
|
|
|
- null,
|
|
|
- DenseVectorFieldMapper.FilterHeuristic.ACORN
|
|
|
- );
|
|
|
- strategy = tuple.v2().apply(query);
|
|
|
- assertTrue(strategy instanceof KnnSearchStrategy.Hnsw);
|
|
|
- assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(60));
|
|
|
+ if (strategy != null) {
|
|
|
+ assertTrue(strategy instanceof KnnSearchStrategy.Hnsw);
|
|
|
+ assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(0));
|
|
|
+
|
|
|
+ query = fieldType.createKnnQuery(
|
|
|
+ VectorData.fromFloats(new float[] { 1, 4, 10 }),
|
|
|
+ 10,
|
|
|
+ 100,
|
|
|
+ 0f,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ DenseVectorFieldMapper.FilterHeuristic.ACORN,
|
|
|
+ randomBoolean()
|
|
|
+ );
|
|
|
+ strategy = tuple.v2().apply(query);
|
|
|
+ if (strategy != null) {
|
|
|
+ assertThat(strategy, instanceOf(KnnSearchStrategy.Hnsw.class));
|
|
|
+ assertThat(((KnnSearchStrategy.Hnsw) strategy).filteredSearchThreshold(), equalTo(60));
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -759,12 +789,18 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
|
|
|
null,
|
|
|
null,
|
|
|
null,
|
|
|
- randomFrom(DenseVectorFieldMapper.FilterHeuristic.values())
|
|
|
+ randomFrom(DenseVectorFieldMapper.FilterHeuristic.values()),
|
|
|
+ randomBoolean()
|
|
|
);
|
|
|
RescoreKnnVectorQuery rescoreQuery = (RescoreKnnVectorQuery) query;
|
|
|
- ESKnnFloatVectorQuery esKnnQuery = (ESKnnFloatVectorQuery) rescoreQuery.innerQuery();
|
|
|
- assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults));
|
|
|
- assertThat("Unexpected k parameter", esKnnQuery.kParam(), equalTo(expectedK));
|
|
|
- assertThat("Unexpected candidates", esKnnQuery.getK(), equalTo(expectedCandidates));
|
|
|
+ Query innerQuery = rescoreQuery.innerQuery();
|
|
|
+ if (innerQuery instanceof PatienceKnnVectorQuery patienceKnnVectorQuery) {
|
|
|
+ assertThat("Unexpected candidates", patienceKnnVectorQuery.getK(), equalTo(expectedCandidates));
|
|
|
+ } else {
|
|
|
+ ESKnnFloatVectorQuery knnQuery = (ESKnnFloatVectorQuery) innerQuery;
|
|
|
+ assertThat("Unexpected total results", rescoreQuery.k(), equalTo(expectedResults));
|
|
|
+ assertThat("Unexpected candidates", knnQuery.getK(), equalTo(expectedCandidates));
|
|
|
+ assertThat("Unexpected k parameter", knnQuery.kParam(), equalTo(expectedK));
|
|
|
+ }
|
|
|
}
|
|
|
}
|