Browse Source

[ML] Hybrid retrieval for Semantic search. (#91348)

Adds the query option to the _semantic_search endpoint for hybrid retrieval. 
Scoring is controlled by the boost fields of the knn search and the query.
David Kyle 3 years ago
parent
commit
b46ee9caaa

+ 56 - 21
docs/reference/search/semantic-search.asciidoc

@@ -10,25 +10,6 @@ The resulting dense vector is then used in a <<knn-search,k-nearest neighbor (kn
 created with the same text embedding model. The search results are semantically similar as learned
 by the model.
 
-////
-[source,console]
-----
-PUT my-index
-{
-  "mappings": {
-    "properties": {
-      "text_embedding": {
-        "type": "dense_vector",
-        "dims": 512,
-        "index": true,
-        "similarity": "cosine"
-      }
-    }
-  }
-}
-----
-////
-
 [source,console]
 ----
 GET my-index/_semantic_search
@@ -110,15 +91,69 @@ value must be less than `num_candidates`.
 shard. Cannot exceed 10,000. {es} collects `num_candidates` results from each
 shard, then merges them to find the top `k` results. Increasing
 `num_candidates` tends to improve the accuracy of the final `k` results.
-====
 
 `filter`::
 (Optional, <<query-dsl,Query DSL object>>) Query to filter the documents that
 can match. The kNN search will return the top `k` documents that also match
 this filter. The value can be a single query or a list of queries. If `filter`
 is not provided, all documents are allowed to match.
+====
 
+`query`::
+(Optional, <<query-dsl,query object>>) Defines the search definition using the
+<<query-dsl,Query DSL>>.
 
+`text_embedding_config`::
+(Object, optional) Override certain setting of the text embedding model's configuration
+.Properties of text_embedding inference
+[%collapsible%open]
+=====
+`results_field`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-results-field]
+
+`tokenization`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization]
++
+.Properties of tokenization
+[%collapsible%open]
+======
+`bert`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-bert]
++
+.Properties of bert
+[%collapsible%open]
+=======
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate]
+=======
+`roberta`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-roberta]
++
+.Properties of roberta
+[%collapsible%open]
+=======
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate]
+=======
+`mpnet`::::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-mpnet]
++
+.Properties of mpnet
+[%collapsible%open]
+=======
+`truncate`::::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-tokenization-truncate]
+=======
+======
+=====
 
 include::{es-repo-dir}/search/search.asciidoc[tag=docvalue-fields-def]
 include::{es-repo-dir}/search/search.asciidoc[tag=fields-param-def]
@@ -129,5 +164,5 @@ include::{es-repo-dir}/search/search.asciidoc[tag=stored-fields-def]
 [[semantic-search-api-response-body]]
 ==== {api-response-body-title}
 
-A sementic search response has the same structure as a kNN search response.
+The semantic search response has the same structure as a kNN search response.
 

+ 54 - 36
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/SemanticSearchAction.java

@@ -57,6 +57,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
     public static class Request extends ActionRequest implements IndicesRequest.Replaceable {
 
         public static final ParseField QUERY_STRING = new ParseField("query_string"); // TODO a better name and update docs when changed
+        public static final ParseField TEXT_EMBEDDING_CONFIG = new ParseField("text_embedding_config");
 
         static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME);
 
@@ -67,15 +68,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             PARSER.declareObject(
                 Request.Builder::setUpdate,
                 (p, c) -> TextEmbeddingConfigUpdate.fromXContentStrict(p),
-                InferTrainedModelDeploymentAction.Request.INFERENCE_CONFIG
+                TEXT_EMBEDDING_CONFIG
             );
