Browse Source

Add new `similarity` field to `knn` clause in `_search` (#94828)

This adds a new parameter to `knn` that allows filtering nearest neighbor results that are outside a given similarity.

`num_candidates` and `k` are still required as this controls the nearest-neighbor vector search accuracy and exploration. For each shard the query will search `num_candidates` and only keep those that are within the provided `similarity` boundary, and then finally reduce to only the global top `k` as normal.

For example, when using the `l2_norm` indexed similarity value, this could be considered a `radius` post-filter on `knn`.

relates to: https://github.com/elastic/elasticsearch/issues/84929 && https://github.com/elastic/elasticsearch/pull/93574
Benjamin Trent 2 years ago
parent
commit
f23b906891
27 changed files with 942 additions and 155 deletions
  1. 23 0
      docs/changelog/94828.yaml
  2. 19 2
      docs/reference/rest-api/common-parms.asciidoc
  3. 63 17
      docs/reference/search/search-your-data/knn-search.asciidoc
  4. 6 2
      docs/reference/search/search.asciidoc
  5. 58 0
      rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml
  6. 58 0
      rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml
  7. 2 1
      server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java
  8. 9 4
      server/src/main/java/org/elasticsearch/common/lucene/search/function/MinScoreScorer.java
  9. 48 8
      server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java
  10. 52 12
      server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java
  11. 1 1
      server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java
  12. 46 8
      server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java
  13. 158 0
      server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java
  14. 10 10
      server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java
  15. 6 3
      server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java
  16. 2 2
      server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java
  17. 10 10
      server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java
  18. 15 6
      server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java
  19. 1 1
      server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java
  20. 1 1
      server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java
  21. 77 40
      server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java
  22. 36 19
      server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java
  23. 221 0
      server/src/test/java/org/elasticsearch/search/vectors/VectorSimilarityQueryTests.java
  24. 1 1
      test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java
  25. 15 3
      test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java
  26. 1 1
      x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java
  27. 3 3
      x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java

+ 23 - 0
docs/changelog/94828.yaml

@@ -0,0 +1,23 @@
+pr: 94828
+summary: Add new `similarity` field to `knn` clause in `_search`
+area: Search
+type: feature
+issues: []
+highlight:
+  title: Add new `similarity` field to `knn` clause in `_search`
+  body: |-
+    This adds a new parameter to `knn` that allows filtering nearest
+    neighbor results that are outside a given similarity.
+
+    `num_candidates` and `k` are still required as this controls the
+    nearest-neighbor vector search accuracy and exploration. For each shard
+    the query will search `num_candidates` and only keep those that are
+    within the provided `similarity` boundary, and then finally reduce to
+    only the global top `k` as normal.
+
+    For example, when using the `l2_norm` indexed similarity value, this
+    could be considered a `radius` post-filter on `knn`.
+
+    relates to: https://github.com/elastic/elasticsearch/issues/84929 &&
+    https://github.com/elastic/elasticsearch/pull/93574
+  notable: true

+ 19 - 2
docs/reference/rest-api/common-parms.asciidoc

@@ -130,8 +130,8 @@ shards. Statuses are:
     All shards are assigned.
 
   * `yellow`:
-    All primary shards are assigned, but one or more replica shards are 
-    unassigned. If a node in the cluster fails, some data could be unavailable 
+    All primary shards are assigned, but one or more replica shards are
+    unassigned. If a node in the cluster fails, some data could be unavailable
     until that node is repaired.
 
   * `red`:
@@ -581,6 +581,23 @@ Query vector. Must have the same number of dimensions as the vector field you
 are searching against.
 end::knn-query-vector[]
 
+tag::knn-similarity[]
+The minimum similarity required for a document to be considered a match. The similarity
+value calculated relates to the raw <<dense-vector-similarity, `similarity`>> used. Not the
+document score. The matched documents are then scored according to <<dense-vector-similarity, `similarity`>>
+and the provided `boost` is applied.
+
++
+--
+The `similarity` parameter is the direct vector similarity calculation.
+
+* `l2_norm`: also known as Euclidean, will include documents where the vector is within
+the `dims` dimensional hypersphere with radius `similarity` with origin at `query_vector`.
+* `cosine` & `dot_product`: Only return vectors where the cosine similarity or dot-product are at least the provided
+`similarity`.
+--
+end::knn-similarity[]
+
 tag::lenient[]
 `lenient`::
 (Optional, Boolean) If `true`, format-based query failures (such as providing

+ 63 - 17
docs/reference/search/search-your-data/knn-search.asciidoc

@@ -407,10 +407,10 @@ each score in the sum. In the example above, the scores will be calculated as
 score = 0.9 * match_score + 0.1 * knn_score
 ```
 
-The `knn` option can also be used with <<search-aggregations, `aggregations`>>. 
-In general, {es} computes aggregations over all documents that match the search. 
-So for approximate kNN search, aggregations are calculated on the top `k` 
-nearest documents. If the search also includes a `query`, then aggregations are 
+The `knn` option can also be used with <<search-aggregations, `aggregations`>>.
+In general, {es} computes aggregations over all documents that match the search.
+So for approximate kNN search, aggregations are calculated on the top `k`
+nearest documents. If the search also includes a `query`, then aggregations are
 calculated on the combined set of `knn` and `query` matches.
 
 [discrete]
@@ -419,30 +419,30 @@ calculated on the combined set of `knn` and `query` matches.
 
 experimental[]
 
-kNN search enables you to perform semantic search by using a previously deployed 
-{ml-docs}/ml-nlp-search-compare.html#ml-nlp-text-embedding[text embedding model]. 
+kNN search enables you to perform semantic search by using a previously deployed
+{ml-docs}/ml-nlp-search-compare.html#ml-nlp-text-embedding[text embedding model].
 Instead of literal matching on search terms, semantic search retrieves results
 based on the intent and the contextual meaning of a search query.
 
-Under the hood, the text embedding NLP model generates a dense vector from the 
-input query string called `model_text` you provide. Then, it is searched 
-against an index containing dense vectors created with the same text embedding 
+Under the hood, the text embedding NLP model generates a dense vector from the
+input query string called `model_text` you provide. Then, it is searched
+against an index containing dense vectors created with the same text embedding
 {ml} model. The search results are semantically similar as learned by the model.
 
 [IMPORTANT]
 =====================
 To perform semantic search:
 
-* you need an index that contains the dense vector representation of the input 
+* you need an index that contains the dense vector representation of the input
 data to search against,
 
-* you must use the same text embedding model for search that you used to create 
+* you must use the same text embedding model for search that you used to create
 the dense vectors from the input data,
 
 * the text embedding NLP model deployment must be started.
 =====================
 
-Reference the deployed text embedding model in the `query_vector_builder` object 
+Reference the deployed text embedding model in the `query_vector_builder` object
 and provide the search query as `model_text`:
 
 [source,js]
@@ -466,14 +466,14 @@ and provide the search query as `model_text`:
 // NOTCONSOLE
 
 <1> The {nlp} task to perform. It must be `text_embedding`.
-<2> The ID of the text embedding model to use to generate the dense vectors from 
-the query string. Use the same model that generated the embeddings from the 
+<2> The ID of the text embedding model to use to generate the dense vectors from
+the query string. Use the same model that generated the embeddings from the
 input text in the index you search against.
-<3> The query string from which the model generates the dense vector 
+<3> The query string from which the model generates the dense vector
 representation.
 
-For more information on how to deploy a trained model and use it to create text 
-embeddings, refer to this 
+For more information on how to deploy a trained model and use it to create text
+embeddings, refer to this
 {ml-docs}/ml-nlp-text-emb-vector-search-example.html[end-to-end example].
 
 
@@ -525,6 +525,52 @@ The scoring for a doc with the above configured boosts would be:
 score = 0.9 * match_score + 0.1 * knn_score_image-vector + 0.5 * knn_score_title-vector
 ```
 
+[discrete]
+==== Search kNN with expected similarity
+
+While kNN is a powerful tool, it always tries to return `k` nearest neighbors. Consequently, when using `knn` with
+a `filter`, you could filter out all relevant documents and only have irrelevant ones left to search. In that situation,
+`knn` will still do its best to return `k` nearest neighbors, even though those neighbors could be far away in the
+vector space.
+
+To alleviate this worry, there is a `similarity` parameter available in the `knn` clause. This value is the required
+minimum similarity for a vector to be considered a match. The `knn` search flow with this parameter is as follows:
+
++
+--
+* Apply any user provided `filter` queries
+* Explore the vector space to get `k` vectors
+* Do not return any vectors that are further away than the configured `similarity`
+--
+
+Here is an example. In this example we search for the given `query_vector` for `k` nearest neighbors. However, with
+`filter` applied and requiring that the found vectors have at least the provided `similarity` between them.
+[source,console]
+----
+POST image-index/_search
+{
+  "knn": {
+    "field": "image-vector",
+    "query_vector": [1, 5, -20],
+    "k": 5,
+    "num_candidates": 50,
+    "similarity": 36,
+    "filter": {
+      "term": {
+        "file-type": "png"
+      }
+    }
+  },
+  "fields": ["title"],
+  "_source": false
+}
+----
+// TEST[continued]
+
+In our data set, the only document with the file type of `png` has a vector of `[42, 8, -15]`. The `l2_norm` distance
+between `[42, 8, -15]` and `[1, 5, -20]` is `41.412`, which is greater than the configured similarity of `36`. Meaning,
+this search will return no hits.
+
 [discrete]
 [[knn-indexing-considerations]]
 ==== Indexing considerations

+ 6 - 2
docs/reference/search/search.asciidoc

@@ -511,10 +511,14 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-query-vector]
 
 `query_vector_builder`::
 (Optional, object)
-A configuration object indicating how to build a query_vector before executing 
-the request. You must provide a `query_vector_builder` or `query_vector`, but 
+A configuration object indicating how to build a query_vector before executing
+the request. You must provide a `query_vector_builder` or `query_vector`, but
 not both. Refer to <<semantic-search>> to learn more.
 
+`similarity`::
+(Optional, float)
+include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-similarity]
+
 ====
 
 [[search-api-min-score]]

