Browse Source

[ML] Hybrid retrieval for Semantic search. (#91348)

Adds the query option to the _semantic_search endpoint for hybrid retrieval. 
Scoring is controlled by the boost fields of the knn search and the query.
David Kyle 3 years ago
parent
commit
b46ee9caaa

+ 56 - 21
docs/reference/search/semantic-search.asciidoc

@@ -10,25 +10,6 @@ The resulting dense vector is then used in a <<knn-search,k-nearest neighbor (kn
 created with the same text embedding model. The search results are semantically similar as learned
 created with the same text embedding model. The search results are semantically similar as learned
 by the model.
 by the model.
 
 
-////
-[source,console]
-----
-PUT my-index
-{
-  "mappings": {
-    "properties": {
-      "text_embedding": {
-        "type": "dense_vector",
-        "dims": 512,
-        "index": true,
-        "similarity": "cosine"
-      }
-    }
-  }
-}
-----
-////
-
 [source,console]
 [source,console]
 ----
 ----
 GET my-index/_semantic_search
 GET my-index/_semantic_search
@@ -110,15 +91,69 @@ value must be less than `num_candidates`.
 shard. Cannot exceed 10,000. {es} collects `num_candidates` results from each
 shard. Cannot exceed 10,000. {es} collects `num_candidates` results from each
 shard, then merges them to find the top `k` results. Increasing
 shard, then merges them to find the top `k` results. Increasing
 `num_candidates` tends to improve the accuracy of the final `k` results.
 `num_candidates` tends to improve the accuracy of the final `k` results.
-====
 
 
 `filter`::
 `filter`::
 (Optional, <<query-dsl,Query DSL object>>) Query to filter the documents that
 (Optional, <<query-dsl,Query DSL object>>) Query to filter the documents that
 can match. The kNN search will return the top `k` documents that also match
 can match. The kNN search will return the top `k` documents that also match
 this filter. The value can be a single query or a list of queries. If `filter`
 this filter. The value can be a single query or a list of queries. If `filter`
 is not provided, all documents are allowed to match.
 is not provided, all documents are allowed to match.
+====
 
 
+`query`::
+(Optional, <<query-dsl,query object>>) Defines the search definition using the
+<<query-dsl,Query DSL>>.
 
 
+`text_embedding_config`::
+(Object, optional) Override certain setting of the text embedding model's configuration
+.Properties of text_embedding inference
+[%collapsible%open]
+=====
+`results_field`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]
+
+`tokenization`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization]
++
+.Properties of tokenization
+[%collapsible%open]
+======
+`bert`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert]
++
+.Properties of bert
+[%collapsible%open]
+=======
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate]
+=======
+`roberta`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta]
++
+.Properties of roberta
+[%collapsible%open]
+=======
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate]
+=======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate]
+=======
+======
+=====
 
 
 include::{es-repo-dir}/search/search.asciidoc[tag=docvalue-fields-def]
 include::{es-repo-dir}/search/search.asciidoc[tag=docvalue-fields-def]
 include::{es-repo-dir}/search/search.asciidoc[tag=fields-param-def]
 include::{es-repo-dir}/search/search.asciidoc[tag=fields-param-def]
@@ -129,5 +164,5 @@ include::{es-repo-dir}/search/search.asciidoc[tag=stored-fields-def]
 [[semantic-search-api-response-body]]
 [[semantic-search-api-response-body]]
 ==== {api-response-body-title}
 ==== {api-response-body-title}
 
 
-A sementic search response has the same structure as a kNN search response.
+The semantic search response has the same structure as a kNN search response.
 
 

+ 54 - 36
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/SemanticSearchAction.java

@@ -57,6 +57,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
     public static class Request extends ActionRequest implements IndicesRequest.Replaceable {
     public static class Request extends ActionRequest implements IndicesRequest.Replaceable {
 
 
         public static final ParseField QUERY_STRING = new ParseField("query_string"); // TODO a better name and update docs when changed
         public static final ParseField QUERY_STRING = new ParseField("query_string"); // TODO a better name and update docs when changed
+        public static final ParseField TEXT_EMBEDDING_CONFIG = new ParseField("text_embedding_config");
 
 
         static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME);
         static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME);
 
 
@@ -67,15 +68,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             PARSER.declareObject(
             PARSER.declareObject(
                 Request.Builder::setUpdate,
                 Request.Builder::setUpdate,
                 (p, c) -> TextEmbeddingConfigUpdate.fromXContentStrict(p),
                 (p, c) -> TextEmbeddingConfigUpdate.fromXContentStrict(p),
-                InferTrainedModelDeploymentAction.Request.INFERENCE_CONFIG
+                TEXT_EMBEDDING_CONFIG
             );
             );
