浏览代码

[ML] Utilise parallel allocations where the inference request contains multiple documents (#92359)

Divide work from the _infer API among all allocations
David Kyle 2 年之前
父节点
当前提交
6acfbbcd8b
共有 23 个文件被更改,包括 593 次插入365 次删除
  1. 2 2
      build.gradle
  2. 6 0
      docs/changelog/92359.yaml
  3. 6 6
      docs/reference/ml/trained-models/apis/infer-trained-model-deployment.asciidoc
  4. 7 9
      docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc
  5. 78 42
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java
  6. 62 48
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java
  7. 32 7
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java
  8. 2 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdate.java
  9. 21 10
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java
  10. 4 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentRequestsTests.java
  11. 12 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java
  12. 29 29
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java
  13. 77 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  14. 10 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticSearchIT.java
  15. 4 4
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java
  16. 15 10
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java
  17. 21 86
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java
  18. 145 61
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java
  19. 19 12
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSemanticSearchAction.java
  20. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java
  21. 10 11
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java
  22. 3 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelAction.java
  23. 27 11
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java

+ 2 - 2
build.gradle

@@ -137,9 +137,9 @@ tasks.register("verifyVersions") {
  * after the backport of the backcompat code is complete.
  */
 
-boolean bwc_tests_enabled = true
+boolean bwc_tests_enabled = false
 // place a PR link here when committing bwc changes:
-String bwc_tests_disabled_issue = ""
+String bwc_tests_disabled_issue = "https://github.com/elastic/elasticsearch/pull/92359"
 if (bwc_tests_enabled == false) {
   if (bwc_tests_disabled_issue.isEmpty()) {
     throw new GradleException("bwc_tests_disabled_issue must be set when bwc_tests_enabled == false")

+ 6 - 0
docs/changelog/92359.yaml

@@ -0,0 +1,6 @@
+pr: 92359
+summary: Utilise parallel allocations where the inference request contains multiple
+  documents
+area: Machine Learning
+type: bug
+issues: []

+ 6 - 6
docs/reference/ml/trained-models/apis/infer-trained-model-deployment.asciidoc

@@ -46,8 +46,8 @@ Controls the amount of time to wait for {infer} results. Defaults to 10 seconds.
 `docs`::
 (Required, array)
 An array of objects to pass to the model for inference. The objects should
-contain a field matching your configured trained model input. Typically, the 
-field name is `text_field`. Currently, only a single value is allowed.
+contain a field matching your configured trained model input. Typically, the
+field name is `text_field`.
 
 ////
 [[infer-trained-model-deployment-results]]
@@ -62,7 +62,7 @@ field name is `text_field`. Currently, only a single value is allowed.
 [[infer-trained-model-deployment-example]]
 == {api-examples-title}
 
-The response depends on the task the model is trained for. If it is a text 
+The response depends on the task the model is trained for. If it is a text
 classification task, the response is the score. For example:
 
 [source,console]
@@ -123,7 +123,7 @@ The API returns in this case:
 ----
 // NOTCONSOLE
 
-Zero-shot classification tasks require extra configuration defining the class 
+Zero-shot classification tasks require extra configuration defining the class
 labels. These labels are passed in the zero-shot inference config.
 
 [source,console]
@@ -150,7 +150,7 @@ POST _ml/trained_models/model2/deployment/_infer
 --------------------------------------------------
 // TEST[skip:TBD]
 
-The API returns the predicted label and the confidence, as well as the top 
+The API returns the predicted label and the confidence, as well as the top
 classes:
 
 [source,console-result]
@@ -205,7 +205,7 @@ POST _ml/trained_models/model2/deployment/_infer
 --------------------------------------------------
 // TEST[skip:TBD]
 
-When the input has been truncated due to the limit imposed by the model's 
+When the input has been truncated due to the limit imposed by the model's
 `max_sequence_length` the `is_truncated` field appears in the response.
 
 [source,console-result]

+ 7 - 9
docs/reference/ml/trained-models/apis/infer-trained-model.asciidoc

@@ -6,10 +6,10 @@
 <titleabbrev>Infer trained model</titleabbrev>
 ++++
 
-Evaluates a trained model. The model may be any supervised model either trained 
+Evaluates a trained model. The model may be any supervised model either trained
 by {dfanalytics} or imported.
 
-NOTE: For model deployments with caching enabled, results may be returned 
+NOTE: For model deployments with caching enabled, results may be returned
 directly from the {infer} cache.
 
 [[infer-trained-model-request]]
@@ -49,9 +49,7 @@ Controls the amount of time to wait for {infer} results. Defaults to 10 seconds.
 (Required, array)
 An array of objects to pass to the model for inference. The objects should
 contain the fields matching your configured trained model input. Typically for
-NLP models, the field name is `text_field`. Currently for NLP models, only a
-single value is allowed. For {dfanalytics} or imported classification or
-regression models, more than one value is allowed.
+NLP models, the field name is `text_field`.
 
 //Begin inference_config
 `inference_config`::
@@ -104,7 +102,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-nlp-fill-mask]
 =====
 `num_top_classes`::::
 (Optional, integer)
-Number of top predicted tokens to return for replacing the mask token. Defaults 
+Number of top predicted tokens to return for replacing the mask token. Defaults
 to `0`.
 
 `results_field`::::
@@ -275,7 +273,7 @@ The maximum amount of words in the answer. Defaults to `15`.
 
 `num_top_classes`::::
 (Optional, integer)
-The number the top found answers to return. Defaults to `0`, meaning only the 
+The number the top found answers to return. Defaults to `0`, meaning only the
 best found answer is returned.
 
 `question`::::
@@ -372,7 +370,7 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-text-classific
 
 `num_top_classes`::::
 (Optional, integer)
-Specifies the number of top class predictions to return. Defaults to all classes 
+Specifies the number of top class predictions to return. Defaults to all classes
 (-1).
 
 `results_field`::::
@@ -884,7 +882,7 @@ POST _ml/trained_models/model2/_infer
 --------------------------------------------------
 // TEST[skip:TBD]
 
-When the input has been truncated due to the limit imposed by the model's 
+When the input has been truncated due to the limit imposed by the model's
 `max_sequence_length` the `is_truncated` field appears in the response.
 
 [source,console-result]

+ 78 - 42
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java

@@ -23,11 +23,12 @@ import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -80,45 +81,55 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
         private final List<Map<String, Object>> objectsToInfer;
         private final InferenceConfigUpdate update;
         private final boolean previouslyLicensed;
-        private final TimeValue timeout;
+        private TimeValue inferenceTimeout;
+        // textInput added for uses that accept a query string
+        // and do know which field the model expects to find its
+        // input and so cannot construct a document.
+        private final List<String> textInput;
 
-        public Request(String modelId, boolean previouslyLicensed) {
-            this(modelId, Collections.emptyList(), RegressionConfigUpdate.EMPTY_PARAMS, TimeValue.MAX_VALUE, previouslyLicensed);
-        }
-
-        public Request(
-            String modelId,
-            List<Map<String, Object>> objectsToInfer,
-            InferenceConfigUpdate inferenceConfig,
-            TimeValue timeout,
-            boolean previouslyLicensed
-        ) {
-            this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
-            this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, DOCS.getPreferredName()));
-            this.update = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
-            this.previouslyLicensed = previouslyLicensed;
-            this.timeout = timeout;
-        }
-
-        public Request(
+        public static Request forDocs(
             String modelId,
-            List<Map<String, Object>> objectsToInfer,
-            InferenceConfigUpdate inferenceConfig,
+            List<Map<String, Object>> docs,
+            InferenceConfigUpdate update,
             boolean previouslyLicensed
         ) {
-            this(modelId, objectsToInfer, inferenceConfig, TimeValue.MAX_VALUE, previouslyLicensed);
+            return new Request(
+                ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.MODEL_ID),
+                update,
+                ExceptionsHelper.requireNonNull(Collections.unmodifiableList(docs), DOCS),
+                null,
+                DEFAULT_TIMEOUT,
+                previouslyLicensed
+            );
         }
 
-        public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfigUpdate update, boolean previouslyLicensed) {
-            this(
+        public static Request forTextInput(String modelId, InferenceConfigUpdate update, List<String> textInput) {
+            return new Request(
                 modelId,
-                Collections.singletonList(ExceptionsHelper.requireNonNull(objectToInfer, DOCS.getPreferredName())),
                 update,
-                TimeValue.MAX_VALUE,
-                previouslyLicensed
+                List.of(),
+                ExceptionsHelper.requireNonNull(textInput, "inference text input"),
+                DEFAULT_TIMEOUT,
+                false
             );
         }
 
+        Request(
+            String modelId,
+            InferenceConfigUpdate inferenceConfigUpdate,
+            List<Map<String, Object>> docs,
+            List<String> textInput,
+            TimeValue inferenceTimeout,
+            boolean previouslyLicensed
+        ) {
+            this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
+            this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(docs, DOCS.getPreferredName()));
+            this.update = ExceptionsHelper.requireNonNull(inferenceConfigUpdate, "inference_config");
+            this.textInput = textInput;
+            this.previouslyLicensed = previouslyLicensed;
+            this.inferenceTimeout = inferenceTimeout;
+        }
+
         public Request(StreamInput in) throws IOException {
             super(in);
             this.modelId = in.readString();
@@ -126,9 +137,22 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
             this.update = in.readNamedWriteable(InferenceConfigUpdate.class);
             this.previouslyLicensed = in.readBoolean();
             if (in.getVersion().onOrAfter(Version.V_8_3_0)) {
-                this.timeout = in.readTimeValue();
+                this.inferenceTimeout = in.readTimeValue();
             } else {
-                this.timeout = TimeValue.MAX_VALUE;
+                this.inferenceTimeout = TimeValue.MAX_VALUE;
+            }
+            if (in.getVersion().onOrAfter(Version.V_8_7_0)) {
+                textInput = in.readOptionalStringList();
+            } else {
+                textInput = null;
+            }
+        }
+
+        public int numberOfDocuments() {
+            if (textInput != null) {
+                return textInput.size();
+            } else {
+                return objectsToInfer.size();
             }
         }
 
@@ -140,6 +164,10 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
             return objectsToInfer;
         }
 