+ 58 - 0
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml

@@ -310,3 +310,61 @@ setup:
               num_candidates: 1
   - match: { error.root_cause.0.type: "illegal_argument_exception" }
   - match: { error.root_cause.0.reason: "[knn] queries cannot be provided directly, use the [knn] body parameter instead" }
+---
+"KNN Vector similarity search only":
+  - skip:
+      version: ' - 8.7.99'
+      reason: 'kNN similarity added in 8.8'
+  - do:
+      search:
+        index: test
+        body:
+          fields: [ "name" ]
+          knn:
+            num_candidates: 3
+            k: 3
+            field: vector
+            similarity: 11
+            query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
+
+  - length: {hits.hits: 1}
+
+  - match: {hits.hits.0._id: "2"}
+  - match: {hits.hits.0.fields.name.0: "moose.jpg"}
+---
+"Vector similarity with filter only":
+  - skip:
+      version: ' - 8.7.99'
+      reason: 'kNN similarity added in 8.8'
+  - do:
+      search:
+        index: test
+        body:
+          fields: [ "name" ]
+          knn:
+            num_candidates: 3
+            k: 3
+            field: vector
+            similarity: 11
+            query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
+            filter: {"term": {"name": "moose.jpg"}}
+
+  - length: {hits.hits: 1}
+
+  - match: {hits.hits.0._id: "2"}
+  - match: {hits.hits.0.fields.name.0: "moose.jpg"}
+
+  - do:
+      search:
+        index: test
+        body:
+          fields: [ "name" ]
+          knn:
+            num_candidates: 3
+            k: 3
+            field: vector
+            similarity: 110
+            query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
+            filter: {"term": {"name": "cow.jpg"}}
+
+  - length: {hits.hits: 0}

+ 58 - 0
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml

@@ -176,3 +176,61 @@ setup:
               num_candidates: 1
   - match: { error.root_cause.0.type: "illegal_argument_exception" }
   - match: { error.root_cause.0.reason: "[knn] queries cannot be provided directly, use the [knn] body parameter instead" }
+---
+"Vector similarity search only":
+  - skip:
+      version: ' - 8.7.99'
+      reason: 'kNN similarity added in 8.8'
+  - do:
+      search:
+        index: test
+        body:
+          fields: [ "name" ]
+          knn:
+            num_candidates: 3
+            k: 3
+            field: vector
+            similarity: 1.0
+            query_vector: [5, 4.0, 3, 2.0, 127]
+
+  - length: {hits.hits: 1}
+
+  - match: {hits.hits.0._id: "3"}
+  - match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
+---
+"Vector similarity with filter only":
+  - skip:
+      version: ' - 8.7.99'
+      reason: 'kNN similarity added in 8.8'
+  - do:
+      search:
+        index: test
+        body:
+          fields: [ "name" ]
+          knn:
+            num_candidates: 3
+            k: 3
+            field: vector
+            similarity: 1.0
+            query_vector: [5, 4.0, 3, 2.0, 127]
+            filter: {"term": {"name": "rabbit.jpg"}}
+
+  - length: {hits.hits: 1}
+
+  - match: {hits.hits.0._id: "3"}
+  - match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
+
+  - do:
+      search:
+        index: test
+        body:
+          fields: [ "name" ]
+          knn:
+            num_candidates: 3
+            k: 3
+            field: vector
+            similarity: 1
+            query_vector: [5, 4.0, 3, 2.0, 127]
+            filter: {"term": {"name": "cow.jpg"}}
+
+  - length: {hits.hits: 0}

+ 2 - 1
server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java

@@ -80,7 +80,8 @@ public class DfsProfilerIT extends ESIntegTestCase {
                             vectorField,
                             new float[] { randomFloat(), randomFloat(), randomFloat() },
                             randomIntBetween(5, 10),
-                            50
+                            50,
+                            randomBoolean() ? null : randomFloat()
                         )
                     )
                 )

+ 9 - 4
server/src/main/java/org/elasticsearch/common/lucene/search/function/MinScoreScorer.java

@@ -17,17 +17,22 @@ import java.io.IOException;
 
 /** A {@link Scorer} that filters out documents that have a score that is
  *  lower than a configured constant. */