-            PARSER.declareObject(Request.Builder::setKnnSearch, (p, c) -> KnnQueryOptions.fromXContent(p), SearchSourceBuilder.KNN_FIELD);
-            PARSER.declareFieldArray(
-                Request.Builder::setFilters,
+            PARSER.declareObject(
+                Request.Builder::setQueryBuilder,
                 (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
                 (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
-                KnnSearchBuilder.FILTER_FIELD,
-                ObjectParser.ValueType.OBJECT_ARRAY
+                SearchSourceBuilder.QUERY_FIELD
             );
             );
+            PARSER.declareObject(Request.Builder::setKnnSearch, (p, c) -> KnnQueryOptions.fromXContent(p), SearchSourceBuilder.KNN_FIELD);
             PARSER.declareField(
             PARSER.declareField(
                 (p, request, c) -> request.setFetchSource(FetchSourceContext.fromXContent(p)),
                 (p, request, c) -> request.setFetchSource(FetchSourceContext.fromXContent(p)),
                 SearchSourceBuilder._SOURCE_FIELD,
                 SearchSourceBuilder._SOURCE_FIELD,
@@ -99,16 +99,21 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 SearchSourceBuilder.STORED_FIELDS_FIELD,
                 SearchSourceBuilder.STORED_FIELDS_FIELD,
                 ObjectParser.ValueType.STRING_ARRAY
                 ObjectParser.ValueType.STRING_ARRAY
             );
             );
+            PARSER.declareInt(Request.Builder::setSize, SearchSourceBuilder.SIZE_FIELD);
         }
         }
 
 
         public static Request parseRestRequest(RestRequest restRequest) throws IOException {
         public static Request parseRestRequest(RestRequest restRequest) throws IOException {
             Builder builder = new Builder(Strings.splitStringByCommaToArray(restRequest.param("index")));
             Builder builder = new Builder(Strings.splitStringByCommaToArray(restRequest.param("index")));
-            builder.setRouting(restRequest.param("routing"));
             if (restRequest.hasContentOrSourceParam()) {
             if (restRequest.hasContentOrSourceParam()) {
                 try (XContentParser contentParser = restRequest.contentOrSourceParamParser()) {
                 try (XContentParser contentParser = restRequest.contentOrSourceParamParser()) {
                     PARSER.parse(contentParser, builder, null);
                     PARSER.parse(contentParser, builder, null);
                 }
                 }
             }
             }
+            // Query parameters are preferred to body parameters.
+            if (restRequest.hasParam("size")) {
+                builder.setSize(restRequest.paramAsInt("size", -1));
+            }
+            builder.setRouting(restRequest.param("routing"));
             return builder.build();
             return builder.build();
         }
         }
 
 
@@ -117,13 +122,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
         private final String queryString;
         private final String queryString;
         private final String modelId;
         private final String modelId;
         private final TimeValue inferenceTimeout;
         private final TimeValue inferenceTimeout;
+        private final QueryBuilder query;
         private final KnnQueryOptions knnQueryOptions;
         private final KnnQueryOptions knnQueryOptions;
         private final TextEmbeddingConfigUpdate embeddingConfig;
         private final TextEmbeddingConfigUpdate embeddingConfig;
-        private final List<QueryBuilder> filters;
         private final FetchSourceContext fetchSource;
         private final FetchSourceContext fetchSource;
         private final List<FieldAndFormat> fields;
         private final List<FieldAndFormat> fields;
         private final List<FieldAndFormat> docValueFields;
         private final List<FieldAndFormat> docValueFields;
         private final StoredFieldsContext storedFields;
         private final StoredFieldsContext storedFields;