+        public List<String> getTextInput() {
+            return textInput;
+        }
+
         public InferenceConfigUpdate getUpdate() {
             return update;
         }
@@ -148,8 +176,12 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
             return previouslyLicensed;
         }
 
-        public TimeValue getTimeout() {
-            return timeout;
+        public TimeValue getInferenceTimeout() {
+            return inferenceTimeout;
+        }
+
+        public void setInferenceTimeout(TimeValue inferenceTimeout) {
+            this.inferenceTimeout = inferenceTimeout;
         }
 
         @Override
@@ -165,7 +197,10 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
             out.writeNamedWriteable(update);
             out.writeBoolean(previouslyLicensed);
             if (out.getVersion().onOrAfter(Version.V_8_3_0)) {
-                out.writeTimeValue(timeout);
+                out.writeTimeValue(inferenceTimeout);
+            }
+            if (out.getVersion().onOrAfter(Version.V_8_7_0)) {
+                out.writeOptionalStringCollection(textInput);
             }
         }
 
@@ -177,8 +212,9 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
             return Objects.equals(modelId, that.modelId)
                 && Objects.equals(update, that.update)
                 && Objects.equals(previouslyLicensed, that.previouslyLicensed)
-                && Objects.equals(timeout, that.timeout)
-                && Objects.equals(objectsToInfer, that.objectsToInfer);
+                && Objects.equals(inferenceTimeout, that.inferenceTimeout)
+                && Objects.equals(objectsToInfer, that.objectsToInfer)
+                && Objects.equals(textInput, that.textInput);
         }
 
         @Override
@@ -188,7 +224,7 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
 
         @Override
         public int hashCode() {
-            return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed, timeout);
+            return Objects.hash(modelId, objectsToInfer, update, previouslyLicensed, inferenceTimeout, textInput);
         }
 
         public static class Builder {
@@ -196,7 +232,7 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
             private String modelId;
             private List<Map<String, Object>> docs;
             private TimeValue timeout;
-            private InferenceConfigUpdate update;
+            private InferenceConfigUpdate update = new EmptyConfigUpdate();
 
             private Builder() {}
 
@@ -229,7 +265,7 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
             }
 
             public Request build() {
-                return new Request(modelId, docs, update, timeout, false);
+                return new Request(modelId, update, docs, null, timeout, false);
             }
         }
 
@@ -302,12 +338,12 @@ public class InferModelAction extends ActionType<InferModelAction.Response> {
         }
 
         public static class Builder {
-            private List<InferenceResults> inferenceResults;
+            private List<InferenceResults> inferenceResults = new ArrayList<>();
             private String modelId;
             private boolean isLicensed;
 
-            public Builder setInferenceResults(List<InferenceResults> inferenceResults) {
-                this.inferenceResults = inferenceResults;
+            public Builder addInferenceResults(List<InferenceResults> inferenceResults) {
+                this.inferenceResults.addAll(inferenceResults);
                 return this;
             }
 

+ 62 - 48
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java

@@ -44,7 +44,14 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
 
     public static final InferTrainedModelDeploymentAction INSTANCE = new InferTrainedModelDeploymentAction();
 
-    // TODO Review security level
+    /**
+     * Do not call this action directly, use InferModelAction instead
+     * which will perform various checks and set the node the request
+     * should execute on.
+     *
+     * The action is poorly named as once it was publicly accessible
+     * and exposed through a REST API now it _must_ only called internally.
+     */
     public static final String NAME = "cluster:monitor/xpack/ml/trained_models/deployment/infer";
 
     public InferTrainedModelDeploymentAction() {
@@ -94,22 +101,38 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
         // textInput added for uses that accept a query string
         // and do know which field the model expects to find its
         // input and so cannot construct a document.
-        private final String textInput;
+        private final List<String> textInput;
 
-        public Request(String modelId, InferenceConfigUpdate update, List<Map<String, Object>> docs, TimeValue inferenceTimeout) {
-            this.modelId = ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.MODEL_ID);
-            this.docs = ExceptionsHelper.requireNonNull(Collections.unmodifiableList(docs), DOCS);
-            this.update = update;
-            this.inferenceTimeout = inferenceTimeout;
-            this.textInput = null;
+        public static Request forDocs(
+            String modelId,
+            InferenceConfigUpdate update,
+            List<Map<String, Object>> docs,
+            TimeValue inferenceTimeout
+        ) {
+            return new Request(
+                ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.MODEL_ID),
+                update,
+                ExceptionsHelper.requireNonNull(Collections.unmodifiableList(docs), DOCS),
+                null,
+                false,
+                inferenceTimeout
+            );
         }
 
-        public Request(String modelId, InferenceConfigUpdate update, String textInput, TimeValue inferenceTimeout) {
-            this.modelId = ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.MODEL_ID);
-            this.docs = List.of();
-            this.textInput = ExceptionsHelper.requireNonNull(textInput, "inference text input");
-            this.update = update;
-            this.inferenceTimeout = inferenceTimeout;
+        public static Request forTextInput(
+            String modelId,
+            InferenceConfigUpdate update,
+            List<String> textInput,
+            TimeValue inferenceTimeout
+        ) {
+            return new Request(
+                ExceptionsHelper.requireNonNull(modelId, InferModelAction.Request.MODEL_ID),
+                update,
+                List.of(),
+                ExceptionsHelper.requireNonNull(textInput, "inference text input"),
+                false,
+                inferenceTimeout
+            );
         }
 
         // for tests
@@ -117,7 +140,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             String modelId,
             InferenceConfigUpdate update,
             List<Map<String, Object>> docs,
-            String textInput,
+            List<String> textInput,
             boolean skipQueue,
             TimeValue inferenceTimeout
         ) {
@@ -139,7 +162,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
                 skipQueue = in.readBoolean();
             }
             if (in.getVersion().onOrAfter(Version.V_8_7_0)) {
-                textInput = in.readOptionalString();
+                textInput = in.readOptionalStringList();
             } else {
                 textInput = null;
             }
@@ -153,7 +176,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             return docs;
         }
 
-        public String getTextInput() {
+        public List<String> getTextInput() {
             return textInput;
         }
 
@@ -196,10 +219,6 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
                 if (docs.isEmpty() && textInput == null) {
                     validationException = addValidationError("at least one document is required ", validationException);
                 }
-                if (docs.size() > 1) {
-                    // TODO support multiple docs
-                    validationException = addValidationError("multiple documents are not supported", validationException);
-                }
             }
             return validationException;
         }
@@ -215,7 +234,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
                 out.writeBoolean(skipQueue);
             }
             if (out.getVersion().onOrAfter(Version.V_8_7_0)) {
-                out.writeOptionalString(textInput);
+                out.writeOptionalStringCollection(textInput);
             }
         }
 