-final class MinScoreScorer extends Scorer {
+public final class MinScoreScorer extends Scorer {
 
     private final Scorer in;
     private final float minScore;
-
     private float curScore;
+    private final float boost;
+
+    public MinScoreScorer(Weight weight, Scorer scorer, float minScore) {
+        this(weight, scorer, minScore, 1f);
+    }
 
-    MinScoreScorer(Weight weight, Scorer scorer, float minScore) {
+    public MinScoreScorer(Weight weight, Scorer scorer, float minScore, float boost) {
         super(weight);
         this.in = scorer;
         this.minScore = minScore;
+        this.boost = boost;
     }
 
     @Override
@@ -37,7 +42,7 @@ final class MinScoreScorer extends Scorer {
 
     @Override
     public float score() {
-        return curScore;
+        return curScore * boost;
     }
 
     @Override

+ 48 - 8
server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

@@ -43,6 +43,7 @@ import org.elasticsearch.index.mapper.ValueFetcher;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
+import org.elasticsearch.search.vectors.VectorSimilarityQuery;
 import org.elasticsearch.xcontent.ToXContent;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
@@ -571,9 +572,31 @@ public class DenseVectorFieldMapper extends FieldMapper {
     );
 
     enum VectorSimilarity {
-        L2_NORM(VectorSimilarityFunction.EUCLIDEAN),
-        COSINE(VectorSimilarityFunction.COSINE),
-        DOT_PRODUCT(VectorSimilarityFunction.DOT_PRODUCT);
+        L2_NORM(VectorSimilarityFunction.EUCLIDEAN) {
+            @Override
+            float score(float similarity, ElementType elementType, int dim) {
+                return switch (elementType) {
+                    case BYTE, FLOAT -> 1f / (1f + similarity * similarity);
+                };
+            }
+        },
+        COSINE(VectorSimilarityFunction.COSINE) {
+            @Override
+            float score(float similarity, ElementType elementType, int dim) {
+                return switch (elementType) {
+                    case BYTE, FLOAT -> (1 + similarity) / 2f;
+                };
+            }
+        },
+        DOT_PRODUCT(VectorSimilarityFunction.DOT_PRODUCT) {
+            @Override
+            float score(float similarity, ElementType elementType, int dim) {
+                return switch (elementType) {
+                    case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15));
+                    case FLOAT -> (1 + similarity) / 2f;
+                };
+            }
+        };
 
         public final VectorSimilarityFunction function;
 
@@ -585,6 +608,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
         public final String toString() {
             return name().toLowerCase(Locale.ROOT);
         }
+
+        abstract float score(float similarity, ElementType elementType, int dim);
     }
 
     private abstract static class IndexOptions implements ToXContent {
@@ -723,7 +748,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
             throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support term queries");
         }
 
-        public Query createKnnQuery(byte[] queryVector, int numCands, Query filter) {
+        public Query createKnnQuery(byte[] queryVector, int numCands, Query filter, Float similarityThreshold) {
             if (isIndexed() == false) {
                 throw new IllegalArgumentException(
                     "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
@@ -749,11 +774,18 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 }
                 elementType.checkVectorMagnitude(similarity, elementType.errorByteElementsAppender(queryVector), squaredMagnitude);
             }
-
-            return new KnnByteVectorQuery(name(), queryVector, numCands, filter);
+            Query knnQuery = new KnnByteVectorQuery(name(), queryVector, numCands, filter);
+            if (similarityThreshold != null) {
+                knnQuery = new VectorSimilarityQuery(
+                    knnQuery,
+                    similarityThreshold,
+                    similarity.score(similarityThreshold, elementType, dims)
+                );
+            }
+            return knnQuery;
         }
 
-        public Query createKnnQuery(float[] queryVector, int numCands, Query filter) {
+        public Query createKnnQuery(float[] queryVector, int numCands, Query filter, Float similarityThreshold) {
             if (isIndexed() == false) {
                 throw new IllegalArgumentException(
                     "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
@@ -774,7 +806,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 }
                 elementType.checkVectorMagnitude(similarity, elementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
             }
-            return switch (elementType) {
+            Query knnQuery = switch (elementType) {
                 case BYTE -> {
                     byte[] bytes = new byte[queryVector.length];
                     for (int i = 0; i < queryVector.length; i++) {
@@ -784,6 +816,14 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 }
                 case FLOAT -> new KnnFloatVectorQuery(name(), queryVector, numCands, filter);
             };
+            if (similarityThreshold != null) {
+                knnQuery = new VectorSimilarityQuery(
+                    knnQuery,
+                    similarityThreshold,
+                    similarity.score(similarityThreshold, elementType, dims)
+                );
+            }
+            return knnQuery;
         }
     }
 

+ 52 - 12
server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

@@ -46,11 +46,12 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
     public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
     public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
     public static final ParseField QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder");
+    public static final ParseField VECTOR_SIMILARITY = new ParseField("similarity");
     public static final ParseField FILTER_FIELD = new ParseField("filter");
     public static final ParseField BOOST_FIELD = AbstractQueryBuilder.BOOST_FIELD;
 
+    @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<KnnSearchBuilder, Void> PARSER = new ConstructingObjectParser<>("knn", args -> {
-        @SuppressWarnings("unchecked")
         // TODO optimize parsing for when BYTE values are provided
         List<Float> vector = (List<Float>) args[1];
         final float[] vectorArray;
@@ -62,7 +63,14 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         } else {
             vectorArray = null;
         }
-        return new KnnSearchBuilder((String) args[0], vectorArray, (QueryVectorBuilder) args[4], (int) args[2], (int) args[3]);
+        return new KnnSearchBuilder(
+            (String) args[0],
+            vectorArray,
+            (QueryVectorBuilder) args[4],
+            (int) args[2],
+            (int) args[3],
+            (Float) args[5]
+        );
     });
 
     static {
@@ -75,6 +83,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
             QUERY_VECTOR_BUILDER_FIELD
         );
+        PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY);
         PARSER.declareFieldArray(
             KnnSearchBuilder::addFilterQueries,
             (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
@@ -94,6 +103,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
     private final Supplier<float[]> querySupplier;
     final int k;
     final int numCands;
+    final Float similarity;
     final List<QueryBuilder> filterQueries;
     float boost = AbstractQueryBuilder.DEFAULT_BOOST;
 
@@ -105,8 +115,8 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
      * @param k           the final number of nearest neighbors to return as top hits
      * @param numCands    the number of nearest neighbor candidates to consider per shard
      */
-    public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands) {
-        this(field, Objects.requireNonNull(queryVector, format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands);
+    public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, Float similarity) {
+        this(field, Objects.requireNonNull(queryVector, format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands, similarity);
     }
 
     /**
@@ -116,17 +126,25 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
      * @param k                  the final number of nearest neighbors to return as top hits
      * @param numCands           the number of nearest neighbor candidates to consider per shard
      */
-    public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands) {
+    public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int k, int numCands, Float similarity) {
         this(
             field,
             null,
             Objects.requireNonNull(queryVectorBuilder, format("[%s] cannot be null", QUERY_VECTOR_BUILDER_FIELD.getPreferredName())),
             k,
-            numCands
+            numCands,
+            similarity
         );
     }
 
-    private KnnSearchBuilder(String field, float[] queryVector, QueryVectorBuilder queryVectorBuilder, int k, int numCands) {
+    private KnnSearchBuilder(
+        String field,
+        float[] queryVector,
+        QueryVectorBuilder queryVectorBuilder,
+        int k,
+        int numCands,
+        Float similarity
+    ) {
         if (k < 1) {
             throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
         }
@@ -163,9 +181,17 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         this.numCands = numCands;
         this.filterQueries = new ArrayList<>();
         this.querySupplier = null;
+        this.similarity = similarity;
     }
 
-    private KnnSearchBuilder(String field, Supplier<float[]> querySupplier, int k, int numCands, List<QueryBuilder> filterQueries) {
+    private KnnSearchBuilder(
+        String field,
+        Supplier<float[]> querySupplier,
+        int k,
+        int numCands,
+        List<QueryBuilder> filterQueries,
+        Float similarity
+    ) {
         this.field = field;
         this.queryVector = new float[0];
         this.queryVectorBuilder = null;
@@ -173,6 +199,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         this.numCands = numCands;
         this.filterQueries = filterQueries;
         this.querySupplier = querySupplier;
+        this.similarity = similarity;
     }
 
     public KnnSearchBuilder(StreamInput in) throws IOException {
@@ -188,6 +215,11 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             this.queryVectorBuilder = null;
         }
         this.querySupplier = null;
+        if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
+            this.similarity = in.readOptionalFloat();
+        } else {
+            this.similarity = null;
+        }
     }
 
     public int k() {
@@ -229,7 +261,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             if (querySupplier.get() == null) {
                 return this;
             }
-            return new KnnSearchBuilder(field, querySupplier.get(), k, numCands).boost(boost).addFilterQueries(filterQueries);
+            return new KnnSearchBuilder(field, querySupplier.get(), k, numCands, similarity).boost(boost).addFilterQueries(filterQueries);
         }
         if (queryVectorBuilder != null) {
             SetOnce<float[]> toSet = new SetOnce<>();
@@ -249,7 +281,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
                 }
                 l.onResponse(null);
             }, l::onFailure)));
-            return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries).boost(boost);
+            return new KnnSearchBuilder(field, toSet::get, k, numCands, filterQueries, similarity).boost(boost);
         }
         boolean changed = false;
         List<QueryBuilder> rewrittenQueries = new ArrayList<>(filterQueries.size());
@@ -261,7 +293,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             rewrittenQueries.add(rewrittenQuery);
         }
         if (changed) {
-            return new KnnSearchBuilder(field, queryVector, k, numCands).boost(boost).addFilterQueries(rewrittenQueries);
+            return new KnnSearchBuilder(field, queryVector, k, numCands, similarity).boost(boost).addFilterQueries(rewrittenQueries);
         }
         return this;
     }
@@ -270,7 +302,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         if (queryVectorBuilder != null) {
             throw new IllegalArgumentException("missing rewrite");
         }
-        return new KnnVectorQueryBuilder(field, queryVector, numCands).boost(boost).addFilterQueries(filterQueries);
+        return new KnnVectorQueryBuilder(field, queryVector, numCands, similarity).boost(boost).addFilterQueries(filterQueries);
     }
 
     @Override
@@ -285,6 +317,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             && Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
             && Objects.equals(querySupplier, that.querySupplier)
             && Objects.equals(filterQueries, that.filterQueries)
+            && Objects.equals(similarity, that.similarity)
             && boost == that.boost;
     }
 
@@ -296,6 +329,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             numCands,
             querySupplier,
             queryVectorBuilder,
+            similarity,
             Arrays.hashCode(queryVector),
             Objects.hashCode(filterQueries),
             boost
@@ -314,6 +348,9 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         } else {
             builder.array(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
         }
+        if (similarity != null) {
+            builder.field(VECTOR_SIMILARITY.getPreferredName(), similarity);
+        }
 
         if (filterQueries.isEmpty() == false) {
             builder.startArray(FILTER_FIELD.getPreferredName());
@@ -353,5 +390,8 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_7_0)) {
             out.writeOptionalNamedWriteable(queryVectorBuilder);
         }
+        if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
+            out.writeOptionalFloat(similarity);
+        }
     }
 }

+ 1 - 1
server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java

@@ -255,7 +255,7 @@ public class KnnSearchRequestParser {
             if (numCands > NUM_CANDS_LIMIT) {
                 throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
             }
-            return new KnnVectorQueryBuilder(field, queryVector, numCands);
+            return new KnnVectorQueryBuilder(field, queryVector, numCands, null);
         }
 
         @Override

+ 46 - 8
server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