-            PARSER.declareObject(Request.Builder::setKnnSearch, (p, c) -> KnnQueryOptions.fromXContent(p), SearchSourceBuilder.KNN_FIELD);
-            PARSER.declareFieldArray(
-                Request.Builder::setFilters,
+            PARSER.declareObject(
+                Request.Builder::setQueryBuilder,
                 (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p),
-                KnnSearchBuilder.FILTER_FIELD,
-                ObjectParser.ValueType.OBJECT_ARRAY
+                SearchSourceBuilder.QUERY_FIELD
             );
+            PARSER.declareObject(Request.Builder::setKnnSearch, (p, c) -> KnnQueryOptions.fromXContent(p), SearchSourceBuilder.KNN_FIELD);
             PARSER.declareField(
                 (p, request, c) -> request.setFetchSource(FetchSourceContext.fromXContent(p)),
                 SearchSourceBuilder._SOURCE_FIELD,
@@ -99,16 +99,21 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 SearchSourceBuilder.STORED_FIELDS_FIELD,
                 ObjectParser.ValueType.STRING_ARRAY
             );
+            PARSER.declareInt(Request.Builder::setSize, SearchSourceBuilder.SIZE_FIELD);
         }
 
         public static Request parseRestRequest(RestRequest restRequest) throws IOException {
             Builder builder = new Builder(Strings.splitStringByCommaToArray(restRequest.param("index")));
-            builder.setRouting(restRequest.param("routing"));
             if (restRequest.hasContentOrSourceParam()) {
                 try (XContentParser contentParser = restRequest.contentOrSourceParamParser()) {
                     PARSER.parse(contentParser, builder, null);
                 }
             }
+            // Query parameters are preferred to body parameters.
+            if (restRequest.hasParam("size")) {
+                builder.setSize(restRequest.paramAsInt("size", -1));
+            }
+            builder.setRouting(restRequest.param("routing"));
             return builder.build();
         }
 
@@ -117,13 +122,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
         private final String queryString;
         private final String modelId;
         private final TimeValue inferenceTimeout;
+        private final QueryBuilder query;
         private final KnnQueryOptions knnQueryOptions;
         private final TextEmbeddingConfigUpdate embeddingConfig;
-        private final List<QueryBuilder> filters;
         private final FetchSourceContext fetchSource;
         private final List<FieldAndFormat> fields;
         private final List<FieldAndFormat> docValueFields;
         private final StoredFieldsContext storedFields;
+        private final int size;
 
         public Request(StreamInput in) throws IOException {
             super(in);
@@ -132,17 +138,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             queryString = in.readString();
             modelId = in.readString();
             inferenceTimeout = in.readOptionalTimeValue();
+            query = in.readOptionalNamedWriteable(QueryBuilder.class);
             knnQueryOptions = new KnnQueryOptions(in);
             embeddingConfig = in.readOptionalWriteable(TextEmbeddingConfigUpdate::new);
-            if (in.readBoolean()) {
-                filters = in.readNamedWriteableList(QueryBuilder.class);
-            } else {
-                filters = null;
-            }
             fetchSource = in.readOptionalWriteable(FetchSourceContext::readFrom);
             fields = in.readOptionalList(FieldAndFormat::new);
             docValueFields = in.readOptionalList(FieldAndFormat::new);
             storedFields = in.readOptionalWriteable(StoredFieldsContext::new);
+            size = in.readInt();
         }
 
         Request(
@@ -150,27 +153,29 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             String routing,
             String queryString,
             String modelId,
+            QueryBuilder query,
             KnnQueryOptions knnQueryOptions,
             TextEmbeddingConfigUpdate embeddingConfig,
             TimeValue inferenceTimeout,
-            List<QueryBuilder> filters,
             FetchSourceContext fetchSource,
             List<FieldAndFormat> fields,
             List<FieldAndFormat> docValueFields,
-            StoredFieldsContext storedFields
+            StoredFieldsContext storedFields,
+            int size
         ) {
             this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
             this.routing = routing;
             this.queryString = queryString;
             this.modelId = modelId;
+            this.query = query;
             this.knnQueryOptions = knnQueryOptions;
             this.embeddingConfig = embeddingConfig;
             this.inferenceTimeout = inferenceTimeout;
-            this.filters = filters;
             this.fetchSource = fetchSource;
             this.fields = fields;
             this.docValueFields = docValueFields;
             this.storedFields = storedFields;
+            this.size = size;
         }
 
         @Override
@@ -181,18 +186,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             out.writeString(queryString);
             out.writeString(modelId);
             out.writeOptionalTimeValue(inferenceTimeout);
+            out.writeOptionalNamedWriteable(query);
             knnQueryOptions.writeTo(out);
             out.writeOptionalWriteable(embeddingConfig);
-            if (filters != null) {
-                out.writeBoolean(true);
-                out.writeNamedWriteableList(filters);
-            } else {
-                out.writeBoolean(false);
-            }
             out.writeOptionalWriteable(fetchSource);
             out.writeOptionalCollection(fields);
             out.writeOptionalCollection(docValueFields);
             out.writeOptionalWriteable(storedFields);
+            out.writeInt(size);
         }
 
         @Override
@@ -231,6 +232,10 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             return inferenceTimeout;
         }
 