+        private final int size;
 
 
         public Request(StreamInput in) throws IOException {
         public Request(StreamInput in) throws IOException {
             super(in);
             super(in);
@@ -132,17 +138,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             queryString = in.readString();
             queryString = in.readString();
             modelId = in.readString();
             modelId = in.readString();
             inferenceTimeout = in.readOptionalTimeValue();
             inferenceTimeout = in.readOptionalTimeValue();
+            query = in.readOptionalNamedWriteable(QueryBuilder.class);
             knnQueryOptions = new KnnQueryOptions(in);
             knnQueryOptions = new KnnQueryOptions(in);
             embeddingConfig = in.readOptionalWriteable(TextEmbeddingConfigUpdate::new);
             embeddingConfig = in.readOptionalWriteable(TextEmbeddingConfigUpdate::new);
-            if (in.readBoolean()) {
-                filters = in.readNamedWriteableList(QueryBuilder.class);
-            } else {
-                filters = null;
-            }
             fetchSource = in.readOptionalWriteable(FetchSourceContext::readFrom);
             fetchSource = in.readOptionalWriteable(FetchSourceContext::readFrom);
             fields = in.readOptionalList(FieldAndFormat::new);
             fields = in.readOptionalList(FieldAndFormat::new);
             docValueFields = in.readOptionalList(FieldAndFormat::new);
             docValueFields = in.readOptionalList(FieldAndFormat::new);
             storedFields = in.readOptionalWriteable(StoredFieldsContext::new);
             storedFields = in.readOptionalWriteable(StoredFieldsContext::new);
+            size = in.readInt();
         }
         }
 
 
         Request(
         Request(
@@ -150,27 +153,29 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             String routing,
             String routing,
             String queryString,
             String queryString,
             String modelId,
             String modelId,
+            QueryBuilder query,
             KnnQueryOptions knnQueryOptions,
             KnnQueryOptions knnQueryOptions,
             TextEmbeddingConfigUpdate embeddingConfig,
             TextEmbeddingConfigUpdate embeddingConfig,
             TimeValue inferenceTimeout,
             TimeValue inferenceTimeout,
-            List<QueryBuilder> filters,
             FetchSourceContext fetchSource,
             FetchSourceContext fetchSource,
             List<FieldAndFormat> fields,
             List<FieldAndFormat> fields,
             List<FieldAndFormat> docValueFields,
             List<FieldAndFormat> docValueFields,
-            StoredFieldsContext storedFields
+            StoredFieldsContext storedFields,
+            int size
         ) {
         ) {
             this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
             this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
             this.routing = routing;
             this.routing = routing;
             this.queryString = queryString;
             this.queryString = queryString;
             this.modelId = modelId;
             this.modelId = modelId;
+            this.query = query;
             this.knnQueryOptions = knnQueryOptions;
             this.knnQueryOptions = knnQueryOptions;
             this.embeddingConfig = embeddingConfig;
             this.embeddingConfig = embeddingConfig;
             this.inferenceTimeout = inferenceTimeout;
             this.inferenceTimeout = inferenceTimeout;
-            this.filters = filters;
             this.fetchSource = fetchSource;
             this.fetchSource = fetchSource;
             this.fields = fields;
             this.fields = fields;
             this.docValueFields = docValueFields;
             this.docValueFields = docValueFields;
             this.storedFields = storedFields;
             this.storedFields = storedFields;
+            this.size = size;
         }
         }
 
 
         @Override
         @Override
@@ -181,18 +186,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             out.writeString(queryString);
             out.writeString(queryString);
             out.writeString(modelId);
             out.writeString(modelId);
             out.writeOptionalTimeValue(inferenceTimeout);
             out.writeOptionalTimeValue(inferenceTimeout);
+            out.writeOptionalNamedWriteable(query);
             knnQueryOptions.writeTo(out);
             knnQueryOptions.writeTo(out);
             out.writeOptionalWriteable(embeddingConfig);
             out.writeOptionalWriteable(embeddingConfig);
-            if (filters != null) {
-                out.writeBoolean(true);
-                out.writeNamedWriteableList(filters);
-            } else {
-                out.writeBoolean(false);
-            }
             out.writeOptionalWriteable(fetchSource);
             out.writeOptionalWriteable(fetchSource);
             out.writeOptionalCollection(fields);
             out.writeOptionalCollection(fields);
             out.writeOptionalCollection(docValueFields);
             out.writeOptionalCollection(docValueFields);
             out.writeOptionalWriteable(storedFields);
             out.writeOptionalWriteable(storedFields);
+            out.writeInt(size);
         }
         }
 
 
         @Override
         @Override
@@ -231,6 +232,10 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             return inferenceTimeout;
             return inferenceTimeout;
         }
         }
 
 
+        public QueryBuilder getQuery() {
+            return query;
+        }
+
         public KnnQueryOptions getKnnQueryOptions() {
         public KnnQueryOptions getKnnQueryOptions() {
             return knnQueryOptions;
             return knnQueryOptions;
         }
         }
@@ -239,10 +244,6 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             return embeddingConfig;
             return embeddingConfig;
         }
         }
 
 