@@ -254,7 +273,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             private TimeValue timeout;
             private InferenceConfigUpdate update;
             private boolean skipQueue = false;
-            private String textInput;
+            private List<String> textInput;
 
             private Builder() {}
 
@@ -282,7 +301,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
                 return setInferenceTimeout(TimeValue.parseTimeValue(inferenceTimeout, TIMEOUT.getPreferredName()));
             }
 
-            public Builder setTextInput(String textInput) {
+            public Builder setTextInput(List<String> textInput) {
                 this.textInput = textInput;
                 return this;
             }
@@ -300,50 +319,45 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
 
     public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject {
 
-        private final InferenceResults results;
-        private long tookMillis;
+        private final List<InferenceResults> results;
 
-        public Response(InferenceResults result, long tookMillis) {
+        public Response(List<InferenceResults> results) {
             super(Collections.emptyList(), Collections.emptyList());
-            this.results = Objects.requireNonNull(result);
-            this.tookMillis = tookMillis;
+            this.results = Objects.requireNonNull(results);
         }
 
         public Response(StreamInput in) throws IOException {
             super(in);
-            results = in.readNamedWriteable(InferenceResults.class);
+
+            // Multiple results added in 8.7.0
             if (in.getVersion().onOrAfter(Version.V_8_7_0)) {
-                tookMillis = in.readVLong();
+                results = in.readNamedWriteableList(InferenceResults.class);
+            } else {
+                results = List.of(in.readNamedWriteable(InferenceResults.class));
             }
         }
 
-        @Override
-        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-            builder.startObject();
-            results.toXContent(builder, params);
-            builder.endObject();
-            return builder;
-        }
-
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             super.writeTo(out);
-            out.writeNamedWriteable(results);
+
             if (out.getVersion().onOrAfter(Version.V_8_7_0)) {
-                out.writeVLong(tookMillis);
+                out.writeNamedWriteableList(results);
+            } else {
+                out.writeNamedWriteable(results.get(0));
             }
         }
 
-        public InferenceResults getResults() {
+        public List<InferenceResults> getResults() {
             return results;
         }
 
-        public long getTookMillis() {
-            return tookMillis;
-        }
-
-        public void setTookMillis(long tookMillis) {
-            this.tookMillis = tookMillis;
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            results.get(0).toXContent(builder, params);
+            builder.endObject();
+            return builder;
         }
     }
 }

+ 32 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java

@@ -14,6 +14,7 @@ import org.elasticsearch.cluster.SimpleDiffable;
 import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.ParseField;
@@ -177,7 +178,7 @@ public class TrainedModelAssignment implements SimpleDiffable<TrainedModelAssign
             .toArray(String[]::new);
     }
 
-    public Optional<String> selectRandomStartedNodeWeighedOnAllocations() {
+    public List<Tuple<String, Integer>> selectRandomStartedNodesWeighedOnAllocationsForNRequests(int numberOfRequests) {
         List<String> nodeIds = new ArrayList<>(nodeRoutingTable.size());
         List<Integer> cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size());
         int allocationSum = 0;
@@ -189,18 +190,42 @@ public class TrainedModelAssignment implements SimpleDiffable<TrainedModelAssign
             }
         }
 
+        if (nodeIds.isEmpty()) {
+            return List.of();
+        }
+
         if (allocationSum == 0) {
             // If we are in a mixed cluster where there are assignments prior to introducing allocation distribution
             // we could have a zero-sum of allocations. We fall back to returning a random started node.
-            return nodeIds.isEmpty() ? Optional.empty() : Optional.of(nodeIds.get(Randomness.get().nextInt(nodeIds.size())));
+            int[] counts = new int[nodeIds.size()];
+            for (int i = 0; i < numberOfRequests; i++) {
+                counts[Randomness.get().nextInt(nodeIds.size())]++;
+            }
+
+            var nodeCounts = new ArrayList<Tuple<String, Integer>>();
+            for (int i = 0; i < counts.length; i++) {
+                nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i]));
+            }
+            return nodeCounts;
+        }
+
+        int[] counts = new int[nodeIds.size()];
+        var randomIter = Randomness.get().ints(numberOfRequests, 1, allocationSum + 1).iterator();
+        for (int i = 0; i < numberOfRequests; i++) {
+            int randomInt = randomIter.nextInt();
+            int nodeIndex = Collections.binarySearch(cumulativeAllocations, randomInt);
+            if (nodeIndex < 0) {
+                nodeIndex = -nodeIndex - 1;
+            }
+
+            counts[nodeIndex]++;
         }
 
-        int randomInt = Randomness.get().ints(1, 1, allocationSum + 1).iterator().nextInt();
-        int nodeIndex = Collections.binarySearch(cumulativeAllocations, randomInt);
-        if (nodeIndex < 0) {
-            nodeIndex = -nodeIndex - 1;
+        var nodeCounts = new ArrayList<Tuple<String, Integer>>();
+        for (int i = 0; i < counts.length; i++) {
+            nodeCounts.add(new Tuple<>(nodeIds.get(i), counts[i]));
         }
-        return Optional.of(nodeIds.get(nodeIndex));
+        return nodeCounts;
     }
 
     public Optional<String> getReason() {

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdate.java

@@ -28,6 +28,8 @@ public class TextEmbeddingConfigUpdate extends NlpConfigUpdate implements NamedX
 
     public static final String NAME = TextEmbeddingConfig.NAME;
 
+    public static TextEmbeddingConfigUpdate EMPTY_INSTANCE = new TextEmbeddingConfigUpdate(null, null);
+
     public static TextEmbeddingConfigUpdate fromMap(Map<String, Object> map) {
         Map<String, Object> options = new HashMap<>(map);
         String resultsField = (String) options.remove(RESULTS_FIELD.getPreferredName());

+ 21 - 10
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java

@@ -35,6 +35,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassifica
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdateTests;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
@@ -46,14 +47,17 @@ public class InferModelActionRequestTests extends AbstractBWCWireSerializationTe
     @Override
     protected Request createTestInstance() {
         return randomBoolean()
-            ? new Request(
+            ? Request.forDocs(
                 randomAlphaOfLength(10),
                 Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()),
                 randomInferenceConfigUpdate(),
-                TimeValue.parseTimeValue(randomTimeValue(), null, "test"),
                 randomBoolean()
             )
-            : new Request(randomAlphaOfLength(10), randomMap(), randomInferenceConfigUpdate(), randomBoolean());
+            : Request.forTextInput(
+                randomAlphaOfLength(10),
+                randomInferenceConfigUpdate(),
+                Arrays.asList(generateRandomStringArray(3, 5, false))
+            );
     }
 
     private static InferenceConfigUpdate randomInferenceConfigUpdate() {
@@ -115,20 +119,27 @@ public class InferModelActionRequestTests extends AbstractBWCWireSerializationTe
         } else {
             adjustedUpdate = currentUpdate;
         }
-        return version.before(Version.V_8_3_0)
-            ? new Request(
+
+        if (version.before(Version.V_8_3_0)) {
+            return new Request(
                 instance.getModelId(),
-                instance.getObjectsToInfer(),
                 adjustedUpdate,
+                instance.getObjectsToInfer(),
+                null,
                 TimeValue.MAX_VALUE,
                 instance.isPreviouslyLicensed()
-            )
-            : new Request(
+            );
+        } else if (version.before(Version.V_8_7_0)) {
+            return new Request(
                 instance.getModelId(),
-                instance.getObjectsToInfer(),
                 adjustedUpdate,
-                instance.getTimeout(),
+                instance.getObjectsToInfer(),
+                null,
+                instance.getInferenceTimeout(),
                 instance.isPreviouslyLicensed()
             );
+        }
+
+        return instance;
     }
 }

+ 4 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentRequestsTests.java

@@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpd
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdateTests;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 
@@ -37,10 +38,10 @@ public class InferTrainedModelDeploymentRequestsTests extends AbstractWireSerial
         boolean createQueryStringRequest = randomBoolean();
 
         if (createQueryStringRequest) {
-            return new InferTrainedModelDeploymentAction.Request(
+            return InferTrainedModelDeploymentAction.Request.forTextInput(
                 randomAlphaOfLength(4),
                 randomBoolean() ? null : randomInferenceConfigUpdate(),
-                randomAlphaOfLength(6),
+                Arrays.asList(generateRandomStringArray(4, 7, false)),
                 randomBoolean() ? null : TimeValue.parseTimeValue(randomTimeValue(), "timeout")
             );
         } else {
@@ -49,7 +50,7 @@ public class InferTrainedModelDeploymentRequestsTests extends AbstractWireSerial
                 () -> randomMap(1, 3, () -> Tuple.tuple(randomAlphaOfLength(7), randomAlphaOfLength(7)))
             );
 
-            return new InferTrainedModelDeploymentAction.Request(
+            return InferTrainedModelDeploymentAction.Request.forDocs(
                 randomAlphaOfLength(4),
                 randomBoolean() ? null : randomInferenceConfigUpdate(),
                 docs,

+ 12 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentResponseTests.java

@@ -16,6 +16,8 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvide
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResultsTests;
 import org.junit.Before;
 
+import java.util.List;
+
 public class InferTrainedModelDeploymentResponseTests extends AbstractBWCWireSerializationTestCase<
     InferTrainedModelDeploymentAction.Response> {
 
@@ -45,7 +47,14 @@ public class InferTrainedModelDeploymentResponseTests extends AbstractBWCWireSer
 
     @Override
     protected InferTrainedModelDeploymentAction.Response createTestInstance() {
-        return new InferTrainedModelDeploymentAction.Response(TextEmbeddingResultsTests.createRandomResults(), randomLongBetween(1, 200));
+        return new InferTrainedModelDeploymentAction.Response(
+            List.of(
+                TextEmbeddingResultsTests.createRandomResults(),
+                TextEmbeddingResultsTests.createRandomResults(),
+                TextEmbeddingResultsTests.createRandomResults(),
+                TextEmbeddingResultsTests.createRandomResults()
+            )
+        );
     }
 
     @Override
@@ -53,8 +62,8 @@ public class InferTrainedModelDeploymentResponseTests extends AbstractBWCWireSer
         InferTrainedModelDeploymentAction.Response instance,
         Version version
     ) {
-        if (version.before(Version.V_8_6_0)) {
-            return new InferTrainedModelDeploymentAction.Response(instance.getResults(), 0);
+        if (version.before(Version.V_8_7_0)) {
+            return new InferTrainedModelDeploymentAction.Response(instance.getResults().subList(0, 1));
         }
 
         return instance;

+ 29 - 29
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.test.AbstractXContentSerializingTestCase;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@@ -19,17 +20,16 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskPar
 import org.elasticsearch.xpack.core.ml.stats.CountAccumulator;
 
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
 import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
 
@@ -157,62 +157,62 @@ public class TrainedModelAssignmentTests extends AbstractXContentSerializingTest
         );
     }
 
-    public void testSelectRandomStartedNodeWeighedOnAllocations_GivenNoStartedAllocations() {
+    public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenNoStartedAllocations() {
         TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5));
         builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, ""));
         builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STOPPED, ""));
         TrainedModelAssignment assignment = builder.build();
 
