Explorar o código

[ML] Rename semantic search input parameter (#91787)

Formerly `query_string` now `model_text`
David Kyle %!s(int64=2) %!d(string=hai) anos
pai
achega
751dc244f1

+ 7 - 7
docs/reference/search/semantic-search.asciidoc

@@ -14,7 +14,7 @@ by the model.
 ----
 GET my-index/_semantic_search
 {
-  "query_string": "A picture of a snow capped mountain",
+  "model_text": "A picture of a snow capped mountain",
   "model_id": "my-text-embedding-model",
   "knn": {
     "field": "text_embedding",
@@ -69,8 +69,8 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=routing]
 (Required, string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 
-`query_string`::
-(Required, string) The input text to embed.
+`model_text`::
+(Required, string) The input to the text embedding model.
 
 `knn`::
 (Required, object)
@@ -80,7 +80,7 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn]
 [%collapsible%open]
 ====
 `field`::
-(Required, string) 
+(Required, string)
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-field]
 
 `filter`::
@@ -88,11 +88,11 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-field]
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-filter]
 
 `k`::
-(Required, integer) 
+(Required, integer)
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-k]
 
 `num_candidates`::
-(Required, integer) 
+(Required, integer)
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-num-candidates]
 ====
 
@@ -101,7 +101,7 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-num-candidates]
 <<query-dsl,Query DSL>>.
 
 `text_embedding_config`::
-(Object, optional) Override certain setting of the text embedding model's 
+(Object, optional) Override certain setting of the text embedding model's
 configuration.
 +
 .Properties of text_embedding inference

+ 17 - 17
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/SemanticSearchAction.java