+        public QueryBuilder getQuery() {
+            return query;
+        }
+
         public KnnQueryOptions getKnnQueryOptions() {
             return knnQueryOptions;
         }
@@ -239,10 +244,6 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             return embeddingConfig;
         }
 
-        public List<QueryBuilder> getFilters() {
-            return filters;
-        }
-
         public FetchSourceContext getFetchSource() {
             return fetchSource;
         }
@@ -259,6 +260,10 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             return storedFields;
         }
 
+        public int getSize() {
+            return size;
+        }
+
         @Override
         public boolean equals(Object o) {
             if (this == o) return true;
@@ -269,13 +274,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 && Objects.equals(queryString, request.queryString)
                 && Objects.equals(modelId, request.modelId)
                 && Objects.equals(inferenceTimeout, request.inferenceTimeout)
+                && Objects.equals(query, request.query)
                 && Objects.equals(knnQueryOptions, request.knnQueryOptions)
                 && Objects.equals(embeddingConfig, request.embeddingConfig)
-                && Objects.equals(filters, request.filters)
                 && Objects.equals(fetchSource, request.fetchSource)
                 && Objects.equals(fields, request.fields)
                 && Objects.equals(docValueFields, request.docValueFields)
-                && Objects.equals(storedFields, request.storedFields);
+                && Objects.equals(storedFields, request.storedFields)
+                && size == request.size;
         }
 
         @Override
@@ -285,13 +291,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 queryString,
                 modelId,
                 inferenceTimeout,
+                query,
                 knnQueryOptions,
                 embeddingConfig,
-                filters,
                 fetchSource,
                 fields,
                 docValueFields,
-                storedFields
+                storedFields,
+                size
             );
             result = 31 * result + Arrays.hashCode(indices);
             return result;
@@ -321,12 +328,13 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             private String queryString;
             private TimeValue timeout;
             private TextEmbeddingConfigUpdate update;
+            private QueryBuilder queryBuilder;
             private KnnQueryOptions knnSearchBuilder;
-            private List<QueryBuilder> filters;
             private FetchSourceContext fetchSource;
             private List<FieldAndFormat> fields;
             private List<FieldAndFormat> docValueFields;
             private StoredFieldsContext storedFields;
+            private int size = -1;
 
             Builder(String[] indices) {
                 this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
@@ -360,8 +368,8 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 this.knnSearchBuilder = knnSearchBuilder;
             }
 
-            private void setFilters(List<QueryBuilder> filters) {
-                this.filters = filters;
+            void setQueryBuilder(QueryBuilder queryBuilder) {
+                this.queryBuilder = queryBuilder;
             }
 
             private void setFetchSource(FetchSourceContext fetchSource) {
@@ -380,20 +388,25 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 this.storedFields = storedFields;
             }
 
+            private void setSize(int size) {
+                this.size = size;
+            }
+
             Request build() {
                 return new Request(
                     indices,
                     routing,
                     queryString,
                     modelId,
+                    queryBuilder,
                     knnSearchBuilder,
                     update,
                     timeout,
-                    filters,
                     fetchSource,
                     fields,
                     docValueFields,
-                    storedFields
+                    storedFields,
+                    size
                 );
             }
         }
@@ -528,7 +541,12 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             if (queryVector == null) {
                 throw new IllegalStateException("[query_vector] not set on the Knn query");
             }
-            return new KnnSearchBuilder(field, queryVector, k, numCands);
+            var builder = new KnnSearchBuilder(field, queryVector, k, numCands);
+            builder.boost(boost);
+            if (filterQueries.isEmpty() == false) {
+                builder.addFilterQueries(filterQueries);
+            }
+            return builder;
         }
 
         @Override

+ 17 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/SemanticSearchActionKnnQueryOptionsTests.java

@@ -19,6 +19,8 @@ import org.junit.Before;
 import java.util.List;
 
 import static java.util.Collections.emptyList;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.sameInstance;
 
 public class SemanticSearchActionKnnQueryOptionsTests extends AbstractWireSerializingTestCase<SemanticSearchAction.KnnQueryOptions> {
 
@@ -65,4 +67,19 @@ public class SemanticSearchActionKnnQueryOptionsTests extends AbstractWireSerial
     protected SemanticSearchAction.KnnQueryOptions createTestInstance() {
         return randomInstance();
     }
+
+    public void testToKnnSearchBuilder() {
+        var knnOptions = new SemanticSearchAction.KnnQueryOptions("foo", 5, 100);
+        knnOptions.boost(20.0f);
+        var termsQuery = QueryBuilders.termQuery("foo", "bar");
+        knnOptions.addFilterQueries(List.of(termsQuery));
+
+        var knnSearch = knnOptions.toKnnSearchBuilder(new float[] { 0.1f, 0.2f });
+        assertEquals(5, knnSearch.k());
+        var knnQuery = knnSearch.toQueryBuilder();
+        assertEquals(100, knnQuery.numCands());
+        assertEquals("foo", knnQuery.getFieldName());
+        assertThat(knnQuery.filterQueries(), contains(sameInstance(termsQuery)));
+        assertEquals(20.0f, knnQuery.boost(), 0.001);
+    }
 }