-        assertThat(assignment.selectRandomStartedNodeWeighedOnAllocations().isEmpty(), is(true));
+        assertThat(assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1).isEmpty(), is(true));
     }
 
-    public void testSelectRandomStartedNodeWeighedOnAllocations_GivenSingleStartedNode() {
+    public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSingleStartedNode() {
         TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5));
         builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, ""));
         TrainedModelAssignment assignment = builder.build();
 
-        Optional<String> node = assignment.selectRandomStartedNodeWeighedOnAllocations();
+        var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1);
 
-        assertThat(node.isPresent(), is(true));
-        assertThat(node.get(), equalTo("node-1"));
+        assertThat(nodes, hasSize(1));
+        assertThat(nodes.get(0), equalTo(new Tuple<>("node-1", 1)));
     }
 
-    public void testSelectRandomStartedNodeWeighedOnAllocations_GivenMultipleStartedNodes() {
+    public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMultipleStartedNodes() {
         TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6));
         builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, ""));
         builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, ""));
         builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, ""));
         TrainedModelAssignment assignment = builder.build();
 
-        final long selectionCount = 10000;
+        final int selectionCount = 10000;
         final CountAccumulator countsPerNodeAccumulator = new CountAccumulator();
-        for (int i = 0; i < selectionCount; i++) {
-            Optional<String> node = assignment.selectRandomStartedNodeWeighedOnAllocations();
-            assertThat(node.isPresent(), is(true));
-            countsPerNodeAccumulator.add(node.get(), 1L);
-        }
+        var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount);
 
-        Map<String, Long> countsPerNode = countsPerNodeAccumulator.asMap();
-        assertThat(countsPerNode.keySet(), contains("node-1", "node-2", "node-3"));
-        assertThat(countsPerNode.get("node-1") + countsPerNode.get("node-2") + countsPerNode.get("node-3"), equalTo(selectionCount));
+        assertThat(nodes, hasSize(3));
+        assertThat(nodes.stream().mapToInt(Tuple::v2).sum(), equalTo(selectionCount));
+        var asMap = new HashMap<String, Integer>();
+        for (var node : nodes) {
+            asMap.put(node.v1(), node.v2());
+        }
 
-        assertValueWithinPercentageOfExpectedRatio(countsPerNode.get("node-1"), selectionCount, 1.0 / 6.0, 0.2);
-        assertValueWithinPercentageOfExpectedRatio(countsPerNode.get("node-2"), selectionCount, 2.0 / 6.0, 0.2);
-        assertValueWithinPercentageOfExpectedRatio(countsPerNode.get("node-3"), selectionCount, 3.0 / 6.0, 0.2);
+        assertValueWithinPercentageOfExpectedRatio(asMap.get("node-1"), selectionCount, 1.0 / 6.0, 0.2);
+        assertValueWithinPercentageOfExpectedRatio(asMap.get("node-2"), selectionCount, 2.0 / 6.0, 0.2);
+        assertValueWithinPercentageOfExpectedRatio(asMap.get("node-3"), selectionCount, 3.0 / 6.0, 0.2);
     }
 
-    public void testSelectRandomStartedNodeWeighedOnAllocations_GivenMultipleStartedNodesWithZeroAllocations() {
+    public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMultipleStartedNodesWithZeroAllocations() {
         TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6));
         builder.addRoutingEntry("node-1", new RoutingInfo(0, 0, RoutingState.STARTED, ""));
         builder.addRoutingEntry("node-2", new RoutingInfo(0, 0, RoutingState.STARTED, ""));
         builder.addRoutingEntry("node-3", new RoutingInfo(0, 0, RoutingState.STARTED, ""));
         TrainedModelAssignment assignment = builder.build();
-        final long selectionCount = 1000;
-        Set<String> selectedNodes = new HashSet<>();
-        for (int i = 0; i < selectionCount; i++) {
-            Optional<String> selectedNode = assignment.selectRandomStartedNodeWeighedOnAllocations();
-            assertThat(selectedNode.isPresent(), is(true));
-            selectedNodes.add(selectedNode.get());
+        final int selectionCount = 1000;
+        var nodeCounts = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount);
+        assertThat(nodeCounts, hasSize(3));
+
+        var selectedNodes = new HashSet<String>();
+        for (var node : nodeCounts) {
+            selectedNodes.add(node.v1());
         }
 
         assertThat(selectedNodes, contains("node-1", "node-2", "node-3"));

+ 77 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -21,8 +21,10 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
 import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Base64;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Queue;
 import java.util.concurrent.ConcurrentLinkedQueue;
@@ -419,6 +421,79 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         assertThat(ex.getMessage(), containsString("Could not find trained model [missing_model]"));
     }
 