@@ -56,14 +56,14 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
 
     public static class Request extends ActionRequest implements IndicesRequest.Replaceable {
 
-        public static final ParseField QUERY_STRING = new ParseField("query_string"); // TODO a better name and update docs when changed
+        public static final ParseField MODEL_TEXT = new ParseField("model_text"); // TODO a better name and update docs when changed
         public static final ParseField TEXT_EMBEDDING_CONFIG = new ParseField("text_embedding_config");
 
         static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME);
 
         static {
             PARSER.declareString(Request.Builder::setModelId, InferModelAction.Request.MODEL_ID);
-            PARSER.declareString(Request.Builder::setQueryString, QUERY_STRING);
+            PARSER.declareString(Request.Builder::setModelText, MODEL_TEXT);
             PARSER.declareString(Request.Builder::setTimeout, SearchSourceBuilder.TIMEOUT_FIELD);
             PARSER.declareObject(
                 Request.Builder::setUpdate,
@@ -119,7 +119,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
 
         private String[] indices;
         private final String routing;
-        private final String queryString;
+        private final String modelText;
         private final String modelId;
         private final TimeValue inferenceTimeout;
         private final QueryBuilder query;
@@ -135,7 +135,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             super(in);
             indices = in.readStringArray();
             routing = in.readOptionalString();
-            queryString = in.readString();
+            modelText = in.readString();
             modelId = in.readString();
             inferenceTimeout = in.readOptionalTimeValue();
             query = in.readOptionalNamedWriteable(QueryBuilder.class);
@@ -151,7 +151,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
         Request(
             String[] indices,
             String routing,
-            String queryString,
+            String modelText,
             String modelId,
             QueryBuilder query,
             KnnQueryOptions knnQueryOptions,
@@ -165,7 +165,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
         ) {
             this.indices = Objects.requireNonNull(indices, "[indices] must not be null");
             this.routing = routing;
-            this.queryString = queryString;
+            this.modelText = modelText;
             this.modelId = modelId;
             this.query = query;
             this.knnQueryOptions = knnQueryOptions;
@@ -183,7 +183,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             super.writeTo(out);
             out.writeStringArray(indices);
             out.writeOptionalString(routing);
-            out.writeString(queryString);
+            out.writeString(modelText);
             out.writeString(modelId);
             out.writeOptionalTimeValue(inferenceTimeout);
             out.writeOptionalNamedWriteable(query);
@@ -220,8 +220,8 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             return routing;
         }
 
-        public String getQueryString() {
-            return queryString;
+        public String getModelText() {
+            return modelText;
         }
 
         public String getModelId() {
@@ -271,7 +271,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             Request request = (Request) o;
             return Arrays.equals(indices, request.indices)
                 && Objects.equals(routing, request.routing)
-                && Objects.equals(queryString, request.queryString)
+                && Objects.equals(modelText, request.modelText)
                 && Objects.equals(modelId, request.modelId)
                 && Objects.equals(inferenceTimeout, request.inferenceTimeout)
                 && Objects.equals(query, request.query)
@@ -288,7 +288,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
         public int hashCode() {
             int result = Objects.hash(
                 routing,
-                queryString,
+                modelText,
                 modelId,
                 inferenceTimeout,
                 query,
@@ -307,8 +307,8 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
         @Override
         public ActionRequestValidationException validate() {
             ActionRequestValidationException error = new ActionRequestValidationException();
-            if (queryString == null) {
-                error.addValidationError("query_string cannot be null");
+            if (modelText == null) {
+                error.addValidationError("model_text cannot be null");
             }
             if (modelId == null) {
                 error.addValidationError("model_id cannot be null");
@@ -325,7 +325,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
             private final String[] indices;
             private String routing;
             private String modelId;
-            private String queryString;
+            private String modelText;
             private TimeValue timeout;
             private TextEmbeddingConfigUpdate update;
             private QueryBuilder queryBuilder;
@@ -348,8 +348,8 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 this.modelId = modelId;
             }
 
-            void setQueryString(String queryString) {
-                this.queryString = queryString;
+            void setModelText(String modelText) {
+                this.modelText = modelText;
             }
 
             void setTimeout(TimeValue timeout) {
@@ -396,7 +396,7 @@ public class SemanticSearchAction extends ActionType<SemanticSearchAction.Respon
                 return new Request(
                     indices,
                     routing,
-                    queryString,
+                    modelText,
                     modelId,
                     queryBuilder,
                     knnSearchBuilder,

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/SemanticSearchActionRequestTests.java

@@ -101,7 +101,7 @@ public class SemanticSearchActionRequestTests extends AbstractWireSerializingTes
         var validation = action.validate();
         assertNotNull(validation);
         assertThat(validation.validationErrors(), hasSize(3));
-        assertThat(validation.validationErrors().get(0), containsString("query_string cannot be null"));
+        assertThat(validation.validationErrors().get(0), containsString("model_text cannot be null"));
         assertThat(validation.validationErrors().get(1), containsString("model_id cannot be null"));
         assertThat(validation.validationErrors().get(2), containsString("knn cannot be null"));
     }

+ 5 - 5
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java

@@ -272,19 +272,19 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
         return client().performRequest(request);
     }
 
-    protected Response semanticSearch(String index, String queryText, String modelId, String denseVectorFieldName) throws IOException {
+    protected Response semanticSearch(String index, String modelText, String modelId, String denseVectorFieldName) throws IOException {
         Request request = new Request("GET", index + "/_semantic_search?error_trace=true");
 
         request.setJsonEntity(String.format(Locale.ROOT, """
             {
               "model_id": "%s",
-              "query_string": "%s",
+              "model_text": "%s",
               "knn": {
                   "field": "%s",
                   "k": 5,
                   "num_candidates": 10
               }
-            }""", modelId, queryText, denseVectorFieldName));
+            }""", modelId, modelText, denseVectorFieldName));
         return client().performRequest(request);
     }
 
@@ -304,7 +304,7 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
         request.setJsonEntity(String.format(Locale.ROOT, """
             {
               "model_id": "%s",
-              "query_string": "%s",
+              "model_text": "%s",
               "knn": {
                   "field": "%s",
                   "k": 5,
@@ -322,7 +322,7 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
         request.setJsonEntity(String.format(Locale.ROOT, """
             {
               "model_id": "%s",
-              "query_string": "%s",
+              "model_text": "%s",
               "knn": {
                   "field": "%s",
                   "k": 5,

+ 2 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticSearchIT.java

@@ -226,7 +226,7 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
             request.setJsonEntity(String.format(Locale.ROOT, """
                 {
                   "model_id": "%s",
-                  "query_string": "my words",
+                  "model_text": "my words",
                   "knn": {
                       "field": "embedding",
                       "k": 3,
@@ -251,7 +251,7 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
             request.setJsonEntity(String.format(Locale.ROOT, """
                 {
                   "model_id": "%s",
-                  "query_string": "my words",
+                  "model_text": "my words",
                   "knn": {
                       "field": "embedding",
                       "k": 3,

+ 1 - 1
x-pack/plugin/ml/qa/semantic-search-tests/src/yamlRestTest/resources/rest-api-spec/test/semantic_search/10_basic.yml

@@ -117,7 +117,7 @@ setup:
         index: embedded_text
         body:
           model_id: text_embedding_model
-          query_string: "the octopus comforter smells"
+          model_text: "the octopus comforter smells"
           knn:
             field: embedding
             k: 3

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSemanticSearchAction.java

@@ -130,7 +130,7 @@ public class TransportSemanticSearchAction extends HandledTransportAction<Semant
         var inferenceRequest = new InferTrainedModelDeploymentAction.Request(
             request.getModelId(),
             request.getEmbeddingConfig(),
-            request.getQueryString(),
+            request.getModelText(),
             request.getInferenceTimeout()
         );
         inferenceRequest.setParentTask(parentTask);

+ 7 - 7
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/semantic_search.yml

@@ -135,7 +135,7 @@ setup:
         index: embedded_text
         body:
           model_id: text_embedding_model
-          query_string: "the octopus comforter smells"
+          model_text: "the octopus comforter smells"
           knn:
             field: embedding
             k: 3
@@ -150,7 +150,7 @@ setup:
         index: embedded_text
         body:
           model_id: text_embedding_model
-          query_string: "the octopus comforter smells"
+          model_text: "the octopus comforter smells"
           text_embedding_config:
             tokenization:
               bert:
@@ -178,7 +178,7 @@ setup:
         index: embedded_text
         body:
           model_id: text_embedding_model
-          query_string: "kids pyjamas with picture of a tractor"
+          model_text: "kids pyjamas with picture of a tractor"
           knn:
             field: source_text
             k: 2
@@ -198,7 +198,7 @@ setup:
         index: embedded_text_10_dims
         body:
           model_id: text_embedding_model
-          query_string: "kids pyjamas with picture of a tractor"
+          model_text: "kids pyjamas with picture of a tractor"
           knn:
             field: embedding
             k: 2
@@ -212,7 +212,7 @@ setup:
         index: embedded_text
         body:
           model_id: text_embedding_model
-          query_string: "kids pyjamas with picture of a tractor"
+          model_text: "kids pyjamas with picture of a tractor"
           knn:
             field: embedding
             k: 2
@@ -226,7 +226,7 @@ setup:
         index: embedded_text
         body:
           model_id: missing_model
-          query_string: "kids pyjamas with picture of a tractor"
+          model_text: "kids pyjamas with picture of a tractor"
           knn:
             field: embedding
             k: 2
@@ -253,4 +253,4 @@ setup:
         index: embedded_text
         body:
           model_id: text_embedding_model
-          query_string: "kids pyjamas with picture of a tractor"
+          model_text: "kids pyjamas with picture of a tractor"