@@ -45,21 +45,35 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
     private final byte[] byteQueryVector;
     private final int numCands;
     private final List<QueryBuilder> filterQueries;
+    private final Float vectorSimilarity;
 
-    public KnnVectorQueryBuilder(String fieldName, float[] queryVector, int numCands) {
+    public KnnVectorQueryBuilder(String fieldName, float[] queryVector, int numCands, Float vectorSimilarity) {
         this.fieldName = fieldName;
         this.queryVector = Objects.requireNonNull(queryVector);
         this.byteQueryVector = null;
         this.numCands = numCands;
         this.filterQueries = new ArrayList<>();
+        this.vectorSimilarity = vectorSimilarity;
     }
 
-    public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, int numCands) {
+    public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, int numCands, Float vectorSimilarity) {
         this.fieldName = fieldName;
         this.queryVector = null;
         this.byteQueryVector = Objects.requireNonNull(queryVector);
         this.numCands = numCands;
         this.filterQueries = new ArrayList<>();
+        this.vectorSimilarity = vectorSimilarity;
+    }
+
+    // Tests only
+    KnnVectorQueryBuilder(String fieldName, byte[] queryVector, float[] floatQueryVector, int numCands, Float vectorSimilarity) {
+        assert queryVector != null ^ floatQueryVector != null;
+        this.fieldName = fieldName;
+        this.queryVector = floatQueryVector;
+        this.byteQueryVector = queryVector;
+        this.numCands = numCands;
+        this.filterQueries = new ArrayList<>();
+        this.vectorSimilarity = vectorSimilarity;
     }
 
     public KnnVectorQueryBuilder(StreamInput in) throws IOException {
@@ -78,6 +92,11 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         } else {
             this.filterQueries = readQueries(in);
         }
+        if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
+            this.vectorSimilarity = in.readOptionalFloat();
+        } else {
+            this.vectorSimilarity = null;
+        }
     }
 
     public String getFieldName() {
@@ -94,6 +113,11 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         return byteQueryVector;
     }
 
+    @Nullable
+    public Float getVectorSimilarity() {
+        return vectorSimilarity;
+    }
+
     public int numCands() {
         return numCands;
     }
@@ -144,6 +168,9 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_2_0)) {
             writeQueries(out, filterQueries);
         }
+        if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
+            out.writeOptionalFloat(vectorSimilarity);
+        }
     }
 
     @Override
@@ -152,6 +179,9 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
             .field("field", fieldName)
             .field("vector", queryVector != null ? queryVector : byteQueryVector)
             .field("num_candidates", numCands);
+        if (vectorSimilarity != null) {
+            builder.field("similarity", vectorSimilarity);
+        }
         if (filterQueries.isEmpty() == false) {
             builder.startArray("filters");
             for (QueryBuilder filterQuery : filterQueries) {
@@ -184,8 +214,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         }
         if (changed) {
             return byteQueryVector != null
-                ? new KnnVectorQueryBuilder(fieldName, byteQueryVector, numCands).addFilterQueries(rewrittenQueries)
-                : new KnnVectorQueryBuilder(fieldName, queryVector, numCands).addFilterQueries(rewrittenQueries);
+                ? new KnnVectorQueryBuilder(fieldName, byteQueryVector, numCands, vectorSimilarity).addFilterQueries(rewrittenQueries)
+                : new KnnVectorQueryBuilder(fieldName, queryVector, numCands, vectorSimilarity).addFilterQueries(rewrittenQueries);
         }
         return this;
     }
@@ -212,13 +242,20 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
 
         DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
         return queryVector != null
-            ? vectorFieldType.createKnnQuery(queryVector, numCands, filterQuery)
-            : vectorFieldType.createKnnQuery(byteQueryVector, numCands, filterQuery);
+            ? vectorFieldType.createKnnQuery(queryVector, numCands, filterQuery, vectorSimilarity)
+            : vectorFieldType.createKnnQuery(byteQueryVector, numCands, filterQuery, vectorSimilarity);
     }
 
     @Override
     protected int doHashCode() {
-        return Objects.hash(fieldName, Arrays.hashCode(queryVector), Arrays.hashCode(byteQueryVector), numCands, filterQueries);
+        return Objects.hash(
+            fieldName,
+            Arrays.hashCode(queryVector),
+            Arrays.hashCode(byteQueryVector),
+            numCands,
+            filterQueries,
+            vectorSimilarity
+        );
     }
 
     @Override
@@ -227,7 +264,8 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
             && Arrays.equals(queryVector, other.queryVector)
             && Arrays.equals(byteQueryVector, other.byteQueryVector)
             && numCands == other.numCands
-            && Objects.equals(filterQueries, other.filterQueries);
+            && Objects.equals(filterQueries, other.filterQueries)
+            && Objects.equals(vectorSimilarity, other.vectorSimilarity);
     }
 
     @Override

+ 158 - 0
server/src/main/java/org/elasticsearch/search/vectors/VectorSimilarityQuery.java

@@ -0,0 +1,158 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.FilterWeight;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.MatchNoDocsQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.QueryVisitor;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.Weight;
+import org.elasticsearch.common.lucene.search.function.MinScoreScorer;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.common.Strings.format;
+
+/**
+ * This query provides a simple post-filter for the provided Query. The query is assumed to be a Knn(Float|Byte)VectorQuery.
+ */
+public class VectorSimilarityQuery extends Query {
+    private final float similarity;
+    private final float docScore;
+    private final Query innerKnnQuery;
+
+    /**
+     * @param innerKnnQuery A {@link org.apache.lucene.search.KnnFloatVectorQuery} or {@link org.apache.lucene.search.KnnByteVectorQuery}
+     * @param similarity The similarity threshold originally provided (used in explanations)
+     * @param docScore The similarity transformed into a score threshold applied after gathering the nearest neighbors
+     */
+    public VectorSimilarityQuery(Query innerKnnQuery, float similarity, float docScore) {
+        this.similarity = similarity;
+        this.docScore = docScore;
+        this.innerKnnQuery = innerKnnQuery;
+    }
+
+    // For testing
+    Query getInnerKnnQuery() {
+        return innerKnnQuery;
+    }
+
+    float getSimilarity() {
+        return similarity;
+    }
+
+    float getDocScore() {
+        return docScore;
+    }
+
+    @Override
+    public Query rewrite(IndexReader reader) throws IOException {
+        Query rewrittenInnerQuery = innerKnnQuery.rewrite(reader);
+        if (rewrittenInnerQuery instanceof MatchNoDocsQuery) {
+            return rewrittenInnerQuery;
+        }
+        if (rewrittenInnerQuery == innerKnnQuery) {
+            return this;
+        }
+        return new VectorSimilarityQuery(rewrittenInnerQuery, similarity, docScore);
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
+        final Weight innerWeight;
+        if (scoreMode.isExhaustive()) {
+            innerWeight = innerKnnQuery.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
+        } else {
+            innerWeight = innerKnnQuery.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f);
+        }
+        return new MinScoreWeight(innerWeight, docScore, similarity, this, boost);
+    }
+
+    @Override
+    public String toString(String field) {
+        return "VectorSimilarityQuery["
+            + "similarity="
+            + similarity
+            + ", docScore="
+            + docScore
+            + ", innerKnnQuery="
+            + innerKnnQuery.toString(field)
+            + ']';
+    }
+
+    @Override
+    public void visit(QueryVisitor visitor) {
+        visitor.visitLeaf(this);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (sameClassAs(obj) == false) {
+            return false;
+        }
+        VectorSimilarityQuery other = (VectorSimilarityQuery) obj;
+        return Objects.equals(innerKnnQuery, other.innerKnnQuery) && docScore == other.docScore && similarity == other.similarity;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(innerKnnQuery, docScore, similarity);
+    }
+
+    private static class MinScoreWeight extends FilterWeight {
+
+        private final float similarity, docScore, boost;
+
+        private MinScoreWeight(Weight innerWeight, float docScore, float similarity, Query parent, float boost) {
+            super(parent, innerWeight);
+            this.docScore = docScore;
+            this.similarity = similarity;
+            this.boost = boost;
+        }
+
+        @Override
+        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
+            Explanation explanation = in.explain(context, doc);
+            if (explanation.isMatch()) {
+                float score = explanation.getValue().floatValue();
+                if (score >= docScore) {
+                    return Explanation.match(explanation.getValue().floatValue() * boost, "vector similarity within limit", explanation);
+                } else {
+                    return Explanation.noMatch(
+                        format(
+                            "vector found, but score [%f] is less than matching minimum score [%f] from similarity [%f]",
+                            explanation.getValue().floatValue(),
+                            docScore,
+                            similarity
+                        ),
+                        explanation
+                    );
+                }
+            }
+            return explanation;
+        }
+
+        @Override
+        public Scorer scorer(LeafReaderContext context) throws IOException {
+            Scorer innerScorer = in.scorer(context);
+            if (innerScorer == null) {
+                return null;
+            }
+            return new MinScoreScorer(this, innerScorer, docScore, boost);
+        }
+    }
+
+}