+    @SuppressWarnings("unchecked")
+    public void testInferWithMultipleDocs() throws IOException {
+        String modelId = "infer_multi_docs";
+        // Use the text embedding model from SemanticSearchIT so
+        // that each response can be linked to the originating request.
+        // The test ensures the responses are returned in the same order
+        // as the requests
+        createTextEmbeddingModel(modelId);
+        putModelDefinition(modelId, SemanticSearchIT.BASE_64_ENCODED_MODEL, SemanticSearchIT.RAW_MODEL_SIZE);
+        putVocabulary(
+            List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
+            modelId
+        );
+        startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString());
+
+        List<String> inputs = List.of(
+            "my words",
+            "the machine is leaking",
+            "washing machine",
+            "these are my words",
+            "the octopus comforter smells",
+            "the octopus comforter is leaking",
+            "washing machine smells"
+        );
+
+        List<List<Double>> expectedEmbeddings = new ArrayList<>();
+
+        // Generate the text embeddings one at a time using the _infer API
+        // then index them for search
+        for (var input : inputs) {
+            Response inference = infer(input, modelId);
+            List<Map<String, Object>> responseMap = (List<Map<String, Object>>) entityAsMap(inference).get("inference_results");
+            Map<String, Object> inferenceResult = responseMap.get(0);
+            List<Double> embedding = (List<Double>) inferenceResult.get("predicted_value");
+            expectedEmbeddings.add(embedding);
+        }
+
+        // Now do the same with all documents sent at once
+        var docsBuilder = new StringBuilder();
+        int numInputs = inputs.size();
+        for (int i = 0; i < numInputs - 1; i++) {
+            docsBuilder.append("{\"input\":\"").append(inputs.get(i)).append("\"},");
+        }
+        docsBuilder.append("{\"input\":\"").append(inputs.get(numInputs - 1)).append("\"}");
+
+        {
+            Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer");
+            request.setJsonEntity(String.format(Locale.ROOT, """
+                {  "docs": [%s] }
+                """, docsBuilder));
+            Response response = client().performRequest(request);
+            var responseMap = entityAsMap(response);
+            List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results");
+            assertThat(inferenceResults, hasSize(numInputs));
+
+            // Check the result order matches the input order by comparing
+            // the to the pre-calculated embeddings
+            for (int i = 0; i < numInputs; i++) {
+                List<Double> embedding = (List<Double>) inferenceResults.get(i).get("predicted_value");
+                assertArrayEquals(expectedEmbeddings.get(i).toArray(), embedding.toArray());
+            }
+        }
+        {
+            // the deprecated deployment/_infer endpoint does not support multiple docs
+            Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_infer");
+            request.setJsonEntity(String.format(Locale.ROOT, """
+                {  "docs": [%s] }
+                """, docsBuilder));
+            Exception ex = expectThrows(Exception.class, () -> client().performRequest(request));
+            assertThat(ex.getMessage(), containsString("multiple documents are not supported"));
+        }
+    }
+
     public void testGetPytorchModelWithDefinition() throws IOException {
         String model = "should-fail-get";
         createPassThroughModel(model);
@@ -476,7 +551,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
         assertThat(
             response,
             allOf(
-                containsString("model [not-deployed] must be deployed to use. Please deploy with the start trained model deployment API."),
+                containsString("Model [not-deployed] must be deployed to use. Please deploy with the start trained model deployment API."),
                 containsString("error"),
                 not(containsString("warning"))
             )
@@ -499,7 +574,7 @@ public class PyTorchModelIT extends PyTorchModelRestTestCase {
                   }
             """);
         Exception ex = expectThrows(Exception.class, () -> client().performRequest(request));
-        assertThat(ex.getMessage(), containsString("Trained model [not-deployed] is not deployed."));
+        assertThat(ex.getMessage(), containsString("Model [not-deployed] must be deployed to use."));
     }
 
     public void testTruncation() throws IOException {

+ 10 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/SemanticSearchIT.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.integration;
 
 import org.elasticsearch.client.Request;
 import org.elasticsearch.client.Response;
+import org.elasticsearch.client.ResponseException;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 
@@ -19,6 +20,7 @@ import java.util.List;
 import java.util.Map;
 
 import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.hasSize;
 
 /**
@@ -277,6 +279,14 @@ public class SemanticSearchIT extends PyTorchModelRestTestCase {
         }
     }
 
+    public void testSearchWithMissingModel() throws IOException {
+        String modelId = "missing-model";
+        String indexName = modelId + "-index";
+
+        var e = expectThrows(ResponseException.class, () -> semanticSearch(indexName, "the machine is leaking", modelId, "embedding"));
+        assertThat(e.getMessage(), containsString("Could not find trained model [missing-model]"));
+    }
+
     private void createVectorSearchIndex(String indexName) throws IOException {
         Request createIndex = new Request("PUT", "/" + indexName);
         createIndex.setJsonEntity("""

+ 4 - 4
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java

@@ -664,7 +664,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
         PlainActionFuture<InferModelAction.Response> inferModelSuccess = PlainActionFuture.newFuture();
         client().execute(
             InferModelAction.INSTANCE,
-            new InferModelAction.Request(
+            InferModelAction.Request.forDocs(
                 modelId,
                 Collections.singletonList(Collections.emptyMap()),
                 RegressionConfigUpdate.EMPTY_PARAMS,
@@ -685,7 +685,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
         ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
             client().execute(
                 InferModelAction.INSTANCE,
-                new InferModelAction.Request(
+                InferModelAction.Request.forDocs(
                     modelId,
                     Collections.singletonList(Collections.emptyMap()),
                     RegressionConfigUpdate.EMPTY_PARAMS,
@@ -701,7 +701,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
         inferModelSuccess = PlainActionFuture.newFuture();
         client().execute(
             InferModelAction.INSTANCE,
-            new InferModelAction.Request(
+            InferModelAction.Request.forDocs(
                 modelId,
                 Collections.singletonList(Collections.emptyMap()),
                 RegressionConfigUpdate.EMPTY_PARAMS,
@@ -721,7 +721,7 @@ public class MachineLearningLicensingIT extends BaseMlIntegTestCase {
         PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
         client().execute(
             InferModelAction.INSTANCE,
-            new InferModelAction.Request(
+            InferModelAction.Request.forDocs(
                 modelId,
                 Collections.singletonList(Collections.emptyMap()),
                 RegressionConfigUpdate.EMPTY_PARAMS,

+ 15 - 10
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java

@@ -171,14 +171,14 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
         });
 
         // Test regression
-        InferModelAction.Request request = new InferModelAction.Request(modelId1, toInfer, RegressionConfigUpdate.EMPTY_PARAMS, true);
+        InferModelAction.Request request = InferModelAction.Request.forDocs(modelId1, toInfer, RegressionConfigUpdate.EMPTY_PARAMS, true);
         InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet();
         assertThat(
             response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults) i).value()).collect(Collectors.toList()),
             contains(1.3, 1.25)
         );
 
-        request = new InferModelAction.Request(modelId1, toInfer2, RegressionConfigUpdate.EMPTY_PARAMS, true);
+        request = InferModelAction.Request.forDocs(modelId1, toInfer2, RegressionConfigUpdate.EMPTY_PARAMS, true);
         response = client().execute(InferModelAction.INSTANCE, request).actionGet();
         assertThat(
             response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults) i).value()).collect(Collectors.toList()),
@@ -186,7 +186,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
         );
 
         // Test classification
-        request = new InferModelAction.Request(modelId2, toInfer, ClassificationConfigUpdate.EMPTY_PARAMS, true);
+        request = InferModelAction.Request.forDocs(modelId2, toInfer, ClassificationConfigUpdate.EMPTY_PARAMS, true);
         response = client().execute(InferModelAction.INSTANCE, request).actionGet();
         assertThat(
             response.getInferenceResults()
@@ -197,7 +197,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
         );
 
         // Get top classes
-        request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null, null), true);
+        request = InferModelAction.Request.forDocs(modelId2, toInfer, new ClassificationConfigUpdate(2, null, null, null, null), true);
         response = client().execute(InferModelAction.INSTANCE, request).actionGet();
 
         ClassificationInferenceResults classificationInferenceResults = (ClassificationInferenceResults) response.getInferenceResults()
@@ -220,7 +220,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
         );
 
         // Test that top classes restrict the number returned
-        request = new InferModelAction.Request(modelId2, toInfer2, new ClassificationConfigUpdate(1, null, null, null, null), true);
+        request = InferModelAction.Request.forDocs(modelId2, toInfer2, new ClassificationConfigUpdate(1, null, null, null, null), true);
         response = client().execute(InferModelAction.INSTANCE, request).actionGet();
 
         classificationInferenceResults = (ClassificationInferenceResults) response.getInferenceResults().get(0);
@@ -319,7 +319,12 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
         });
 
         // Test regression
-        InferModelAction.Request request = new InferModelAction.Request(modelId, toInfer, ClassificationConfigUpdate.EMPTY_PARAMS, true);
+        InferModelAction.Request request = InferModelAction.Request.forDocs(
+            modelId,
+            toInfer,
+            ClassificationConfigUpdate.EMPTY_PARAMS,
+            true
+        );
         InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet();
         assertThat(
             response.getInferenceResults()
@@ -329,7 +334,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
             contains("option_0", "option_2")
         );
 
-        request = new InferModelAction.Request(modelId, toInfer2, ClassificationConfigUpdate.EMPTY_PARAMS, true);
+        request = InferModelAction.Request.forDocs(modelId, toInfer2, ClassificationConfigUpdate.EMPTY_PARAMS, true);
         response = client().execute(InferModelAction.INSTANCE, request).actionGet();
         assertThat(
             response.getInferenceResults()
@@ -340,7 +345,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
         );
 
         // Get top classes
-        request = new InferModelAction.Request(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null, null), true);
+        request = InferModelAction.Request.forDocs(modelId, toInfer, new ClassificationConfigUpdate(3, null, null, null, null), true);
         response = client().execute(InferModelAction.INSTANCE, request).actionGet();
 
         ClassificationInferenceResults classificationInferenceResults = (ClassificationInferenceResults) response.getInferenceResults()
@@ -358,7 +363,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
 
     public void testInferMissingModel() {
         String model = "test-infer-missing-model";
-        InferModelAction.Request request = new InferModelAction.Request(
+        InferModelAction.Request request = InferModelAction.Request.forDocs(
             model,
             Collections.emptyList(),
             RegressionConfigUpdate.EMPTY_PARAMS,
@@ -404,7 +409,7 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
             }
         });
 
-        InferModelAction.Request request = new InferModelAction.Request(
+        InferModelAction.Request request = InferModelAction.Request.forDocs(
             modelId,
             toInferMissingField,
             RegressionConfigUpdate.EMPTY_PARAMS,

+ 21 - 86
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

@@ -7,38 +7,30 @@
 
 package org.elasticsearch.xpack.ml.action;
 
-import org.apache.logging.log4j.LogManager;
-import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.TaskOperationFailure;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.GroupedActionListener;
 import org.elasticsearch.action.support.tasks.TransportTasksAction;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
-import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
-import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
-import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
-import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
-import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
-import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
 import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
 import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 
+import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
-import java.util.Optional;
-
-import static org.elasticsearch.core.Strings.format;
 
 public class TransportInferTrainedModelDeploymentAction extends TransportTasksAction<
     TrainedModelDeploymentTask,
@@ -46,10 +38,6 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
     InferTrainedModelDeploymentAction.Response,
     InferTrainedModelDeploymentAction.Response> {
 
-    private static final Logger logger = LogManager.getLogger(TransportInferTrainedModelDeploymentAction.class);
-
-    private final TrainedModelProvider provider;
-
     @Inject
     public TransportInferTrainedModelDeploymentAction(
         ClusterService clusterService,
@@ -67,63 +55,6 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
             InferTrainedModelDeploymentAction.Response::new,
             ThreadPool.Names.SAME
         );
-        this.provider = provider;
-    }
-
-    @Override
-    protected void doExecute(
-        Task task,
-        InferTrainedModelDeploymentAction.Request request,
-        ActionListener<InferTrainedModelDeploymentAction.Response> listener
-    ) {
-        TaskId taskId = new TaskId(clusterService.localNode().getId(), task.getId());
-        // Update the requests model ID if it's an alias
-        Optional.ofNullable(ModelAliasMetadata.fromState(clusterService.state()).getModelId(request.getModelId()))
-            .ifPresent(request::setModelId);
-        // We need to check whether there is at least an assigned task here, otherwise we cannot redirect to the
-        // node running the job task.
-        TrainedModelAssignment assignment = TrainedModelAssignmentMetadata.assignmentForModelId(
-            clusterService.state(),
-            request.getModelId()
-        ).orElse(null);
-        if (assignment == null) {
-            // If there is no assignment, verify the model even exists so that we can provide a nicer error message
-            provider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), taskId, ActionListener.wrap(config -> {
-                if (config.getModelType() != TrainedModelType.PYTORCH) {
-                    listener.onFailure(
-                        ExceptionsHelper.badRequestException(
-                            "Only [pytorch] models are supported by _infer, provided model [{}] has type [{}]",
-                            config.getModelId(),
-                            config.getModelType()
-                        )
-                    );
-                    return;
-                }
-                String message = "Trained model [" + request.getModelId() + "] is not deployed";
-                listener.onFailure(ExceptionsHelper.conflictStatusException(message));
-            }, listener::onFailure));
-            return;
-        }
-        if (assignment.getAssignmentState() == AssignmentState.STOPPING) {
-            String message = "Trained model [" + request.getModelId() + "] is STOPPING";
-            listener.onFailure(ExceptionsHelper.conflictStatusException(message));
-            return;
-        }
-        logger.trace(() -> format("[%s] selecting node from routing table: %s", assignment.getModelId(), assignment.getNodeRoutingTable()));
-        assignment.selectRandomStartedNodeWeighedOnAllocations().ifPresentOrElse(node -> {
-            logger.trace(() -> format("[%s] selected node [%s]", assignment.getModelId(), node));
-            request.setNodes(node);
-            long start = System.currentTimeMillis();
-            super.doExecute(task, request, ActionListener.wrap(r -> {
-                r.setTookMillis(System.currentTimeMillis() - start);
-                listener.onResponse(r);
-            }, listener::onFailure));
-        }, () -> {
-            logger.trace(() -> format("[%s] model not allocated to any node [%s]", assignment.getModelId()));
-            listener.onFailure(
-                ExceptionsHelper.conflictStatusException("Trained model [" + request.getModelId() + "] is not allocated to any nodes")
-            );
-        });
     }
 
     @Override
@@ -144,6 +75,7 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
                 request.getModelId()
             );
         } else {
+            assert tasks.size() == 1;
             return tasks.get(0);
         }
     }