+ 5 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/SemanticSearchActionRequestTests.java

@@ -67,14 +67,15 @@ public class SemanticSearchActionRequestTests extends AbstractWireSerializingTes
             randomBoolean() ? null : randomAlphaOfLength(5),
             randomAlphaOfLength(5),
             randomAlphaOfLength(5),
+            randomBoolean() ? null : new TermsQueryBuilder("foo", "bar"),
             SemanticSearchActionKnnQueryOptionsTests.randomInstance(),
             TextEmbeddingConfigUpdateTests.randomUpdate(),
             randomBoolean() ? null : TimeValue.timeValueSeconds(randomIntBetween(1, 10)),
-            randomBoolean() ? null : List.of(new TermsQueryBuilder("foo", "bar", "cat")),
             randomBoolean() ? null : FetchSourceContext.of(randomBoolean()),
             randomBoolean() ? null : List.of(new FieldAndFormat("foo", null)),
             randomBoolean() ? null : List.of(new FieldAndFormat("foo", null)),
-            randomBoolean() ? null : StoredFieldsContext.fromList(List.of("A", "B"))
+            randomBoolean() ? null : StoredFieldsContext.fromList(List.of("A", "B")),
+            randomBoolean() ? -1 : randomIntBetween(1, 10)
         );
     }
 
@@ -94,7 +95,8 @@ public class SemanticSearchActionRequestTests extends AbstractWireSerializingTes
             null,
             null,
             null,
-            null
+            null,
+            -1
         );
         var validation = action.validate();
         assertNotNull(validation);

+ 47 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java

@@ -262,7 +262,7 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
         return client().performRequest(request);
     }
 