-        public List<QueryBuilder> getFilters() {
-            return filters;
-        }
-
         public FetchSourceContext getFetchSource() {
         public FetchSourceContext getFetchSource() {
             return fetchSource;
             return fetchSource;
         }
         }
@@ -259,6 +260,10 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             return storedFields;
             return storedFields;
         }
         }
 
 
+        public int getSize() {
+            return size;
+        }
+
         @Override
         @Override
         public boolean equals(Object o) {
         public boolean equals(Object o) {
             if (this == o) return true;
             if (this == o) return true;
@@ -269,13 +274,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 && Objects.equals(queryString, request.queryString)
                 && Objects.equals(queryString, request.queryString)
                 && Objects.equals(modelId, request.modelId)
                 && Objects.equals(modelId, request.modelId)
                 && Objects.equals(inferenceTimeout, request.inferenceTimeout)
                 && Objects.equals(inferenceTimeout, request.inferenceTimeout)
+                && Objects.equals(query, request.query)
                 && Objects.equals(knnQueryOptions, request.knnQueryOptions)
                 && Objects.equals(knnQueryOptions, request.knnQueryOptions)
                 && Objects.equals(embeddingConfig, request.embeddingConfig)
                 && Objects.equals(embeddingConfig, request.embeddingConfig)
-                && Objects.equals(filters, request.filters)
                 && Objects.equals(fetchSource, request.fetchSource)
                 && Objects.equals(fetchSource, request.fetchSource)
                 && Objects.equals(fields, request.fields)
                 && Objects.equals(fields, request.fields)
                 && Objects.equals(docValueFields, request.docValueFields)
                 && Objects.equals(docValueFields, request.docValueFields)
-                && Objects.equals(storedFields, request.storedFields);
+                && Objects.equals(storedFields, request.storedFields)
+                && size == request.size;
         }
         }
 
 
         @Override
         @Override
@@ -285,13 +291,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 queryString,
                 queryString,
                 modelId,
                 modelId,
                 inferenceTimeout,
                 inferenceTimeout,
+                query,
                 knnQueryOptions,
                 knnQueryOptions,
                 embeddingConfig,
                 embeddingConfig,
-                filters,
                 fetchSource,
                 fetchSource,
                 fields,
                 fields,
                 docValueFields,
                 docValueFields,
-                storedFields
+                storedFields,
+                size
             );
             );
             result = 31 * result + Arrays.hashCode(indices);
             result = 31 * result + Arrays.hashCode(indices);
             return result;
             return result;
@@ -321,12 +328,13 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             private String queryString;
             private String queryString;
             private TimeValue timeout;
             private TimeValue timeout;
             private TextEmbeddingConfigUpdate update;
             private TextEmbeddingConfigUpdate update;
+            private QueryBuilder queryBuilder;
             private KnnQueryOptions knnSearchBuilder;
             private KnnQueryOptions knnSearchBuilder;
-            private List<QueryBuilder> filters;
             private FetchSourceContext fetchSource;
             private FetchSourceContext fetchSource;
             private List<FieldAndFormat> fields;
             private List<FieldAndFormat> fields;
             private List<FieldAndFormat> docValueFields;
             private List<FieldAndFormat> docValueFields;
             private StoredFieldsContext storedFields;
             private StoredFieldsContext storedFields;
+            private int size = -1;
 
 
             Builder(String[] indices) {
             Builder(String[] indices) {
                 this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
                 this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
@@ -360,8 +368,8 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 this.knnSearchBuilder = knnSearchBuilder;
                 this.knnSearchBuilder = knnSearchBuilder;
             }
             }
 
 
-            private void setFilters(List<QueryBuilder> filters) {
-                this.filters = filters;
+            void setQueryBuilder(QueryBuilder queryBuilder) {
+                this.queryBuilder = queryBuilder;
             }
             }
 
 
             private void setFetchSource(FetchSourceContext fetchSource) {
             private void setFetchSource(FetchSourceContext fetchSource) {
@@ -380,20 +388,25 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 this.storedFields = storedFields;
                 this.storedFields = storedFields;
             }
             }
 
 
+            private void setSize(int size) {
+                this.size = size;
+            }
+
             Request build() {
             Request build() {
                 return new Request(
                 return new Request(
                     indices,
                     indices,
                     routing,
                     routing,
                     queryString,
                     queryString,
                     modelId,
                     modelId,
+                    queryBuilder,
                     knnSearchBuilder,
                     knnSearchBuilder,
                     update,
                     update,
                     timeout,
                     timeout,
-                    filters,
                     fetchSource,
                     fetchSource,
                     fields,
                     fields,
                     docValueFields,
                     docValueFields,
-                    storedFields
+                    storedFields,
+                    size
                 );
                 );
             }
             }
         }
         }