@@ -157,23 +89,26 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
     ) {
         assert actionTask instanceof CancellableTask : "task [" + actionTask + "] not cancellable";
 
-        NlpInferenceInput input;
+        var nlpInputs = new ArrayList<NlpInferenceInput>();
         if (request.getTextInput() != null) {
-            input = NlpInferenceInput.fromText(request.getTextInput());
+            for (var text : request.getTextInput()) {
+                nlpInputs.add(NlpInferenceInput.fromText(text));
+            }
         } else {
-            input = NlpInferenceInput.fromDoc(request.getDocs().get(0));
+            for (var doc : request.getDocs()) {
+                nlpInputs.add(NlpInferenceInput.fromDoc(doc));
+            }
         }
 
-        task.infer(
-            input,
-            request.getUpdate(),
-            request.isSkipQueue(),
-            request.getInferenceTimeout(),
-            actionTask,
-            ActionListener.wrap(
-                pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult, 0)),
-                listener::onFailure
-            )
+        // Multiple documents to infer on, wait for all results
+        ActionListener<Collection<InferenceResults>> collectingListener = ActionListener.wrap(
+            pyTorchResults -> { listener.onResponse(new InferTrainedModelDeploymentAction.Response(new ArrayList<>(pyTorchResults))); },
+            listener::onFailure
         );
+
+        GroupedActionListener<InferenceResults> groupedListener = new GroupedActionListener<>(nlpInputs.size(), collectingListener);
+        for (var input : nlpInputs) {
+            task.infer(input, request.getUpdate(), request.isSkipQueue(), request.getInferenceTimeout(), actionTask, groupedListener);
+        }
     }
 }

+ 145 - 61
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -6,13 +6,15 @@
  */
 package org.elasticsearch.xpack.ml.action;
 
+import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
-import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.common.util.concurrent.AtomicArray;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.license.License;
 import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.XPackLicenseState;
@@ -28,8 +30,11 @@ import org.elasticsearch.xpack.core.ml.action.InferModelAction;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response;
 import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
+import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
+import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
@@ -38,9 +43,10 @@ import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
 
-import java.util.Collections;
-import java.util.Map;
+import java.util.List;
 import java.util.Optional;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
 
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
@@ -132,19 +138,19 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
     ) {
         String concreteModelId = Optional.ofNullable(ModelAliasMetadata.fromState(clusterService.state()).getModelId(request.getModelId()))
             .orElse(request.getModelId());
-        if (isAllocatedModel(concreteModelId)) {
+
+        responseBuilder.setModelId(concreteModelId);
+
+        TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(clusterService.state());
+
+        if (trainedModelAssignmentMetadata.isAssigned(concreteModelId)) {
             // It is important to use the resolved model ID here as the alias could change between transport calls.
-            inferAgainstAllocatedModel(request, concreteModelId, responseBuilder, parentTaskId, listener);
+            inferAgainstAllocatedModel(trainedModelAssignmentMetadata, request, concreteModelId, responseBuilder, parentTaskId, listener);
         } else {
             getModelAndInfer(request, responseBuilder, parentTaskId, (CancellableTask) task, listener);
         }
     }
 
-    private boolean isAllocatedModel(String modelId) {
-        TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(clusterService.state());
-        return trainedModelAssignmentMetadata.isAssigned(modelId);
-    }
-
     private void getModelAndInfer(
         Request request,
         Response.Builder responseBuilder,
@@ -169,75 +175,153 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
 
             typedChainTaskExecutor.execute(ActionListener.wrap(inferenceResultsInterfaces -> {
                 model.release();
-                listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).setModelId(model.getModelId()).build());
+                listener.onResponse(responseBuilder.addInferenceResults(inferenceResultsInterfaces).build());
             }, e -> {
                 model.release();
                 listener.onFailure(e);
             }));