+ 10 - 10
server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java

@@ -61,7 +61,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         client().prepareUpdate("index", "0").setDoc("vector", (Object) null).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get();
 
         float[] queryVector = randomVector();
-        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50).boost(5.0f);
+        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 20, 50, null).boost(5.0f);
         SearchResponse response = client().prepareSearch("index")
             .setKnnSearch(List.of(knnSearch))
             .setQuery(QueryBuilders.matchQuery("text", "goodnight"))
@@ -103,7 +103,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         client().admin().indices().prepareRefresh("index").get();
 
         float[] queryVector = randomVector();
-        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).boost(5.0f);
+        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f);
         SearchResponse response = client().prepareSearch("index")
             .setKnnSearch(List.of(knnSearch))
             .setQuery(QueryBuilders.matchQuery("text", "goodnight"))
@@ -147,7 +147,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         client().admin().indices().prepareRefresh("index").get();
 
         float[] queryVector = randomVector();
-        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).addFilterQuery(
+        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).addFilterQuery(
             QueryBuilders.termsQuery("field", "second")
         );
         SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10).get();
@@ -190,7 +190,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         client().admin().indices().prepareRefresh("index").get();
 
         float[] queryVector = randomVector();
-        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).addFilterQuery(
+        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).addFilterQuery(
             QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field"))
         );
         SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10).get();
@@ -237,8 +237,8 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         client().admin().indices().prepareRefresh("index").get();
 
         float[] queryVector = randomVector(20f, 21f);
-        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).boost(5.0f);
-        KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50).boost(10.0f);
+        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null).boost(5.0f);
+        KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null).boost(10.0f);
         SearchResponse response = client().prepareSearch("index")
             .setKnnSearch(List.of(knnSearch, knnSearch2))
             .setQuery(QueryBuilders.matchQuery("text", "goodnight"))
@@ -296,8 +296,8 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
 
         float[] queryVector = randomVector();
         // Having the same query vector and same docs should mean our KNN scores are linearly combined if the same doc is matched
-        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50);
-        KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50);
+        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50, null);
+        KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50, null);
         SearchResponse responseOneKnn = client().prepareSearch("index")
             .setKnnSearch(List.of(knnSearch))
             .addFetchField("*")
@@ -365,7 +365,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         client().admin().indices().prepareRefresh("index").get();
 
         float[] queryVector = randomVector();
-        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50);
+        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50, null);
         SearchResponse response = client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10).get();
 
         assertHitCount(response, expectedHits);
@@ -400,7 +400,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         // how the action works (it builds a kNN query under the hood)
         float[] queryVector = randomVector();
         SearchResponse response = client().prepareSearch("index1", "index2")
-            .setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5))
+            .setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5, null))
             .setSize(2)
             .get();
 

+ 6 - 3
server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java

@@ -95,8 +95,8 @@ public class SearchRequestTests extends AbstractSearchTestCase {
         searchRequest.source()
             .knnSearch(
                 List.of(
-                    new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10),
-                    new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 4, 12, 41 }, 3, 5)
+                    new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10, randomBoolean() ? null : randomFloat()),
+                    new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 4, 12, 41 }, 3, 5, randomBoolean() ? null : randomFloat())
                 )
             );
         expectThrows(
@@ -109,7 +109,10 @@ public class SearchRequestTests extends AbstractSearchTestCase {
             )
         );
 
-        searchRequest.source().knnSearch(List.of(new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10)));
+        searchRequest.source()
+            .knnSearch(
+                List.of(new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10, randomBoolean() ? null : randomFloat()))
+            );
         // Shouldn't throw because its just one KNN request
         copyWriteable(
             searchRequest,

+ 2 - 2
server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java

@@ -1104,7 +1104,7 @@ public class TransportSearchActionTests extends ESTestCase {
         {
             SearchRequest searchRequest = new SearchRequest();
             SearchSourceBuilder source = new SearchSourceBuilder();
-            source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50)));
+            source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null)));
             searchRequest.source(source);
 
             searchRequest.setCcsMinimizeRoundtrips(true);
@@ -1119,7 +1119,7 @@ public class TransportSearchActionTests extends ESTestCase {
             // If the search includes kNN, we should always use DFS_QUERY_THEN_FETCH
             SearchRequest searchRequest = new SearchRequest();
             SearchSourceBuilder source = new SearchSourceBuilder();
-            source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50)));
+            source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50, null)));
             searchRequest.source(source);
 
             TransportSearchAction.adjustSearchType(searchRequest, randomBoolean());

+ 10 - 10
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

@@ -580,7 +580,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         Exception e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 128, 0, 0 }, 3, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 128, 0, 0 }, 3, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -589,7 +589,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 0.0f, 0f, -129.0f }, 3, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 0.0f, 0f, -129.0f }, 3, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -598,7 +598,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 0.0f, 0.5f, 0.0f }, 3, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 0.0f, 0.5f, 0.0f }, 3, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -607,7 +607,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, 0.0f, -0.25f }, 3, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, 0.0f, -0.25f }, 3, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -616,13 +616,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.NaN, 0f, 0.0f }, 3, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.NaN, 0f, 0.0f }, 3, null, null)
         );
         assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];"));
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }, 3, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }, 3, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -631,7 +631,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }, 3, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }, 3, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -657,13 +657,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         Exception e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.NaN, 0f, 0.0f }, 3, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.NaN, 0f, 0.0f }, 3, null, null)
         );
         assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];"));
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }, 3, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }, 3, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -672,7 +672,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }, 3, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }, 3, null, null)
         );
         assertThat(
             e.getMessage(),

+ 15 - 6
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

@@ -118,7 +118,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
         );
         IllegalArgumentException e = expectThrows(
             IllegalArgumentException.class,
-            () -> unindexedField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null)
+            () -> unindexedField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null, null)
         );
         assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]"));
 
@@ -131,7 +131,10 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             VectorSimilarity.DOT_PRODUCT,
             Collections.emptyMap()
         );
-        e = expectThrows(IllegalArgumentException.class, () -> dotProductField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null));
+        e = expectThrows(
+            IllegalArgumentException.class,
+            () -> dotProductField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null, null)
+        );
         assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors."));
 
         DenseVectorFieldType cosineField = new DenseVectorFieldType(
@@ -143,7 +146,10 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             VectorSimilarity.COSINE,
             Collections.emptyMap()
         );
-        e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f }, 10, null));
+        e = expectThrows(
+            IllegalArgumentException.class,
+            () -> cosineField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f }, 10, null, null)
+        );
         assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
     }
 
@@ -159,7 +165,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
         );
         IllegalArgumentException e = expectThrows(
             IllegalArgumentException.class,
-            () -> unindexedField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null)
+            () -> unindexedField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null, null)
         );
         assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]"));
 
@@ -172,10 +178,13 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             VectorSimilarity.COSINE,
             Collections.emptyMap()
         );
-        e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f }, 10, null));
+        e = expectThrows(
+            IllegalArgumentException.class,
+            () -> cosineField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f }, 10, null, null)
+        );
         assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
 
-        e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new byte[] { 0, 0, 0 }, 10, null));
+        e = expectThrows(IllegalArgumentException.class, () -> cosineField.createKnnQuery(new byte[] { 0, 0, 0 }, 10, null, null));
         assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
     }
 }

+ 1 - 1
server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java

@@ -103,7 +103,7 @@ public class RestSearchActionTests extends RestActionTestCase {
             ).withMethod(RestRequest.Method.GET).withPath("/some_index/_search").withParams(params).build();
 
             SearchRequest searchRequest = new SearchRequest();
-            KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100);
+            KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100, null);
             searchRequest.source(new SearchSourceBuilder().knnSearch(List.of(knnSearch)));
 
             Exception ex = expectThrows(

+ 1 - 1
server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java

@@ -744,7 +744,7 @@ public class SearchSourceBuilderTests extends AbstractSearchTestCase {
         searchSourceBuilder.fetchField("field");
         // these are not correct runtime mappings but they are counted compared to empty object
         searchSourceBuilder.runtimeMappings(Collections.singletonMap("field", "keyword"));
-        searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5)));
+        searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5, null)));
         searchSourceBuilder.pointInTimeBuilder(new PointInTimeBuilder("pitid"));
         searchSourceBuilder.docValueField("field");
         searchSourceBuilder.storedField("field");

+ 77 - 40
server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