@@ -528,7 +541,12 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             if (queryVector == null) {
             if (queryVector == null) {
                 throw new IllegalStateException("[query_vector] not set on the Knn query");
                 throw new IllegalStateException("[query_vector] not set on the Knn query");
             }
             }
-            return new KnnSearchBuilder(field, queryVector, k, numCands);
+            var builder = new KnnSearchBuilder(field, queryVector, k, numCands);
+            builder.boost(boost);
+            if (filterQueries.isEmpty() == false) {
+                builder.addFilterQueries(filterQueries);
+            }
+            return builder;
         }
         }
 
 
         @Override
         @Override

+ 17 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/SemanticSearchActionKnnQueryOptionsTests.java

@@ -19,6 +19,8 @@ import org.junit.Before;
 import java.util.List;
 import java.util.List;
 
 
 import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyList;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.sameInstance;
 
 
 public class SemanticSearchActionKnnQueryOptionsTests extends AbstractWireSerializingTestCase<SemanticSearchAction.KnnQueryOptions> {
 public class SemanticSearchActionKnnQueryOptionsTests extends AbstractWireSerializingTestCase<SemanticSearchAction.KnnQueryOptions> {
 
 
@@ -65,4 +67,19 @@ public class SemanticSearchActionKnnQueryOptionsTests extends AbstractWireSerial
     protected SemanticSearchAction.KnnQueryOptions createTestInstance() {
     protected SemanticSearchAction.KnnQueryOptions createTestInstance() {
         return randomInstance();
         return randomInstance();
     }
     }
+
+    public void testToKnnSearchBuilder() {
+        var knnOptions = new SemanticSearchAction.KnnQueryOptions("foo", 5, 100);
+        knnOptions.boost(20.0f);
+        var termsQuery = QueryBuilders.termQuery("foo", "bar");
+        knnOptions.addFilterQueries(List.of(termsQuery));
+
+        var knnSearch = knnOptions.toKnnSearchBuilder(new float[] { 0.1f, 0.2f });
+        assertEquals(5, knnSearch.k());
+        var knnQuery = knnSearch.toQueryBuilder();
+        assertEquals(100, knnQuery.numCands());
+        assertEquals("foo", knnQuery.getFieldName());
+        assertThat(knnQuery.filterQueries(), contains(sameInstance(termsQuery)));
+        assertEquals(20.0f, knnQuery.boost(), 0.001);
+    }
 }
 }

+ 5 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/SemanticSearchActionRequestTests.java

@@ -67,14 +67,15 @@ public class SemanticSearchActionRequestTests extends AbstractWireSerializingTes
             randomBoolean() ? null : randomAlphaOfLength(5),
             randomBoolean() ? null : randomAlphaOfLength(5),
             randomAlphaOfLength(5),
             randomAlphaOfLength(5),
             randomAlphaOfLength(5),
             randomAlphaOfLength(5),
+            randomBoolean() ? null : new TermsQueryBuilder("foo", "bar"),
             SemanticSearchActionKnnQueryOptionsTests.randomInstance(),
             SemanticSearchActionKnnQueryOptionsTests.randomInstance(),
             TextEmbeddingConfigUpdateTests.randomUpdate(),
             TextEmbeddingConfigUpdateTests.randomUpdate(),
             randomBoolean() ? null : TimeValue.timeValueSeconds(randomIntBetween(1, 10)),
             randomBoolean() ? null : TimeValue.timeValueSeconds(randomIntBetween(1, 10)),
-            randomBoolean() ? null : List.of(new TermsQueryBuilder("foo", "bar", "cat")),
             randomBoolean() ? null : FetchSourceContext.of(randomBoolean()),
             randomBoolean() ? null : FetchSourceContext.of(randomBoolean()),
             randomBoolean() ? null : List.of(new FieldAndFormat("foo", null)),
             randomBoolean() ? null : List.of(new FieldAndFormat("foo", null)),
             randomBoolean() ? null : List.of(new FieldAndFormat("foo", null)),
             randomBoolean() ? null : List.of(new FieldAndFormat("foo", null)),
-            randomBoolean() ? null : StoredFieldsContext.fromList(List.of("A", "B"))
+            randomBoolean() ? null : StoredFieldsContext.fromList(List.of("A", "B")),
+            randomBoolean() ? -1 : randomIntBetween(1, 10)
         );
         );
     }
     }
 
 