-        }, listener::onFailure);
+        }, e -> {
+            if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
+                listener.onFailure(e);
+                return;
+            }
 
+            // The model was found, check if a more relevant error message can be returned
+            trainedModelProvider.getTrainedModel(
+                request.getModelId(),
+                GetTrainedModelsAction.Includes.empty(),
+                parentTaskId,
+                ActionListener.wrap(trainedModelConfig -> {
+                    if (trainedModelConfig.getModelType() == TrainedModelType.PYTORCH) {
+                        // The PyTorch model cannot be allocated if we got here
+                        listener.onFailure(
+                            ExceptionsHelper.conflictStatusException(
+                                "Model ["
+                                    + request.getModelId()
+                                    + "] must be deployed to use. Please deploy with the start trained model deployment API.",
+                                request.getModelId()
+                            )
+                        );
+                    } else {
+                        // return the original error
+                        listener.onFailure(e);
+                    }
+                }, listener::onFailure)
+            );
+        });
+
+        // TODO should `getModelForInternalInference` be used here??
         modelLoadingService.getModelForPipeline(request.getModelId(), parentTaskId, getModelListener);
     }
 
     private void inferAgainstAllocatedModel(
+        TrainedModelAssignmentMetadata assignmentMeta,
         Request request,
         String concreteModelId,
         Response.Builder responseBuilder,
         TaskId parentTaskId,
         ActionListener<Response> listener
     ) {
-        TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor = new TypedChainTaskExecutor<>(
-            client.threadPool().executor(ThreadPool.Names.SAME),
-            // run through all tasks
-            r -> true,
-            // Always fail immediately and return an error
-            ex -> true
-        );
-        request.getObjectsToInfer()
-            .forEach(
-                stringObjectMap -> typedChainTaskExecutor.add(
-                    chainedTask -> inferSingleDocAgainstAllocatedModel(
-                        concreteModelId,
-                        request.getTimeout(),
-                        request.getUpdate(),
-                        stringObjectMap,
-                        parentTaskId,
-                        chainedTask
-                    )
-                )
+        TrainedModelAssignment assignment = assignmentMeta.getModelAssignment(concreteModelId);
+
+        if (assignment.getAssignmentState() == AssignmentState.STOPPING) {
+            String message = "Trained model [" + request.getModelId() + "] is STOPPING";
+            listener.onFailure(ExceptionsHelper.conflictStatusException(message));
+            return;
+        }
+
+        // Get a list of nodes to send the requests to and the number of
+        // documents for each node.
+        var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments());
+        if (nodes.isEmpty()) {
+            logger.trace(() -> format("[%s] model not allocated to any node [%s]", assignment.getModelId()));
+            listener.onFailure(
+                ExceptionsHelper.conflictStatusException("Trained model [" + request.getModelId() + "] is not allocated to any nodes")
             );
+            return;
+        }
 
-        typedChainTaskExecutor.execute(
-            ActionListener.wrap(
-                inferenceResults -> listener.onResponse(
-                    responseBuilder.setInferenceResults(inferenceResults).setModelId(concreteModelId).build()
-                ),
-                listener::onFailure
-            )
-        );
+        assert nodes.stream().mapToInt(Tuple::v2).sum() == request.numberOfDocuments()
+            : "mismatch; sum of node requests does not match number of documents in request";
+
+        AtomicInteger count = new AtomicInteger();
+        AtomicArray<List<InferenceResults>> results = new AtomicArray<>(nodes.size());
+        AtomicReference<Exception> failure = new AtomicReference<>();
+
+        int startPos = 0;
+        int slot = 0;
+        for (var node : nodes) {
+            InferTrainedModelDeploymentAction.Request deploymentRequest;
+            if (request.getTextInput() == null) {
+                deploymentRequest = InferTrainedModelDeploymentAction.Request.forDocs(
+                    concreteModelId,
+                    request.getUpdate(),
+                    request.getObjectsToInfer().subList(startPos, startPos + node.v2()),
+                    request.getInferenceTimeout()
+                );
+            } else {
+                deploymentRequest = InferTrainedModelDeploymentAction.Request.forTextInput(
+                    concreteModelId,
+                    request.getUpdate(),
+                    request.getTextInput().subList(startPos, startPos + node.v2()),
+                    request.getInferenceTimeout()
+                );
+            }
+            deploymentRequest.setNodes(node.v1());
+            deploymentRequest.setParentTask(parentTaskId);
+
+            startPos += node.v2();
+
+            executeAsyncWithOrigin(
+                client,
+                ML_ORIGIN,
+                InferTrainedModelDeploymentAction.INSTANCE,
+                deploymentRequest,
+                collectingListener(count, results, failure, slot, nodes.size(), responseBuilder, listener)
+            );
+
+            slot++;
+        }
     }
 
-    private void inferSingleDocAgainstAllocatedModel(
-        String modelId,
-        TimeValue timeValue,
-        InferenceConfigUpdate inferenceConfigUpdate,
-        Map<String, Object> doc,
-        TaskId parentTaskId,
-        ActionListener<InferenceResults> listener
+    private ActionListener<InferTrainedModelDeploymentAction.Response> collectingListener(
+        AtomicInteger count,
+        AtomicArray<List<InferenceResults>> results,
+        AtomicReference<Exception> failure,
+        int slot,
+        int totalNumberOfResponses,
+        Response.Builder responseBuilder,
+        ActionListener<Response> finalListener
     ) {
-        InferTrainedModelDeploymentAction.Request request = new InferTrainedModelDeploymentAction.Request(
-            modelId,
-            inferenceConfigUpdate,
-            Collections.singletonList(doc),
-            timeValue
-        );
-        request.setParentTask(parentTaskId);
-        executeAsyncWithOrigin(
-            client,
-            ML_ORIGIN,
-            InferTrainedModelDeploymentAction.INSTANCE,
-            request,
-            ActionListener.wrap(r -> listener.onResponse(r.getResults()), listener::onFailure)
-        );
+        return new ActionListener<>() {
+            @Override
+            public void onResponse(InferTrainedModelDeploymentAction.Response response) {
+                results.setOnce(slot, response.getResults());
+                if (count.incrementAndGet() == totalNumberOfResponses) {
+                    sendResponse();
+                }
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                failure.set(e);
+                if (count.incrementAndGet() == totalNumberOfResponses) {
+                    sendResponse();
+                }
+            }
+
+            private void sendResponse() {
+                if (results.nonNullLength() > 0) {
+                    for (int i = 0; i < results.length(); i++) {
+                        if (results.get(i) != null) {
+                            responseBuilder.addInferenceResults(results.get(i));
+                        }
+                    }
+                    finalListener.onResponse(responseBuilder.build());
+                } else {
+                    finalListener.onFailure(failure.get());
+                }
+            }
+        };
     }
 }

+ 19 - 12
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSemanticSearchAction.java

@@ -22,9 +22,12 @@ import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.transport.TransportService;
-import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.action.InferModelAction;
 import org.elasticsearch.xpack.core.ml.action.SemanticSearchAction;
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
+
+import java.util.List;
 
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 
@@ -51,12 +54,15 @@ public class TransportSemanticSearchAction extends HandledTransportAction<Semant
         var parentTask = new TaskId(clusterService.localNode().getId(), task.getId());
         var originSettingClient = new OriginSettingClient(client, ML_ORIGIN);
 