@@ -19,7 +19,6 @@ import org.elasticsearch.common.compress.CompressedXContent;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
 import org.elasticsearch.index.query.MatchNoneQueryBuilder;
@@ -38,6 +37,7 @@ import java.util.ArrayList;
 import java.util.List;
 
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 
 abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase<KnnVectorQueryBuilder> {
@@ -84,8 +84,8 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         int numCands = randomIntBetween(1, 1000);
 
         KnnVectorQueryBuilder queryBuilder = switch (elementType()) {
-            case BYTE -> new KnnVectorQueryBuilder(fieldName, byteVector, numCands);
-            case FLOAT -> new KnnVectorQueryBuilder(fieldName, vector, numCands);
+            case BYTE -> new KnnVectorQueryBuilder(fieldName, byteVector, numCands, randomBoolean() ? null : randomFloat());
+            case FLOAT -> new KnnVectorQueryBuilder(fieldName, vector, numCands, randomBoolean() ? null : randomFloat());
         };
 
         if (randomBoolean()) {
@@ -102,9 +102,19 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
 
     @Override
     protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query query, SearchExecutionContext context) throws IOException {
-        switch (elementType()) {
-            case FLOAT -> assertTrue(query instanceof KnnFloatVectorQuery);
-            case BYTE -> assertTrue(query instanceof KnnByteVectorQuery);
+        if (queryBuilder.getVectorSimilarity() != null) {
+            assertTrue(query instanceof VectorSimilarityQuery);
+            Query knnQuery = ((VectorSimilarityQuery) query).getInnerKnnQuery();
+            assertThat(((VectorSimilarityQuery) query).getSimilarity(), equalTo(queryBuilder.getVectorSimilarity()));
+            switch (elementType()) {
+                case FLOAT -> assertTrue(knnQuery instanceof KnnFloatVectorQuery);
+                case BYTE -> assertTrue(knnQuery instanceof KnnByteVectorQuery);
+            }
+        } else {
+            switch (elementType()) {
+                case FLOAT -> assertTrue(query instanceof KnnFloatVectorQuery);
+                case BYTE -> assertTrue(query instanceof KnnByteVectorQuery);
+            }
         }
 
         BooleanQuery.Builder builder = new BooleanQuery.Builder();
@@ -118,19 +128,22 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
             case BYTE -> new KnnByteVectorQuery(VECTOR_FIELD, queryBuilder.getByteQueryVector(), queryBuilder.numCands(), filterQuery);
             case FLOAT -> new KnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector(), queryBuilder.numCands(), filterQuery);
         };
+        if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) {
+            query = vectorSimilarityQuery.getInnerKnnQuery();
+        }
         assertEquals(query, knnVectorQueryBuilt);
     }
 
     public void testWrongDimension() {
         SearchExecutionContext context = createSearchExecutionContext();
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10, null);
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
         assertThat(e.getMessage(), containsString("the query vector has a different dimension [2] than the index vectors [3]"));
     }
 
     public void testNonexistentField() {
         SearchExecutionContext context = createSearchExecutionContext();
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 10);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 10, null);
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
         assertThat(e.getMessage(), containsString("field [nonexistent] does not exist in the mapping"));
     }