-    protected Response semanticSearch(String index, String query, String deploymentId, String denseVectorFieldName) throws IOException {
+    protected Response semanticSearch(String index, String queryText, String modelId, String denseVectorFieldName) throws IOException {
         Request request = new Request("GET", index + "/_semantic_search?error_trace=true");
 
         request.setJsonEntity(String.format(Locale.ROOT, """
@@ -274,7 +274,52 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
                   "k": 5,
                   "num_candidates": 10
               }
-            }""", deploymentId, query, denseVectorFieldName));
+            }""", modelId, queryText, denseVectorFieldName));
+        return client().performRequest(request);
+    }
+
+    protected Response semanticSearchWithTermsFilter(
+        String index,
+        String queryText,
+        String filter,
+        String modelId,
+        String denseVectorFieldName
+    ) throws IOException {
+        Request request = new Request("GET", index + "/_semantic_search?error_trace=true");
+
+        String termsFilter = String.format(Locale.ROOT, """
+            {"term": {"filter_field": "%s"}}
+            """, filter);
+
+        request.setJsonEntity(String.format(Locale.ROOT, """
+            {
+              "model_id": "%s",
+              "query_string": "%s",
+              "knn": {
+                  "field": "%s",
+                  "k": 5,
+                  "num_candidates": 10,
+                  "filter": %s
+              }
+            }""", modelId, queryText, denseVectorFieldName, termsFilter));
+        return client().performRequest(request);
+    }
+
+    protected Response semanticSearchWithQuery(String index, String queryText, String query, String modelId, String denseVectorFieldName)
+        throws IOException {
+        Request request = new Request("GET", index + "/_semantic_search?error_trace=true");
+
+        request.setJsonEntity(String.format(Locale.ROOT, """
+            {
+              "model_id": "%s",
+              "query_string": "%s",
+              "knn": {
+                  "field": "%s",
+                  "k": 5,
+                  "num_candidates": 10
+              },
+              "query": %s
+            }""", modelId, queryText, denseVectorFieldName, query));
         return client().performRequest(request);
     }
 

+ 143 - 5
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticSearchIT.java

@@ -15,8 +15,12 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Base64;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 
+import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.hasSize;
+
 /**
  * This test uses a tiny text embedding model to simulate an trained
  * NLP model.The output tensor is randomly generated but the RNG is
@@ -93,7 +97,7 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
     }
 
     @SuppressWarnings("unchecked")
-    public void testModel() throws IOException {
+    public void testSemanticSearch() throws IOException {
         String modelId = "semantic-search-test";
         String indexName = modelId + "-index";
 
@@ -114,6 +118,7 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
             "the octopus comforter is leaking",
             "washing machine smells"
         );
+        List<String> filters = List.of("foo", "bar", "baz", "foo", "bar", "baz", "foo");
         List<List<Double>> embeddings = new ArrayList<>();
 
         // Generate the text embeddings via the inference API
@@ -128,7 +133,7 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
 
         // index dense vectors
         createVectorSearchIndex(indexName);
-        bulkIndexDocs(inputs, embeddings, indexName);
+        bulkIndexDocs(inputs, filters, embeddings, indexName);
         forceMergeIndex(indexName);
 
         // Test semantic search against the indexed vectors
@@ -143,6 +148,133 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
             String sourceText = (String) MapHelper.dig("_source.source_text", topHit);
             assertEquals(inputs.get(randomInput), sourceText);
         }
+
+        // Test semantic search with filters
+        {
+            var semanticSearchResponse = semanticSearchWithTermsFilter(indexName, inputs.get(0), "foo", modelId, "embedding");
+            assertOkWithErrorMessage(semanticSearchResponse);
+
+            Map<String, Object> responseMap = responseAsMap(semanticSearchResponse);
+            List<Map<String, Object>> hits = (List<Map<String, Object>>) MapHelper.dig("hits.hits", responseMap);
+            assertThat(hits, hasSize(3));
+            for (var hit : hits) {
+                String filter = (String) MapHelper.dig("_source.filter_field", hit);
+                assertEquals("foo", filter);
+            }
+        }
+        {
+            var semanticSearchResponse = semanticSearchWithTermsFilter(indexName, inputs.get(2), "baz", modelId, "embedding");
+            assertOkWithErrorMessage(semanticSearchResponse);
+
+            Map<String, Object> responseMap = responseAsMap(semanticSearchResponse);
+            List<Map<String, Object>> hits = (List<Map<String, Object>>) MapHelper.dig("hits.hits", responseMap);
+            assertThat(hits, hasSize(2));
+            for (var hit : hits) {
+                String filter = (String) MapHelper.dig("_source.filter_field", hit);
+                assertEquals("baz", filter);
+            }
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testHybridSearch() throws IOException {
+        String modelId = "hybrid-semantic-search-test";
+        String indexName = modelId + "-index";
+
+        createTextEmbeddingModel(modelId);
+        putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
+        putVocabulary(
+            List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
+            modelId
+        );
+        startDeployment(modelId);
+
+        List<String> inputs = List.of(
+            "my words",
+            "the machine is leaking",
+            "washing machine",
+            "these are my words",
+            "the octopus comforter smells",
+            "the octopus comforter is leaking",
+            "washing machine smells"
+        );
+        List<String> filters = List.of("foo", "bar", "baz", "foo", "bar", "baz", "foo");
+        List<List<Double>> embeddings = new ArrayList<>();
+
+        // Generate the text embeddings via the inference API
+        // then index them for search
+        for (var input : inputs) {
+            Response inference = infer(input, modelId);
+            List<Map<String, Object>> responseMap = (List<Map<String, Object>>) entityAsMap(inference).get("inference_results");
+            Map<String, Object> inferenceResult = responseMap.get(0);
+            List<Double> embedding = (List<Double>) inferenceResult.get("predicted_value");
+            embeddings.add(embedding);
+        }
+
+        // index dense vectors
+        createVectorSearchIndex(indexName);
+        bulkIndexDocs(inputs, filters, embeddings, indexName);
+        forceMergeIndex(indexName);
+
+        String queryTemplate = """
+            {"match": {"source_text": {"query": "%s"}}}
+            """;
+
+        {
+            // combined query should return size documents where size > k
+            Request request = new Request("GET", indexName + "/_semantic_search");
+            request.setJsonEntity(String.format(Locale.ROOT, """
+                {
+                  "model_id": "%s",
+                  "query_string": "my words",
+                  "knn": {
+                      "field": "embedding",
+                      "k": 3,
+                      "num_candidates": 10,
+                      "boost": 10.0
+                  },
+                  "query": {"match_all": {}},
+                  "size": 7
+                }""", modelId));
+            var semanticSearchResponse = client().performRequest(request);
+            assertOkWithErrorMessage(semanticSearchResponse);
+
+            Map<String, Object> responseMap = responseAsMap(semanticSearchResponse);
+            int hitCount = (Integer) MapHelper.dig("hits.total.value", responseMap);
+            assertEquals(7, hitCount);
+        }
+        {
+            // boost the knn score, as the query is an exact match the unboosted
+            // score should be close to 1.0. Use an unrelated query so scores are
+            // not combined
+            Request request = new Request("GET", indexName + "/_semantic_search");
+            request.setJsonEntity(String.format(Locale.ROOT, """
+                {
+                  "model_id": "%s",
+                  "query_string": "my words",
+                  "knn": {
+                      "field": "embedding",
+                      "k": 3,
+                      "num_candidates": 10,
+                      "boost": 10.0
+                  },
+                  "query": {"match": {"source_text": {"query": "apricot unrelated"}}}
+                }""", modelId));
+            var semanticSearchResponse = client().performRequest(request);
+            assertOkWithErrorMessage(semanticSearchResponse);
+
+            Map<String, Object> responseMap = responseAsMap(semanticSearchResponse);
+            List<Map<String, Object>> hits = (List<Map<String, Object>>) MapHelper.dig("hits.hits", responseMap);
+            boolean found = false;
+            for (var hit : hits) {
+                String source = (String) MapHelper.dig("_source.source_text", hit);
+                if (source.equals("my words")) {
+                    assertThat((Double) MapHelper.dig("_score", hit), closeTo(10.0, 0.01));
+                    found = true;
+                }
+            }
+            assertTrue("should have found hit for string 'my words'", found);
+        }
     }
 
     private void createVectorSearchIndex(String indexName) throws IOException {
@@ -154,6 +286,9 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
                   "source_text": {
                     "type": "text"
                   },
+                  "filter_field": {
+                    "type": "keyword"
+                  },
                   "embedding": {
                     "type": "dense_vector",
                     "dims": 100,
@@ -167,15 +302,18 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
         assertOkWithErrorMessage(response);
     }
 
-    private void bulkIndexDocs(List<String> inputs, List<List<Double>> embeddings, String indexName) throws IOException {
+    private void bulkIndexDocs(List<String> sourceText, List<String> filters, List<List<Double>> embeddings, String indexName)
+        throws IOException {
         String createAction = "{\"create\": {\"_index\": \"" + indexName + "\"}}\n";
 
         StringBuilder bulkBuilder = new StringBuilder();
 
-        for (int i = 0; i < inputs.size(); i++) {
+        for (int i = 0; i < sourceText.size(); i++) {
             bulkBuilder.append(createAction);
             bulkBuilder.append("{\"source_text\": \"")
-                .append(inputs.get(i))
+                .append(sourceText.get(i))
+                .append("\", \"filter_field\":\"")
+                .append(filters.get(i))
                 .append("\", \"embedding\":")
                 .append(embeddings.get(i))
                 .append("}\n");

+ 6 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSemanticSearchAction.java

@@ -99,8 +99,12 @@ public class TransportSemanticSearchAction extends HandledTransportAction<Semant
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
         sourceBuilder.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
         sourceBuilder.knnSearch(knnSearchBuilder);
-        sourceBuilder.size(knnSearchBuilder.k());
-
+        if (request.getSize() != -1) {
+            sourceBuilder.size(request.getSize());
+        }
+        if (request.getQuery() != null) {
+            sourceBuilder.query(request.getQuery());
+        }
         if (request.getFetchSource() != null) {
             sourceBuilder.fetchSource(request.getFetchSource());
         }

+ 19 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/semantic_search.yml

@@ -145,6 +145,25 @@ setup:
     # See comment at the top of the file for the reason why
     # the test cannot match on the text field
 
+  - do:
+      semantic_search:
+        index: embedded_text
+        body:
+          model_id: text_embedding_model
+          query_string: "the octopus comforter smells"
+          text_embedding_config:
+            tokenization:
+              bert:
+                truncate: none
+          knn:
+            field: embedding
+            k: 3
+            num_candidates: 10
+  - gte: { inference_took: 0 }
+  - match: { hits.total.value: 3 }
+    # See comment at the top of the file for the reason why
+    # the test cannot match on the text field
+
 ---
 "Knn field is not a vector":
   - do: