|
@@ -19,7 +19,6 @@ import org.elasticsearch.common.compress.CompressedXContent;
|
|
|
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
|
|
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
-import org.elasticsearch.core.Tuple;
|
|
|
import org.elasticsearch.index.mapper.MapperService;
|
|
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
|
|
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
|
|
@@ -38,6 +37,7 @@ import java.util.ArrayList;
|
|
|
import java.util.List;
|
|
|
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
|
+import static org.hamcrest.Matchers.equalTo;
|
|
|
import static org.hamcrest.Matchers.instanceOf;
|
|
|
|
|
|
abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase<KnnVectorQueryBuilder> {
|
|
@@ -84,8 +84,8 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
|
|
int numCands = randomIntBetween(1, 1000);
|
|
|
|
|
|
KnnVectorQueryBuilder queryBuilder = switch (elementType()) {
|
|
|
- case BYTE -> new KnnVectorQueryBuilder(fieldName, byteVector, numCands);
|
|
|
- case FLOAT -> new KnnVectorQueryBuilder(fieldName, vector, numCands);
|
|
|
+ case BYTE -> new KnnVectorQueryBuilder(fieldName, byteVector, numCands, randomBoolean() ? null : randomFloat());
|
|
|
+ case FLOAT -> new KnnVectorQueryBuilder(fieldName, vector, numCands, randomBoolean() ? null : randomFloat());
|
|
|
};
|
|
|
|
|
|
if (randomBoolean()) {
|
|
@@ -102,9 +102,19 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
|
|
|
|
|
@Override
|
|
|
protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException {
|
|
|
- switch (elementType()) {
|
|
|
- case FLOAT -> assertTrue(query instanceof KnnFloatVectorQuery);
|
|
|
- case BYTE -> assertTrue(query instanceof KnnByteVectorQuery);
|
|
|
+ if (queryBuilder.getVectorSimilarity() != null) {
|
|
|
+ assertTrue(query instanceof VectorSimilarityQuery);
|
|
|
+ Query knnQuery = ((VectorSimilarityQuery) query).getInnerKnnQuery();
|
|
|
+ assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity()));
|
|
|
+ switch (elementType()) {
|
|
|
+ case FLOAT -> assertTrue(knnQuery instanceof KnnFloatVectorQuery);
|
|
|
+ case BYTE -> assertTrue(knnQuery instanceof KnnByteVectorQuery);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ switch (elementType()) {
|
|
|
+ case FLOAT -> assertTrue(query instanceof KnnFloatVectorQuery);
|
|
|
+ case BYTE -> assertTrue(query instanceof KnnByteVectorQuery);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
BooleanQuery.Builder builder = new BooleanQuery.Builder();
|
|
@@ -118,19 +128,22 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
|
|
case BYTE -> new KnnByteVectorQuery(VECTOR_FIELD, queryBuilder.getByteQueryVector(), queryBuilder.numCands(), filterQuery);
|
|
|
case FLOAT -> new KnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector(), queryBuilder.numCands(), filterQuery);
|
|
|
};
|
|
|
+ if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) {
|
|
|
+ query = vectorSimilarityQuery.getInnerKnnQuery();
|
|
|
+ }
|
|
|
assertEquals(query, knnVectorQueryBuilt);
|
|
|
}
|
|
|
|
|
|
public void testWrongDimension() {
|
|
|
SearchExecutionContext context = createSearchExecutionContext();
|
|
|
- KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10);
|
|
|
+ KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10, null);
|
|
|
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
|
|
|
assertThat(e.getMessage(), containsString("the query vector has a different dimension [2] than the index vectors [3]"));
|
|
|
}
|
|
|
|
|
|
public void testNonexistentField() {
|
|
|
SearchExecutionContext context = createSearchExecutionContext();
|
|
|
- KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 10);
|
|
|
+ KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 10, null);
|
|
|
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
|
|
|
assertThat(e.getMessage(), containsString("field [nonexistent] does not exist in the mapping"));
|
|
|
}
|
|
@@ -140,7 +153,8 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
|
|
KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(
|
|
|
AbstractBuilderTestCase.KEYWORD_FIELD_NAME,
|
|
|
new float[] { 1.0f, 1.0f, 1.0f },
|
|
|
- 10
|
|
|
+ 10,
|
|
|
+ null
|
|
|
);
|
|
|
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
|
|
|
assertThat(e.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields"));
|
|
@@ -148,7 +162,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
|
|
|
|
|
@Override
|
|
|
public void testValidOutput() {
|
|
|
- KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 10);
|
|
|
+ KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 10, null);
|
|
|
String expected = """
|
|
|
{
|
|
|
"knn" : {
|
|
@@ -169,7 +183,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
|
|
SearchExecutionContext context = createSearchExecutionContext();
|
|
|
context.setAllowUnmappedFields(true);
|
|
|
TermQueryBuilder termQuery = new TermQueryBuilder("unmapped_field", 42);
|
|
|
- KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, VECTOR_DIMENSION);
|
|
|
+ KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, VECTOR_DIMENSION, null);
|
|
|
query.addFilterQuery(termQuery);
|
|
|
|
|
|
IllegalStateException e = expectThrows(IllegalStateException.class, () -> query.toQuery(context));
|
|
@@ -179,7 +193,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
|
|
assertThat(rewrittenQuery, instanceOf(MatchNoneQueryBuilder.class));
|
|
|
}
|
|
|
|
|
|
- public void testBWCVersionSerialization() throws IOException {
|
|
|
+ public void testBWCVersionSerializationFilters() throws IOException {
|
|
|
float[] bwcFloat = new float[VECTOR_DIMENSION];
|
|
|
KnnVectorQueryBuilder query = createTestQueryBuilder();
|
|
|
if (query.queryVector() != null) {
|
|
@@ -189,47 +203,70 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
|
|
|
bwcFloat[i] = query.getByteQueryVector()[i];
|
|
|
}
|
|
|
}
|
|
|
- KnnVectorQueryBuilder queryWithNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), bwcFloat, query.numCands()).queryName(
|
|
|
+
|
|
|
+ KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), bwcFloat, query.numCands(), null).queryName(
|
|
|
query.queryName()
|
|
|
).boost(query.boost());
|
|
|
|
|
|
- KnnVectorQueryBuilder queryNoByteQuery = new KnnVectorQueryBuilder(query.getFieldName(), bwcFloat, query.numCands()).queryName(
|
|
|
- query.queryName()
|
|
|
- ).boost(query.boost()).addFilterQueries(query.filterQueries());
|
|
|
+ TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween(
|
|
|
+ random(),
|
|
|
+ TransportVersion.V_8_0_0,
|
|
|
+ TransportVersion.V_8_1_0
|
|
|
+ );
|
|
|
+
|
|
|
+ assertBWCSerialization(query, queryNoFilters, beforeFilterVersion);
|
|
|
+ }
|
|
|
|
|
|
- TransportVersion newVersion = TransportVersionUtils.randomVersionBetween(
|
|
|
+ public void testBWCVersionSerializationSimilarity() throws IOException {
|
|
|
+ KnnVectorQueryBuilder query = createTestQueryBuilder();
|
|
|
+ KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(
|
|
|
+ query.getFieldName(),
|
|
|
+ query.getByteQueryVector(),
|
|
|
+ query.queryVector(),
|
|
|
+ query.numCands(),
|
|
|
+ null
|
|
|
+ ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries());
|
|
|
+ TransportVersion beforeSimilarity = TransportVersionUtils.randomVersionBetween(
|
|
|
random(),
|
|
|
TransportVersion.V_8_7_0,
|
|
|
- TransportVersion.CURRENT
|
|
|
+ TransportVersion.V_8_8_0
|
|
|
);
|
|
|
+ assertBWCSerialization(query, queryNoSimilarity, beforeSimilarity);
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testBWCVersionSerializationByteQuery() throws IOException {
|
|
|
+ float[] bwcFloat = new float[VECTOR_DIMENSION];
|
|
|
+ KnnVectorQueryBuilder query = createTestQueryBuilder();
|
|
|
+ if (query.queryVector() != null) {
|
|
|
+ bwcFloat = query.queryVector();
|
|
|
+ } else {
|
|
|
+ for (int i = 0; i < query.getByteQueryVector().length; i++) {
|
|
|
+ bwcFloat[i] = query.getByteQueryVector()[i];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ KnnVectorQueryBuilder queryNoByteQuery = new KnnVectorQueryBuilder(query.getFieldName(), bwcFloat, query.numCands(), null)
|
|
|
+ .queryName(query.queryName())
|
|
|
+ .boost(query.boost())
|
|
|
+ .addFilterQueries(query.filterQueries());
|
|
|
+
|
|
|
TransportVersion beforeByteQueryVersion = TransportVersionUtils.randomVersionBetween(
|
|
|
random(),
|
|
|
TransportVersion.V_8_2_0,
|
|
|
TransportVersion.V_8_6_0
|
|
|
);
|
|
|
- TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween(
|
|
|
- random(),
|
|
|
- TransportVersion.V_8_0_0,
|
|
|
- TransportVersion.V_8_1_0
|
|
|
- );
|
|
|
+ assertBWCSerialization(query, queryNoByteQuery, beforeByteQueryVersion);
|
|
|
+ }
|
|
|
|
|
|
- assertSerialization(query, newVersion);
|
|
|
- assertSerialization(queryNoByteQuery, beforeByteQueryVersion);
|
|
|
- assertSerialization(queryWithNoFilters, beforeFilterVersion);
|
|
|
-
|
|
|
- for (var tuple : List.of(
|
|
|
- Tuple.tuple(beforeByteQueryVersion, queryNoByteQuery),
|
|
|
- Tuple.tuple(beforeFilterVersion, queryWithNoFilters)
|
|
|
- )) {
|
|
|
- try (BytesStreamOutput output = new BytesStreamOutput()) {
|
|
|
- output.setTransportVersion(tuple.v1());
|
|
|
- output.writeNamedWriteable(query);
|
|
|
- try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) {
|
|
|
- in.setTransportVersion(tuple.v1());
|
|
|
- KnnVectorQueryBuilder deserializedQuery = (KnnVectorQueryBuilder) in.readNamedWriteable(QueryBuilder.class);
|
|
|
- assertEquals(tuple.v2(), deserializedQuery);
|
|
|
- assertEquals(tuple.v2().hashCode(), deserializedQuery.hashCode());
|
|
|
- }
|
|
|
+ private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery, TransportVersion version) throws IOException {
|
|
|
+ assertSerialization(bwcQuery, version);
|
|
|
+ try (BytesStreamOutput output = new BytesStreamOutput()) {
|
|
|
+ output.setTransportVersion(version);
|
|
|
+ output.writeNamedWriteable(newQuery);
|
|
|
+ try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) {
|
|
|
+ in.setTransportVersion(version);
|
|
|
+ KnnVectorQueryBuilder deserializedQuery = (KnnVectorQueryBuilder) in.readNamedWriteable(QueryBuilder.class);
|
|
|
+ assertEquals(bwcQuery, deserializedQuery);
|
|
|
+ assertEquals(bwcQuery.hashCode(), deserializedQuery.hashCode());
|
|
|
}
|
|
|
}
|
|
|
}
|