@@ -140,7 +153,8 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(
             AbstractBuilderTestCase.KEYWORD_FIELD_NAME,
             new float[] { 1.0f, 1.0f, 1.0f },
-            10
+            10,
+            null
         );
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
         assertThat(e.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields"));
@@ -148,7 +162,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
 
     @Override
     public void testValidOutput() {
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 10);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 10, null);
         String expected = """
             {
               "knn" : {
@@ -169,7 +183,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         SearchExecutionContext context = createSearchExecutionContext();
         context.setAllowUnmappedFields(true);
         TermQueryBuilder termQuery = new TermQueryBuilder("unmapped_field", 42);
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, VECTOR_DIMENSION);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, VECTOR_DIMENSION, null);
         query.addFilterQuery(termQuery);
 
         IllegalStateException e = expectThrows(IllegalStateException.class, () -> query.toQuery(context));
@@ -179,7 +193,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         assertThat(rewrittenQuery, instanceOf(MatchNoneQueryBuilder.class));
     }
 
-    public void testBWCVersionSerialization() throws IOException {
+    public void testBWCVersionSerializationFilters() throws IOException {
         float[] bwcFloat = new float[VECTOR_DIMENSION];
         KnnVectorQueryBuilder query = createTestQueryBuilder();
         if (query.queryVector() != null) {
@@ -189,47 +203,70 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
                 bwcFloat[i] = query.getByteQueryVector()[i];
             }
         }
-        KnnVectorQueryBuilder queryWithNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), bwcFloat, query.numCands()).queryName(
+
+        KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), bwcFloat, query.numCands(), null).queryName(
             query.queryName()
         ).boost(query.boost());
 
-        KnnVectorQueryBuilder queryNoByteQuery = new KnnVectorQueryBuilder(query.getFieldName(), bwcFloat, query.numCands()).queryName(
-            query.queryName()
-        ).boost(query.boost()).addFilterQueries(query.filterQueries());
+        TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween(
+            random(),
+            TransportVersion.V_8_0_0,
+            TransportVersion.V_8_1_0
+        );
+
+        assertBWCSerialization(query, queryNoFilters, beforeFilterVersion);
+    }
 
-        TransportVersion newVersion = TransportVersionUtils.randomVersionBetween(
+    public void testBWCVersionSerializationSimilarity() throws IOException {
+        KnnVectorQueryBuilder query = createTestQueryBuilder();
+        KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(
+            query.getFieldName(),
+            query.getByteQueryVector(),
+            query.queryVector(),
+            query.numCands(),
+            null
+        ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries());
+        TransportVersion beforeSimilarity = TransportVersionUtils.randomVersionBetween(
             random(),
             TransportVersion.V_8_7_0,
-            TransportVersion.CURRENT
+            TransportVersion.V_8_8_0
         );
+        assertBWCSerialization(query, queryNoSimilarity, beforeSimilarity);
+    }
+
+    public void testBWCVersionSerializationByteQuery() throws IOException {
+        float[] bwcFloat = new float[VECTOR_DIMENSION];
+        KnnVectorQueryBuilder query = createTestQueryBuilder();
+        if (query.queryVector() != null) {
+            bwcFloat = query.queryVector();
+        } else {
+            for (int i = 0; i < query.getByteQueryVector().length; i++) {
+                bwcFloat[i] = query.getByteQueryVector()[i];
+            }
+        }
+        KnnVectorQueryBuilder queryNoByteQuery = new KnnVectorQueryBuilder(query.getFieldName(), bwcFloat, query.numCands(), null)
+            .queryName(query.queryName())
+            .boost(query.boost())
+            .addFilterQueries(query.filterQueries());
+
         TransportVersion beforeByteQueryVersion = TransportVersionUtils.randomVersionBetween(
             random(),
             TransportVersion.V_8_2_0,
             TransportVersion.V_8_6_0
         );
-        TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween(
-            random(),
-            TransportVersion.V_8_0_0,
-            TransportVersion.V_8_1_0
-        );
+        assertBWCSerialization(query, queryNoByteQuery, beforeByteQueryVersion);
+    }
 
-        assertSerialization(query, newVersion);
-        assertSerialization(queryNoByteQuery, beforeByteQueryVersion);
-        assertSerialization(queryWithNoFilters, beforeFilterVersion);
-
-        for (var tuple : List.of(
-            Tuple.tuple(beforeByteQueryVersion, queryNoByteQuery),
-            Tuple.tuple(beforeFilterVersion, queryWithNoFilters)
-        )) {
-            try (BytesStreamOutput output = new BytesStreamOutput()) {
-                output.setTransportVersion(tuple.v1());
-                output.writeNamedWriteable(query);
-                try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) {
-                    in.setTransportVersion(tuple.v1());
-                    KnnVectorQueryBuilder deserializedQuery = (KnnVectorQueryBuilder) in.readNamedWriteable(QueryBuilder.class);
-                    assertEquals(tuple.v2(), deserializedQuery);
-                    assertEquals(tuple.v2().hashCode(), deserializedQuery.hashCode());
-                }
+    private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery, TransportVersion version) throws IOException {
+        assertSerialization(bwcQuery, version);
+        try (BytesStreamOutput output = new BytesStreamOutput()) {
+            output.setTransportVersion(version);
+            output.writeNamedWriteable(newQuery);
+            try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry())) {
+                in.setTransportVersion(version);
+                KnnVectorQueryBuilder deserializedQuery = (KnnVectorQueryBuilder) in.readNamedWriteable(QueryBuilder.class);
+                assertEquals(bwcQuery, deserializedQuery);
+                assertEquals(bwcQuery.hashCode(), deserializedQuery.hashCode());
             }
         }
     }

+ 36 - 19
server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java

@@ -51,7 +51,7 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
         int k = randomIntBetween(1, 100);
         int numCands = randomIntBetween(k + 20, 1000);
 
-        KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands);
+        KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, randomBoolean() ? null : randomFloat());
         if (randomBoolean()) {
             builder.boost(randomFloat());
         }
@@ -98,27 +98,41 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
 
     @Override
     protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) {
-        switch (random().nextInt(6)) {
-
+        switch (random().nextInt(7)) {
             case 0:
                 String newField = randomValueOtherThan(instance.field, () -> randomAlphaOfLength(5));
-                return new KnnSearchBuilder(newField, instance.queryVector, instance.k, instance.numCands + 3).boost(instance.boost);
+                return new KnnSearchBuilder(newField, instance.queryVector, instance.k, instance.numCands + 3, instance.similarity).boost(
+                    instance.boost
+                );
             case 1:
                 float[] newVector = randomValueOtherThan(instance.queryVector, () -> randomVector(5));
-                return new KnnSearchBuilder(instance.field, newVector, instance.k + 3, instance.numCands).boost(instance.boost);
+                return new KnnSearchBuilder(instance.field, newVector, instance.k + 3, instance.numCands, instance.similarity).boost(
+                    instance.boost
+                );
             case 2:
-                return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k + 3, instance.numCands).boost(instance.boost);
+                return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k + 3, instance.numCands, instance.similarity)
+                    .boost(instance.boost);
             case 3:
-                return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands + 3).boost(instance.boost);
+                return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands + 3, instance.similarity)
+                    .boost(instance.boost);
             case 4:
-                return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands).addFilterQueries(
-                    instance.filterQueries
-                ).addFilterQuery(QueryBuilders.termQuery("new_field", "new-value")).boost(instance.boost);
+                return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands, instance.similarity)
+                    .addFilterQueries(instance.filterQueries)
+                    .addFilterQuery(QueryBuilders.termQuery("new_field", "new-value"))
+                    .boost(instance.boost);
             case 5:
                 float newBoost = randomValueOtherThan(instance.boost, ESTestCase::randomFloat);
-                return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands).addFilterQueries(
-                    instance.filterQueries
-                ).boost(newBoost);
+                return new KnnSearchBuilder(instance.field, instance.queryVector, instance.k, instance.numCands, instance.similarity)
+                    .addFilterQueries(instance.filterQueries)
+                    .boost(newBoost);
+            case 6:
+                return new KnnSearchBuilder(
+                    instance.field,
+                    instance.queryVector,
+                    instance.k,
+                    instance.numCands,
+                    randomValueOtherThan(instance.similarity, ESTestCase::randomFloat)
+                ).addFilterQueries(instance.filterQueries).boost(instance.boost);
             default:
                 throw new IllegalStateException();
         }
@@ -129,7 +143,8 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
         float[] vector = randomVector(randomIntBetween(2, 30));
         int k = randomIntBetween(1, 100);
         int numCands = randomIntBetween(k, 1000);
-        KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands);
+        Float similarity = randomBoolean() ? null : randomFloat();
+        KnnSearchBuilder builder = new KnnSearchBuilder(field, vector, k, numCands, similarity);
 
         float boost = AbstractQueryBuilder.DEFAULT_BOOST;
         if (randomBoolean()) {
@@ -145,14 +160,14 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
             builder.addFilterQuery(filter);
         }
 
-        QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, numCands).addFilterQueries(filterQueries).boost(boost);
+        QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, numCands, similarity).addFilterQueries(filterQueries).boost(boost);
         assertEquals(expected, builder.toQueryBuilder());
     }
 
     public void testNumCandsLessThanK() {
         IllegalArgumentException e = expectThrows(
             IllegalArgumentException.class,
-            () -> new KnnSearchBuilder("field", randomVector(3), 50, 10)
+            () -> new KnnSearchBuilder("field", randomVector(3), 50, 10, null)
         );
         assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
     }
@@ -160,7 +175,7 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
     public void testNumCandsExceedsLimit() {
         IllegalArgumentException e = expectThrows(
             IllegalArgumentException.class,
-            () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002)
+            () -> new KnnSearchBuilder("field", randomVector(3), 100, 10002, null)
         );
         assertThat(e.getMessage(), containsString("[num_candidates] cannot exceed [10000]"));
     }
@@ -168,7 +183,7 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
     public void testInvalidK() {
         IllegalArgumentException e = expectThrows(
             IllegalArgumentException.class,
-            () -> new KnnSearchBuilder("field", randomVector(3), 0, 100)
+            () -> new KnnSearchBuilder("field", randomVector(3), 0, 100, null)
         );
         assertThat(e.getMessage(), containsString("[k] must be greater than 0"));
     }
@@ -179,7 +194,8 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
             "field",
             new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expectedArray),
             5,
-            10
+            10,
+            1f
         );
         searchBuilder.boost(randomFloat());
         searchBuilder.addFilterQueries(List.of(new RewriteableQuery()));
@@ -194,6 +210,7 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
         assertThat(rewritten.queryVector, equalTo(expectedArray));
         assertThat(rewritten.queryVectorBuilder, nullValue());
         assertThat(rewritten.filterQueries, hasSize(1));
+        assertThat(rewritten.similarity, equalTo(1f));
         assertThat(((RewriteableQuery) rewritten.filterQueries.get(0)).rewrites, equalTo(1));
     }
 

+ 221 - 0
server/src/test/java/org/elasticsearch/search/vectors/VectorSimilarityQueryTests.java

@@ -0,0 +1,221 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.store.Directory;
+import org.elasticsearch.common.lucene.LuceneTests;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.function.Supplier;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
+
+public class VectorSimilarityQueryTests extends ESTestCase {
+
+    public void testSimpleEuclidean() throws Exception {
+        try (Directory d = newDirectory()) {
+            try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+                Document document = new Document();
+                KnnFloatVectorField vectorField = new KnnFloatVectorField("float_vector", new float[] { 1, 1, 1 });
+                document.add(vectorField);
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 2, 1, 1 });
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 1, 2, 1 });
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 1, 1, 2 });
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 2, 2, 2 });
+                w.addDocument(document);
+
+                w.commit();
+            }
+            try (IndexReader reader = DirectoryReader.open(d)) {
+                IndexSearcher searcher = LuceneTests.newSearcher(reader);
+                // Should match all, worst distance is 3
+                TopDocs docs = searcher.search(
+                    new VectorSimilarityQuery(new KnnFloatVectorQuery("float_vector", new float[] { 1, 1, 1 }, 5), 3f, 0.25f),
+                    5
+                );
+                assertThat(docs.totalHits.value, equalTo(5L));
+
+                // Should match only 4
+                docs = searcher.search(
+                    new VectorSimilarityQuery(new KnnFloatVectorQuery("float_vector", new float[] { 1, 1, 1 }, 5), 1f, 0.5f),
+                    5
+                );
+                assertThat(docs.totalHits.value, equalTo(4L));
+            }
+        }
+    }
+
+    public void testEuclideanInvariant() throws Exception {
+        int dim = 4;
+        int L = 1;
+        int n = 100;
+        String fieldName = "vector";
+        Supplier<float[]> vectorValue = () -> new float[] { randomFloat(), randomFloat(), randomFloat(), randomFloat() };
+        try (Directory d = newDirectory()) {
+            try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+                KnnFloatVectorField vectorField = new KnnFloatVectorField(fieldName, vectorValue.get());
+                Document document = new Document();
+                document.add(vectorField);
+                for (int i = 0; i < n; i++) {
+                    w.addDocument(document);
+                    vectorField.setVectorValue(vectorValue.get());
+                }
+                w.commit();
+            }
+            try (IndexReader reader = DirectoryReader.open(d)) {
+                float radius = 0.25f * (float) Math.sqrt(dim) * L;
+                float score = 1f / (1f + radius);
+                IndexSearcher searcher = LuceneTests.newSearcher(reader);
+                for (int i = 0; i < 10; i++) {
+                    TopDocs docs = searcher.search(
+                        new VectorSimilarityQuery(new KnnFloatVectorQuery(fieldName, vectorValue.get(), n), radius, score),
+                        n
+                    );
+                    for (ScoreDoc scoreDoc : docs.scoreDocs) {
+                        float dist = (1 / scoreDoc.score) - 1;
+                        assertThat(dist, lessThanOrEqualTo(radius));
+                    }
+                }
+            }
+        }
+    }
+
+    public void testSimpleCosine() throws IOException {
+        try (Directory d = newDirectory()) {
+            try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+                Document document = new Document();
+                KnnFloatVectorField vectorField = new KnnFloatVectorField(
+                    "float_vector",
+                    new float[] { 1, 1, 1 },
+                    VectorSimilarityFunction.COSINE
+                );
+                document.add(vectorField);
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 2, 1, 1 });
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 1, 2, 1 });
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 1, 1, 2 });
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 2, 0, 2 });
+                w.addDocument(document);
+
+                w.commit();
+            }
+            try (IndexReader reader = DirectoryReader.open(d)) {
+                IndexSearcher searcher = LuceneTests.newSearcher(reader);
+                // Should match all actually worse distance is
+                TopDocs docs = searcher.search(
+                    new VectorSimilarityQuery(new KnnFloatVectorQuery("float_vector", new float[] { 1, 1, 1 }, 5), .8f, .9f),
+                    5
+                );
+                assertThat(docs.totalHits.value, equalTo(5L));
+
+                // Should match only 4
+                docs = searcher.search(
+                    new VectorSimilarityQuery(new KnnFloatVectorQuery("float_vector", new float[] { 1, 1, 1 }, 5), .9f, 0.95f),
+                    5
+                );
+                assertThat(docs.totalHits.value, equalTo(4L));
+            }
+        }
+    }
+
+    public void testCosineInvariant() throws Exception {
+        int dim = 4;
+        int L = 1;
+        int n = 100;
+        String fieldName = "vector";
+        Supplier<float[]> vectorValue = () -> new float[] { randomFloat(), randomFloat(), randomFloat(), randomFloat() };
+        try (Directory d = newDirectory()) {
+            try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+                KnnFloatVectorField vectorField = new KnnFloatVectorField(fieldName, vectorValue.get(), VectorSimilarityFunction.COSINE);
+                Document document = new Document();
+                document.add(vectorField);
+                for (int i = 0; i < n; i++) {
+                    w.addDocument(document);
+                    vectorField.setVectorValue(vectorValue.get());
+                }
+                w.commit();
+            }
+            try (IndexReader reader = DirectoryReader.open(d)) {
+                float radius = 0.25f * (float) Math.sqrt(dim) * L;
+                float cos = (float) Math.cos(radius);
+                float score = (1 + radius) / 2f;
+                IndexSearcher searcher = LuceneTests.newSearcher(reader);
+                for (int i = 0; i < 10; i++) {
+                    TopDocs docs = searcher.search(
+                        new VectorSimilarityQuery(new KnnFloatVectorQuery(fieldName, vectorValue.get(), n), radius, score),
+                        n
+                    );
+                    for (ScoreDoc scoreDoc : docs.scoreDocs) {
+                        float dist = (2 * scoreDoc.score) - 1;
+                        float distCos = (float) Math.cos(dist);
+                        assertThat(scoreDoc.score, greaterThanOrEqualTo(score));
+                        assertThat(distCos, lessThanOrEqualTo(cos));
+                    }
+                }
+            }
+        }
+    }
+
+    public void testExplain() throws IOException {
+        try (Directory d = newDirectory()) {
+            try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+                Document document = new Document();
+                KnnFloatVectorField vectorField = new KnnFloatVectorField("float_vector", new float[] { 1, 1, 1 });
+                document.add(vectorField);
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 2, 1, 1 });
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 1, 2, 1 });
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 1, 1, 2 });
+                w.addDocument(document);
+                vectorField.setVectorValue(new float[] { 2, 2, 2 });
+                w.addDocument(document);
+            }
+            try (IndexReader reader = DirectoryReader.open(d)) {
+                IndexSearcher searcher = LuceneTests.newSearcher(reader);
+                Query q = searcher.rewrite(
+                    new VectorSimilarityQuery(new KnnFloatVectorQuery("float_vector", new float[] { 1, 1, 1 }, 5), 1f, 0.5f)
+                );
+                Weight w = q.createWeight(searcher, ScoreMode.COMPLETE, 1.0f);
+                Explanation ex = w.explain(searcher.getIndexReader().leaves().get(0), 1);
+                assertTrue(ex.isMatch());
+
+                ex = w.explain(searcher.getIndexReader().leaves().get(0), 4);
+                assertFalse(ex.isMatch());
+            }
+        }
+    }
+
+}

+ 1 - 1
test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java

@@ -259,7 +259,7 @@ public class RandomSearchRequestGenerator {
                 }
                 int k = randomIntBetween(1, 100);
                 int numCands = randomIntBetween(k, 1000);
-                knnSearchBuilders.add(new KnnSearchBuilder(field, vector, k, numCands));
+                knnSearchBuilders.add(new KnnSearchBuilder(field, vector, k, numCands, randomBoolean() ? null : randomFloat()));
             }
             builder.knnSearch(knnSearchBuilders);
         }

+ 15 - 3
test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java

@@ -70,7 +70,7 @@ public abstract class AbstractQueryVectorBuilderTestCase<T extends QueryVectorBu
     public final void testKnnSearchBuilderXContent() throws Exception {
         AbstractXContentTestCase.XContentTester<KnnSearchBuilder> tester = AbstractXContentTestCase.xContentTester(
             this::createParser,
-            () -> new KnnSearchBuilder(randomAlphaOfLength(10), createTestInstance(), 5, 10),
+            () -> new KnnSearchBuilder(randomAlphaOfLength(10), createTestInstance(), 5, 10, randomBoolean() ? null : randomFloat()),
             getToXContentParams(),
             KnnSearchBuilder::fromXContent
         );
@@ -79,7 +79,13 @@ public abstract class AbstractQueryVectorBuilderTestCase<T extends QueryVectorBu
 
     public final void testKnnSearchBuilderWireSerialization() throws IOException {
         for (int i = 0; i < NUMBER_OF_TEST_RUNS; i++) {
-            KnnSearchBuilder searchBuilder = new KnnSearchBuilder(randomAlphaOfLength(10), createTestInstance(), 5, 10);
+            KnnSearchBuilder searchBuilder = new KnnSearchBuilder(
+                randomAlphaOfLength(10),
+                createTestInstance(),
+                5,
+                10,
+                randomBoolean() ? null : randomFloat()
+            );
             KnnSearchBuilder serialized = copyWriteable(
                 searchBuilder,
                 getNamedWriteableRegistry(),
@@ -95,7 +101,13 @@ public abstract class AbstractQueryVectorBuilderTestCase<T extends QueryVectorBu
         for (int i = 0; i < NUMBER_OF_TEST_RUNS; i++) {
             float[] expected = randomVector(randomIntBetween(10, 1024));
             T queryVectorBuilder = createTestInstance(expected);
-            KnnSearchBuilder searchBuilder = new KnnSearchBuilder(randomAlphaOfLength(10), queryVectorBuilder, 5, 10);
+            KnnSearchBuilder searchBuilder = new KnnSearchBuilder(
+                randomAlphaOfLength(10),
+                queryVectorBuilder,
+                5,
+                10,
+                randomBoolean() ? null : randomFloat()
+            );
             KnnSearchBuilder serialized = copyWriteable(
                 searchBuilder,
                 getNamedWriteableRegistry(),

+ 1 - 1
x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java

@@ -893,7 +893,7 @@ public class DocumentLevelSecurityTests extends SecurityIntegTestCase {
         // Since there's no kNN search action at the transport layer, we just emulate
         // how the action works (it builds a kNN query under the hood)
         float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f };
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50, null);
 
         if (randomBoolean()) {
             query.addFilterQuery(new WildcardQueryBuilder("other", "value*"));

+ 3 - 3
x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java

@@ -400,7 +400,7 @@ public class FieldLevelSecurityTests extends SecurityIntegTestCase {
         // Since there's no kNN search action at the transport layer, we just emulate
         // how the action works (it builds a kNN query under the hood)
         float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f };
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 10);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 10, null);
 
         // user1 has access to vector field, so the query should match with the document:
         SearchResponse response = client().filterWithHeader(
@@ -426,7 +426,7 @@ public class FieldLevelSecurityTests extends SecurityIntegTestCase {
         assertNull(response.getHits().getAt(0).field("vector"));
 
         // user1 can access field1, so the filtered query should match with the document:
-        KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10).addFilterQuery(
+        KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10, null).addFilterQuery(
             QueryBuilders.matchQuery("field1", "value1")
         );
         response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))
@@ -436,7 +436,7 @@ public class FieldLevelSecurityTests extends SecurityIntegTestCase {
         assertHitCount(response, 1);
 
         // user1 cannot access field2, so the filtered query should not match with the document:
-        KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10).addFilterQuery(
+        KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10, null).addFilterQuery(
             QueryBuilders.matchQuery("field2", "value2")
         );
         response = client().filterWithHeader(Collections.singletonMap(BASIC_AUTH_HEADER, basicAuthHeaderValue("user1", USERS_PASSWD)))