+        long startMs = System.currentTimeMillis();
         // call inference as ML_ORIGIN
         originSettingClient.execute(
-            InferTrainedModelDeploymentAction.INSTANCE,
+            InferModelAction.INSTANCE,
             toInferenceRequest(request, parentTask),
             ActionListener.wrap(inferenceResults -> {
-                if (inferenceResults.getResults()instanceof TextEmbeddingResults textEmbeddingResults) {
+                // Expect 1 result
+                assert inferenceResults.getInferenceResults().size() == 1;
+                if (inferenceResults.getInferenceResults().get(0)instanceof TextEmbeddingResults textEmbeddingResults) {
 
                     var searchRequestBuilder = buildSearch(client, textEmbeddingResults, request);
                     searchRequestBuilder.request().setParentTask(parentTask);
@@ -66,7 +72,7 @@ public class TransportSemanticSearchAction extends HandledTransportAction<Semant
                         listener.onResponse(
                             new SemanticSearchAction.Response(
                                 searchResponse.getTook(),
-                                TimeValue.timeValueMillis(inferenceResults.getTookMillis()),
+                                TimeValue.timeValueMillis(System.currentTimeMillis() - startMs),
                                 searchResponse
                             )
                         );
@@ -77,7 +83,7 @@ public class TransportSemanticSearchAction extends HandledTransportAction<Semant
                             "model ["
                                 + request.getModelId()
                                 + "] must be a text_embedding model; provided ["
-                                + inferenceResults.getResults().getWriteableName()
+                                + inferenceResults.getInferenceResults().get(0).getWriteableName()
                                 + "]"
                         )
                     );
@@ -126,13 +132,14 @@ public class TransportSemanticSearchAction extends HandledTransportAction<Semant
         return searchBuilder;
     }
 
-    private InferTrainedModelDeploymentAction.Request toInferenceRequest(SemanticSearchAction.Request request, TaskId parentTask) {
-        var inferenceRequest = new InferTrainedModelDeploymentAction.Request(
-            request.getModelId(),
-            request.getEmbeddingConfig(),
-            request.getModelText(),
-            request.getInferenceTimeout()
-        );
+    private InferModelAction.Request toInferenceRequest(SemanticSearchAction.Request request, TaskId parentTask) {
+
+        var configUpdate = request.getEmbeddingConfig();
+        if (configUpdate == null) {
+            configUpdate = TextEmbeddingConfigUpdate.EMPTY_INSTANCE;
+        }
+        var inferenceRequest = InferModelAction.Request.forTextInput(request.getModelId(), configUpdate, List.of(request.getModelText()));
+        inferenceRequest.setInferenceTimeout(request.getInferenceTimeout());
         inferenceRequest.setParentTask(parentTask);
         return inferenceRequest;
     }

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

@@ -150,7 +150,7 @@ public class InferenceProcessor extends AbstractProcessor {
             fields.put(INGEST_KEY, ingestDocument.getIngestMetadata());
         }
         LocalModel.mapFieldsIfNecessary(fields, fieldMap);
-        return new InferModelAction.Request(modelId, fields, inferenceConfig, previouslyLicensed);
+        return InferModelAction.Request.forDocs(modelId, List.of(fields), inferenceConfig, previouslyLicensed);
     }
 
     void auditWarningAboutLicenseIfNecessary() {

+ 10 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java

@@ -355,10 +355,7 @@ public class ModelLoadingService implements ClusterStateListener {
                     );
                     return;
                 }
-                handleLoadFailure(
-                    modelId,
-                    new ElasticsearchStatusException("Trained model [{}] is not deployed.", RestStatus.BAD_REQUEST, modelId)
-                );
+                handleLoadFailure(modelId, modelMustBeDeployedError(modelId));
                 return;
             }
             auditNewReferencedModel(modelId);
@@ -409,13 +406,7 @@ public class ModelLoadingService implements ClusterStateListener {
                     );
                     return;
                 }
-                modelActionListener.onFailure(
-                    new ElasticsearchStatusException(
-                        "model [{}] must be deployed to use. Please deploy with the start trained model deployment API.",
-                        RestStatus.BAD_REQUEST,
-                        modelId
-                    )
-                );
+                modelActionListener.onFailure(modelMustBeDeployedError(modelId));
                 return;
             }
             // Verify we can pull the model into memory without causing OOM
@@ -483,6 +474,14 @@ public class ModelLoadingService implements ClusterStateListener {
         }
     }
 
+    private ElasticsearchStatusException modelMustBeDeployedError(String modelId) {
+        return new ElasticsearchStatusException(
+            "Model [{}] must be deployed to use. Please deploy with the start trained model deployment API.",
+            RestStatus.BAD_REQUEST,
+            modelId
+        );
+    }
+
     private void handleLoadSuccess(
         String modelId,
         Consumer consumer,

+ 3 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelAction.java

@@ -14,9 +14,7 @@ import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.rest.action.RestCancellableNodeClient;
 import org.elasticsearch.rest.action.RestToXContentListener;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction;
-import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
@@ -48,16 +46,13 @@ public class RestInferTrainedModelAction extends BaseRestHandler {
         }
         InferModelAction.Request.Builder request = InferModelAction.Request.parseRequest(modelId, restRequest.contentParser());
 
-        if (restRequest.hasParam(InferTrainedModelDeploymentAction.Request.TIMEOUT.getPreferredName())) {
+        if (restRequest.hasParam(InferModelAction.Request.TIMEOUT.getPreferredName())) {
             TimeValue inferTimeout = restRequest.paramAsTime(
-                InferTrainedModelDeploymentAction.Request.TIMEOUT.getPreferredName(),
-                InferTrainedModelDeploymentAction.Request.DEFAULT_TIMEOUT
+                InferModelAction.Request.TIMEOUT.getPreferredName(),
+                InferModelAction.Request.DEFAULT_TIMEOUT
             );
             request.setInferenceTimeout(inferTimeout);
         }
-        if (request.getUpdate() == null) {
-            request.setUpdate(new EmptyConfigUpdate());
-        }
 
         return channel -> new RestCancellableNodeClient(client, restRequest.getHttpChannel()).execute(
             InferModelAction.EXTERNAL_INSTANCE,

+ 27 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java

@@ -7,13 +7,16 @@
 
 package org.elasticsearch.xpack.ml.rest.inference;
 
+import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.internal.node.NodeClient;
+import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.core.RestApiVersion;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.rest.BaseRestHandler;
 import org.elasticsearch.rest.RestRequest;
 import org.elasticsearch.rest.action.RestCancellableNodeClient;
 import org.elasticsearch.rest.action.RestToXContentListener;
+import org.elasticsearch.xpack.core.ml.action.InferModelAction;
 import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -60,23 +63,36 @@ public class RestInferTrainedModelDeploymentAction extends BaseRestHandler {
         if (restRequest.hasContent() == false) {
             throw ExceptionsHelper.badRequestException("requires body");
         }
-        InferTrainedModelDeploymentAction.Request.Builder request = InferTrainedModelDeploymentAction.Request.parseRequest(
-            modelId,
-            restRequest.contentParser()
-        );
+        InferModelAction.Request.Builder requestBuilder = InferModelAction.Request.parseRequest(modelId, restRequest.contentParser());
 
-        if (restRequest.hasParam(InferTrainedModelDeploymentAction.Request.TIMEOUT.getPreferredName())) {
+        if (restRequest.hasParam(InferModelAction.Request.TIMEOUT.getPreferredName())) {
             TimeValue inferTimeout = restRequest.paramAsTime(
-                InferTrainedModelDeploymentAction.Request.TIMEOUT.getPreferredName(),
-                InferTrainedModelDeploymentAction.Request.DEFAULT_TIMEOUT
+                InferModelAction.Request.TIMEOUT.getPreferredName(),
+                InferModelAction.Request.DEFAULT_TIMEOUT
             );
-            request.setInferenceTimeout(inferTimeout);
+            requestBuilder.setInferenceTimeout(inferTimeout);
+        }
+
+        // Unlike the _infer API, deployment/_infer only accepts a single document
+        var request = requestBuilder.build();
+        if (request.getObjectsToInfer() != null && request.getObjectsToInfer().size() > 1) {
+            ValidationException ex = new ValidationException();
+            ex.addValidationError("multiple documents are not supported");
+            throw ex;
         }
 
         return channel -> new RestCancellableNodeClient(client, restRequest.getHttpChannel()).execute(
-            InferTrainedModelDeploymentAction.INSTANCE,
-            request.build(),
-            new RestToXContentListener<>(channel)
+            InferModelAction.EXTERNAL_INSTANCE,
+            request,
+            // This API is deprecated but refactoring makes it simpler to call
+            // the new replacement API and swap in the old response.
+            ActionListener.wrap(response -> {
+                InferTrainedModelDeploymentAction.Response oldResponse = new InferTrainedModelDeploymentAction.Response(
+                    response.getInferenceResults()
+                );
+                new RestToXContentListener<>(channel).onResponse(oldResponse);
+            }, e -> new RestToXContentListener<>(channel).onFailure(e))
+
         );
     }
 }