@@ -94,7 +95,8 @@ public class SemanticSearchActionRequestTests extends AbstractWireSerializingTes
             null,
             null,
             null,
             null,
             null,
             null,
-            null
+            null,
+            -1
         );
         );
         var validation = action.validate();
         var validation = action.validate();
         assertNotNull(validation);
         assertNotNull(validation);

+ 47 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java

@@ -262,7 +262,7 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
         return client().performRequest(request);
         return client().performRequest(request);
     }
     }
 
 
-    protected Response semanticSearch(String index, String query, String deploymentId, String denseVectorFieldName) throws IOException {
+    protected Response semanticSearch(String index, String queryText, String modelId, String denseVectorFieldName) throws IOException {
         Request request = new Request("GET", index + "/_semantic_search?error_trace=true");
         Request request = new Request("GET", index + "/_semantic_search?error_trace=true");
 
 
         request.setJsonEntity(String.format(Locale.ROOT, """
         request.setJsonEntity(String.format(Locale.ROOT, """
@@ -274,7 +274,52 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
                   "k": 5,
                   "k": 5,
                   "num_candidates": 10
                   "num_candidates": 10
               }
               }
-            }""", deploymentId, query, denseVectorFieldName));
+            }""", modelId, queryText, denseVectorFieldName));
+        return client().performRequest(request);
+    }
+
+    protected Response semanticSearchWithTermsFilter(
+        String index,
+        String queryText,
+        String filter,
+        String modelId,
+        String denseVectorFieldName
+    ) throws IOException {
+        Request request = new Request("GET", index + "/_semantic_search?error_trace=true");
+
+        String termsFilter = String.format(Locale.ROOT, """
+            {"term": {"filter_field": "%s"}}
+            """, filter);
+
+        request.setJsonEntity(String.format(Locale.ROOT, """
+            {
+              "model_id": "%s",
+              "query_string": "%s",
+              "knn": {
+                  "field": "%s",
+                  "k": 5,
+                  "num_candidates": 10,
+                  "filter": %s
+              }
+            }""", modelId, queryText, denseVectorFieldName, termsFilter));
+        return client().performRequest(request);
+    }
+
+    protected Response semanticSearchWithQuery(String index, String queryText, String query, String modelId, String denseVectorFieldName)
+        throws IOException {
+        Request request = new Request("GET", index + "/_semantic_search?error_trace=true");
+
+        request.setJsonEntity(String.format(Locale.ROOT, """
+            {
+              "model_id": "%s",
+              "query_string": "%s",
+              "knn": {
+                  "field": "%s",
+                  "k": 5,
+                  "num_candidates": 10
+              },
+              "query": %s
+            }""", modelId, queryText, denseVectorFieldName, query));
         return client().performRequest(request);
         return client().performRequest(request);
     }
     }
 
 

+ 143 - 5
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticSearchIT.java

@@ -15,8 +15,12 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Base64;
 import java.util.Base64;
 import java.util.List;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Map;
 
 
+import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.hasSize;
+
 /**
 /**
  * This test uses a tiny text embedding model to simulate an trained
  * This test uses a tiny text embedding model to simulate an trained
  * NLP model.The output tensor is randomly generated but the RNG is
  * NLP model.The output tensor is randomly generated but the RNG is
@@ -93,7 +97,7 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
     }
     }
 
 
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
-    public void testModel() throws IOException {
+    public void testSemanticSearch() throws IOException {
         String modelId = "semantic-search-test";
         String modelId = "semantic-search-test";
         String indexName = modelId + "-index";
         String indexName = modelId + "-index";
 
 
@@ -114,6 +118,7 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
             "the octopus comforter is leaking",
             "the octopus comforter is leaking",
             "washing machine smells"
             "washing machine smells"
         );
         );
+        List<String> filters = List.of("foo", "bar", "baz", "foo", "bar", "baz", "foo");
         List<List<Double>> embeddings = new ArrayList<>();
         List<List<Double>> embeddings = new ArrayList<>();
 
 
         // Generate the text embeddings via the inference API
         // Generate the text embeddings via the inference API
@@ -128,7 +133,7 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
 
 
         // index dense vectors
         // index dense vectors
         createVectorSearchIndex(indexName);
         createVectorSearchIndex(indexName);
-        bulkIndexDocs(inputs, embeddings, indexName);
+        bulkIndexDocs(inputs, filters, embeddings, indexName);
         forceMergeIndex(indexName);
         forceMergeIndex(indexName);
 
 
         // Test semantic search against the indexed vectors
         // Test semantic search against the indexed vectors
@@ -143,6 +148,133 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
             String sourceText = (String) MapHelper.dig("_source.source_text", topHit);
             String sourceText = (String) MapHelper.dig("_source.source_text", topHit);
             assertEquals(inputs.get(randomInput), sourceText);
             assertEquals(inputs.get(randomInput), sourceText);
         }
         }
+
+        // Test semantic search with filters
+        {
+            var semanticSearchResponse = semanticSearchWithTermsFilter(indexName, inputs.get(0), "foo", modelId, "embedding");
+            assertOkWithErrorMessage(semanticSearchResponse);
+
+            Map<String, Object> responseMap = responseAsMap(semanticSearchResponse);
+            List<Map<String, Object>> hits = (List<Map<String, Object>>) MapHelper.dig("hits.hits", responseMap);
+            assertThat(hits, hasSize(3));
+            for (var hit : hits) {
+                String filter = (String) MapHelper.dig("_source.filter_field", hit);
+                assertEquals("foo", filter);
+            }
+        }
+        {
+            var semanticSearchResponse = semanticSearchWithTermsFilter(indexName, inputs.get(2), "baz", modelId, "embedding");
+            assertOkWithErrorMessage(semanticSearchResponse);
+
+            Map<String, Object> responseMap = responseAsMap(semanticSearchResponse);
+            List<Map<String, Object>> hits = (List<Map<String, Object>>) MapHelper.dig("hits.hits", responseMap);
+            assertThat(hits, hasSize(2));
+            for (var hit : hits) {
+                String filter = (String) MapHelper.dig("_source.filter_field", hit);
+                assertEquals("baz", filter);
+            }
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testHybridSearch() throws IOException {
+        String modelId = "hybrid-semantic-search-test";
+        String indexName = modelId + "-index";
+
+        createTextEmbeddingModel(modelId);
+        putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
+        putVocabulary(
+            List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
+            modelId
+        );
+        startDeployment(modelId);
+
+        List<String> inputs = List.of(
+            "my words",
+            "the machine is leaking",
+            "washing machine",
+            "these are my words",
+            "the octopus comforter smells",
+            "the octopus comforter is leaking",
+            "washing machine smells"
+        );
+        List<String> filters = List.of("foo", "bar", "baz", "foo", "bar", "baz", "foo");
+        List<List<Double>> embeddings = new ArrayList<>();
+
+        // Generate the text embeddings via the inference API
+        // then index them for search
+        for (var input : inputs) {
+            Response inference = infer(input, modelId);
+            List<Map<String, Object>> responseMap = (List<Map<String, Object>>) entityAsMap(inference).get("inference_results");
+            Map<String, Object> inferenceResult = responseMap.get(0);
+            List<Double> embedding = (List<Double>) inferenceResult.get("predicted_value");
+            embeddings.add(embedding);
+        }
+
+        // index dense vectors
+        createVectorSearchIndex(indexName);
+        bulkIndexDocs(inputs, filters, embeddings, indexName);
+        forceMergeIndex(indexName);
+
+        String queryTemplate = """
+            {"match": {"source_text": {"query": "%s"}}}
+            """;
+
+        {
+            // combined query should return size documents where size > k
+            Request request = new Request("GET", indexName + "/_semantic_search");
+            request.setJsonEntity(String.format(Locale.ROOT, """
+                {
+                  "model_id": "%s",
+                  "query_string": "my words",
+                  "knn": {
+                      "field": "embedding",
+                      "k": 3,
+                      "num_candidates": 10,
+                      "boost": 10.0
+                  },
+                  "query": {"match_all": {}},
+                  "size": 7
+                }""", modelId));
+            var semanticSearchResponse = client().performRequest(request);
+            assertOkWithErrorMessage(semanticSearchResponse);
+
+            Map<String, Object> responseMap = responseAsMap(semanticSearchResponse);
+            int hitCount = (Integer) MapHelper.dig("hits.total.value", responseMap);
+            assertEquals(7, hitCount);
+        }
+        {
+            // boost the knn score, as the query is an exact match the unboosted
+            // score should be close to 1.0. Use an unrelated query so scores are
+            // not combined
+            Request request = new Request("GET", indexName + "/_semantic_search");
+            request.setJsonEntity(String.format(Locale.ROOT, """
+                {
+                  "model_id": "%s",
+                  "query_string": "my words",
+                  "knn": {
+                      "field": "embedding",
+                      "k": 3,
+                      "num_candidates": 10,
+                      "boost": 10.0
+                  },
+                  "query": {"match": {"source_text": {"query": "apricot unrelated"}}}
+                }""", modelId));
+            var semanticSearchResponse = client().performRequest(request);
+            assertOkWithErrorMessage(semanticSearchResponse);
+
+            Map<String, Object> responseMap = responseAsMap(semanticSearchResponse);
+            List<Map<String, Object>> hits = (List<Map<String, Object>>) MapHelper.dig("hits.hits", responseMap);
+            boolean found = false;
+            for (var hit : hits) {
+                String source = (String) MapHelper.dig("_source.source_text", hit);
+                if (source.equals("my words")) {
+                    assertThat((Double) MapHelper.dig("_score", hit), closeTo(10.0, 0.01));
+                    found = true;
+                }
+            }
+            assertTrue("should have found hit for string 'my words'", found);
+        }
     }
     }
 
 
     private void createVectorSearchIndex(String indexName) throws IOException {
     private void createVectorSearchIndex(String indexName) throws IOException {
@@ -154,6 +286,9 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
                   "source_text": {
                   "source_text": {
                     "type": "text"
                     "type": "text"
                   },
                   },
+                  "filter_field": {
+                    "type": "keyword"
+                  },
                   "embedding": {
                   "embedding": {
                     "type": "dense_vector",
                     "type": "dense_vector",
                     "dims": 100,
                     "dims": 100,
@@ -167,15 +302,18 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
         assertOkWithErrorMessage(response);
         assertOkWithErrorMessage(response);
     }
     }
 
 
-    private void bulkIndexDocs(List<String> inputs, List<List<Double>> embeddings, String indexName) throws IOException {
+    private void bulkIndexDocs(List<String> sourceText, List<String> filters, List<List<Double>> embeddings, String indexName)
+        throws IOException {
         String createAction = "{\"create\": {\"_index\": \"" + indexName + "\"}}\n";
         String createAction = "{\"create\": {\"_index\": \"" + indexName + "\"}}\n";
 
 
         StringBuilder bulkBuilder = new StringBuilder();
         StringBuilder bulkBuilder = new StringBuilder();
 
 
-        for (int i = 0; i < inputs.size(); i++) {
+        for (int i = 0; i < sourceText.size(); i++) {
             bulkBuilder.append(createAction);
             bulkBuilder.append(createAction);
             bulkBuilder.append("{\"source_text\": \"")
             bulkBuilder.append("{\"source_text\": \"")
-                .append(inputs.get(i))
+                .append(sourceText.get(i))
+                .append("\", \"filter_field\":\"")
+                .append(filters.get(i))
                 .append("\", \"embedding\":")
                 .append("\", \"embedding\":")
                 .append(embeddings.get(i))
                 .append(embeddings.get(i))
                 .append("}\n");
                 .append("}\n");

+ 6 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSemanticSearchAction.java

@@ -99,8 +99,12 @@ public class TransportSemanticSearchAction extends HandledTransportAction<Semant
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
         sourceBuilder.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
         sourceBuilder.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
         sourceBuilder.knnSearch(knnSearchBuilder);
         sourceBuilder.knnSearch(knnSearchBuilder);
-        sourceBuilder.size(knnSearchBuilder.k());
-
+        if (request.getSize() != -1) {
+            sourceBuilder.size(request.getSize());
+        }
+        if (request.getQuery() != null) {
+            sourceBuilder.query(request.getQuery());
+        }
         if (request.getFetchSource() != null) {
         if (request.getFetchSource() != null) {
             sourceBuilder.fetchSource(request.getFetchSource());
             sourceBuilder.fetchSource(request.getFetchSource());
         }
         }

+ 19 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/semantic_search.yml

@@ -145,6 +145,25 @@ setup:
     # See comment at the top of the file for the reason why
     # See comment at the top of the file for the reason why
     # the test cannot match on the text field
     # the test cannot match on the text field
 
 
+  - do:
+      semantic_search:
+        index: embedded_text
+        body:
+          model_id: text_embedding_model
+          query_string: "the octopus comforter smells"
+          text_embedding_config:
+            tokenization:
+              bert:
+                truncate: none
+          knn:
+            field: embedding
+            k: 3
+            num_candidates: 10
+  - gte: { inference_took: 0 }
+  - match: { hits.total.value: 3 }
+    # See comment at the top of the file for the reason why
+    # the test cannot match on the text field
+
 ---
 ---
 "Knn field is not a vector":
 "Knn field is not a vector":
   - do:
   - do: