فهرست منبع

[ML] add new cache_size parameter to trained_model deployments API (#88450)

With: https://github.com/elastic/ml-cpp/pull/2305 we now support caching pytorch inference responses per node per model.

By default, the cache will be the same size has the model on disk size. This is because our current best estimate for memory used (for deploying) is 2*model_size + constant_overhead. 

This is due to the model having to be loaded in memory twice when serializing to the native process. 

But, once the model is in memory and accepting requests, its actual memory usage is reduced vs. what we have "reserved" for it within the node.

Consequently, having a cache layer that takes advantage of that unused (but reserved) memory is effectively free. When used in production, especially in search scenarios, caching inference results is critical for decreasing latency.
Benjamin Trent 3 سال پیش
والد
کامیت
afa28d49b4
28فایلهای تغییر یافته به همراه343 افزوده شده و 32 حذف شده
  1. 5 0
      docs/changelog/88450.yaml
  2. 4 0
      docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc
  3. 6 1
      docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc
  4. 5 0
      rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json
  5. 86 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
  6. 24 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java
  7. 56 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java
  8. 3 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java
  9. 5 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java
  10. 6 4
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java
  11. 3 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  12. 3 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java
  13. 3 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java
  14. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java
  15. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java
  16. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java
  17. 13 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java
  18. 8 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java
  19. 5 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java
  20. 6 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java
  21. 8 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java
  22. 3 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java
  23. 9 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java
  24. 8 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java
  25. 3 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java
  26. 19 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java
  27. 8 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java
  28. 38 0
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

+ 5 - 0
docs/changelog/88450.yaml

@@ -0,0 +1,5 @@
+pr: 88450
+summary: Add new `cache_size` parameter to `trained_model` deployments API
+area: Machine Learning
+type: enhancement
+issues: []

+ 4 - 0
docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc

@@ -97,6 +97,10 @@ The detailed allocation status given the deployment configuration.
 (integer)
 The current number of nodes where the model is allocated.
 
+`cache_size`:::
+(<<byte-units,byte value>>)
+The inference cache size (in memory outside the JVM heap) per node for the model.
+
 `state`:::
 (string)
 The detailed allocation state related to the nodes.

+ 6 - 1
docs/reference/ml/trained-models/apis/start-trained-model-deployment.asciidoc

@@ -34,7 +34,7 @@ Increasing `threads_per_allocation` means more threads are used when
 an inference request is processed on a node. This can improve inference speed
 for certain models. It may also result in improvement to throughput.
 
-Increasing `number_of_allocations` means more threads are used to 
+Increasing `number_of_allocations` means more threads are used to
 process multiple inference requests in parallel resulting in throughput
 improvement. Each model allocation uses a number of threads defined by
 `threads_per_allocation`.
@@ -55,6 +55,11 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 [[start-trained-model-deployment-query-params]]
 == {api-query-parms-title}
 
+`cache_size`::
+(Optional, <<byte-units,byte value>>)
+The inference cache size (in memory outside the JVM heap) per node for the model.
+The default value is the same size as the `model_size_bytes`. To disable the cache, `0b` can be provided.
+
 `number_of_allocations`::
 (Optional, integer)
 The total number of allocations this model is assigned across {ml} nodes.

+ 5 - 0
rest-api-spec/src/main/resources/rest-api-spec/api/ml.start_trained_model_deployment.json

@@ -28,6 +28,11 @@
       ]
     },
     "params":{
+      "cache_size": {
+        "type": "string",
+        "description": "A byte-size value for configuring the inference cache size. For example, 20mb.",
+        "required": false
+      },
       "number_of_allocations":{
         "type":"int",
         "description": "The number of model allocations on each node where the model is deployed.",

+ 86 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -18,6 +18,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
@@ -34,8 +35,10 @@ import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 
 import java.io.IOException;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.concurrent.TimeUnit;
 
+import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
 import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription;
 
 public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedModelAssignmentAction.Response> {
@@ -75,6 +78,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation", "inference_threads");
         public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations", "model_threads");
         public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
+        public static final ParseField CACHE_SIZE = TaskParams.CACHE_SIZE;
 
         public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
 
@@ -85,6 +89,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             PARSER.declareInt(Request::setThreadsPerAllocation, THREADS_PER_ALLOCATION);
             PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS);
             PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
+            PARSER.declareField(
+                Request::setCacheSize,
+                (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
+                CACHE_SIZE,
+                ObjectParser.ValueType.VALUE
+            );
         }
 
         public static Request parseRequest(String modelId, XContentParser parser) {
@@ -102,6 +112,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         private String modelId;
         private TimeValue timeout = DEFAULT_TIMEOUT;
         private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
+        private ByteSizeValue cacheSize;
         private int numberOfAllocations = 1;
         private int threadsPerAllocation = 1;
         private int queueCapacity = 1024;
@@ -120,6 +131,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             numberOfAllocations = in.readVInt();
             threadsPerAllocation = in.readVInt();
             queueCapacity = in.readVInt();
+            if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
+                this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
+            }
         }
 
         public final void setModelId(String modelId) {
@@ -171,6 +185,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             this.queueCapacity = queueCapacity;
         }
 
+        public ByteSizeValue getCacheSize() {
+            return cacheSize;
+        }
+
+        public void setCacheSize(ByteSizeValue cacheSize) {
+            this.cacheSize = cacheSize;
+        }
+
         @Override
         public void writeTo(StreamOutput out) throws IOException {
             super.writeTo(out);
@@ -180,6 +202,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             out.writeVInt(numberOfAllocations);
             out.writeVInt(threadsPerAllocation);
             out.writeVInt(queueCapacity);
+            if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
+                out.writeOptionalWriteable(cacheSize);
+            }
         }
 
         @Override
@@ -191,6 +216,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
             builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
             builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
+            if (cacheSize != null) {
+                builder.field(CACHE_SIZE.getPreferredName(), cacheSize);
+            }
             builder.endObject();
             return builder;
         }
@@ -229,7 +257,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 
         @Override
         public int hashCode() {
-            return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity);
+            return Objects.hash(modelId, timeout, waitForState, numberOfAllocations, threadsPerAllocation, queueCapacity, cacheSize);
         }
 
         @Override
@@ -244,6 +272,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             return Objects.equals(modelId, other.modelId)
                 && Objects.equals(timeout, other.timeout)
                 && Objects.equals(waitForState, other.waitForState)
+                && Objects.equals(cacheSize, other.cacheSize)
                 && numberOfAllocations == other.numberOfAllocations
                 && threadsPerAllocation == other.threadsPerAllocation
                 && queueCapacity == other.queueCapacity;
@@ -273,11 +302,21 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         // threads_per_allocation was previously named inference_threads
         public static final ParseField LEGACY_INFERENCE_THREADS = new ParseField("inference_threads");
         public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
+        public static final ParseField CACHE_SIZE = new ParseField("cache_size");
 
         private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
             "trained_model_deployment_params",
             true,
-            a -> new TaskParams((String) a[0], (Long) a[1], (Integer) a[2], (Integer) a[3], (int) a[4], (Integer) a[5], (Integer) a[6])
+            a -> new TaskParams(
+                (String) a[0],
+                (Long) a[1],
+                (Integer) a[2],
+                (Integer) a[3],
+                (int) a[4],
+                (ByteSizeValue) a[5],
+                (Integer) a[6],
+                (Integer) a[7]
+            )
         );
 
         static {
@@ -286,6 +325,12 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_OF_ALLOCATIONS);
             PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION);
             PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
+            PARSER.declareField(
+                optionalConstructorArg(),
+                (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()),
+                CACHE_SIZE,
+                ObjectParser.ValueType.VALUE
+            );
             PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
             PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
         }
@@ -295,6 +340,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         }
 
         private final String modelId;
+        private final ByteSizeValue cacheSize;
         private final long modelBytes;
         // How many threads are used by the model during inference. Used to increase inference speed.
         private final int threadsPerAllocation;
@@ -308,6 +354,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             Integer numberOfAllocations,
             Integer threadsPerAllocation,
             int queueCapacity,
+            ByteSizeValue cacheSizeValue,
             Integer legacyModelThreads,
             Integer legacyInferenceThreads
         ) {
@@ -316,16 +363,25 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
                 modelBytes,
                 threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
                 numberOfAllocations == null ? legacyModelThreads : numberOfAllocations,
-                queueCapacity
+                queueCapacity,
+                cacheSizeValue
             );
         }
 
-        public TaskParams(String modelId, long modelBytes, int threadsPerAllocation, int numberOfAllocations, int queueCapacity) {
+        public TaskParams(
+            String modelId,
+            long modelBytes,
+            int threadsPerAllocation,
+            int numberOfAllocations,
+            int queueCapacity,
+            @Nullable ByteSizeValue cacheSize
+        ) {
             this.modelId = Objects.requireNonNull(modelId);
             this.modelBytes = modelBytes;
             this.threadsPerAllocation = threadsPerAllocation;
             this.numberOfAllocations = numberOfAllocations;
             this.queueCapacity = queueCapacity;
+            this.cacheSize = cacheSize;
         }
 
         public TaskParams(StreamInput in) throws IOException {
@@ -334,6 +390,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             this.threadsPerAllocation = in.readVInt();
             this.numberOfAllocations = in.readVInt();
             this.queueCapacity = in.readVInt();
+            if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
+                this.cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
+            } else {
+                this.cacheSize = null;
+            }
         }
 
         public String getModelId() {
@@ -341,6 +402,11 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         }
 
         public long estimateMemoryUsageBytes() {
+            // We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then
+            // we need to take it into account when returning the estimate.
+            if (cacheSize != null && cacheSize.getBytes() > modelBytes) {
+                return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes) + (cacheSize.getBytes() - modelBytes);
+            }
             return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes);
         }
 
@@ -355,6 +421,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             out.writeVInt(threadsPerAllocation);
             out.writeVInt(numberOfAllocations);
             out.writeVInt(queueCapacity);
+            if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
+                out.writeOptionalWriteable(cacheSize);
+            }
         }
 
         @Override
@@ -365,13 +434,16 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation);
             builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations);
             builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
+            if (cacheSize != null) {
+                builder.field(CACHE_SIZE.getPreferredName(), cacheSize.getStringRep());
+            }
             builder.endObject();
             return builder;
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity);
+            return Objects.hash(modelId, modelBytes, threadsPerAllocation, numberOfAllocations, queueCapacity, cacheSize);
         }
 
         @Override
@@ -384,6 +456,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
                 && modelBytes == other.modelBytes
                 && threadsPerAllocation == other.threadsPerAllocation
                 && numberOfAllocations == other.numberOfAllocations
+                && Objects.equals(cacheSize, other.cacheSize)
                 && queueCapacity == other.queueCapacity;
         }
 
@@ -408,6 +481,14 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             return queueCapacity;
         }
 
+        public Optional<ByteSizeValue> getCacheSize() {
+            return Optional.ofNullable(cacheSize);
+        }
+
+        public long getCacheSizeBytes() {
+            return Optional.ofNullable(cacheSize).map(ByteSizeValue::getBytes).orElse(modelBytes);
+        }
+
         @Override
         public String toString() {
             return Strings.toString(this);

+ 24 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java

@@ -13,6 +13,7 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -355,6 +356,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
     private final Integer numberOfAllocations;
     @Nullable
     private final Integer queueCapacity;
+    @Nullable
+    private final ByteSizeValue cacheSize;
     private final Instant startTime;
     private final List<AssignmentStats.NodeStats> nodeStats;
 
@@ -363,6 +366,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         @Nullable Integer threadsPerAllocation,
         @Nullable Integer numberOfAllocations,
         @Nullable Integer queueCapacity,
+        @Nullable ByteSizeValue cacheSize,
         Instant startTime,
         List<AssignmentStats.NodeStats> nodeStats
     ) {
@@ -372,6 +376,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         this.queueCapacity = queueCapacity;
         this.startTime = Objects.requireNonNull(startTime);
         this.nodeStats = nodeStats;
+        this.cacheSize = cacheSize;
         this.state = null;
         this.reason = null;
     }
@@ -386,6 +391,11 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         state = in.readOptionalEnum(AssignmentState.class);
         reason = in.readOptionalString();
         allocationStatus = in.readOptionalWriteable(AllocationStatus::new);
+        if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
+            cacheSize = in.readOptionalWriteable(ByteSizeValue::new);
+        } else {
+            cacheSize = null;
+        }
     }
 
     public String getModelId() {
@@ -407,6 +417,11 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         return queueCapacity;
     }
 
+    @Nullable
+    public ByteSizeValue getCacheSize() {
+        return cacheSize;
+    }
+
     public Instant getStartTime() {
         return startTime;
     }
@@ -477,6 +492,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         if (allocationStatus != null) {
             builder.field("allocation_status", allocationStatus);
         }
+        if (cacheSize != null) {
+            builder.field("cache_size", cacheSize);
+        }
         builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
 
         int totalErrorCount = nodeStats.stream().mapToInt(NodeStats::getErrorCount).sum();
@@ -526,6 +544,9 @@ public class AssignmentStats implements ToXContentObject, Writeable {
         }
         out.writeOptionalString(reason);
         out.writeOptionalWriteable(allocationStatus);
+        if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
+            out.writeOptionalWriteable(cacheSize);
+        }
     }
 
     @Override
@@ -541,6 +562,7 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             && Objects.equals(state, that.state)
             && Objects.equals(reason, that.reason)
             && Objects.equals(allocationStatus, that.allocationStatus)
+            && Objects.equals(cacheSize, that.cacheSize)
             && Objects.equals(nodeStats, that.nodeStats);
     }
 
@@ -555,7 +577,8 @@ public class AssignmentStats implements ToXContentObject, Writeable {
             nodeStats,
             state,
             reason,
-            allocationStatus
+            allocationStatus,
+            cacheSize
         );
     }
 

+ 56 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java

@@ -46,9 +46,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
     }
 
     private IngestStats randomIngestStats() {
-        List<String> pipelineIds = Stream.generate(() -> randomAlphaOfLength(10))
-            .limit(randomIntBetween(0, 10))
-            .collect(Collectors.toList());
+        List<String> pipelineIds = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 10)).toList();
         return new IngestStats(
             new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()),
             pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()),
@@ -115,6 +113,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                         stats.getDeploymentStats().getThreadsPerAllocation(),
                                         stats.getDeploymentStats().getNumberOfAllocations(),
                                         stats.getDeploymentStats().getQueueCapacity(),
+                                        null,
                                         stats.getDeploymentStats().getStartTime(),
                                         stats.getDeploymentStats()
                                             .getNodeStats()
@@ -167,6 +166,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                         stats.getDeploymentStats().getThreadsPerAllocation(),
                                         stats.getDeploymentStats().getNumberOfAllocations(),
                                         stats.getDeploymentStats().getQueueCapacity(),
+                                        null,
                                         stats.getDeploymentStats().getStartTime(),
                                         stats.getDeploymentStats()
                                             .getNodeStats()
@@ -199,6 +199,59 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                     RESULTS_FIELD
                 )
             );
+        } else if (version.before(Version.V_8_4_0)) {
+            return new Response(
+                new QueryPage<>(
+                    instance.getResources()
+                        .results()
+                        .stream()
+                        .map(
+                            stats -> new Response.TrainedModelStats(
+                                stats.getModelId(),
+                                stats.getModelSizeStats(),
+                                stats.getIngestStats(),
+                                stats.getPipelineCount(),
+                                stats.getInferenceStats(),
+                                stats.getDeploymentStats() == null
+                                    ? null
+                                    : new AssignmentStats(
+                                        stats.getDeploymentStats().getModelId(),
+                                        stats.getDeploymentStats().getThreadsPerAllocation(),
+                                        stats.getDeploymentStats().getNumberOfAllocations(),
+                                        stats.getDeploymentStats().getQueueCapacity(),
+                                        null,
+                                        stats.getDeploymentStats().getStartTime(),
+                                        stats.getDeploymentStats()
+                                            .getNodeStats()
+                                            .stream()
+                                            .map(
+                                                nodeStats -> new AssignmentStats.NodeStats(
+                                                    nodeStats.getNode(),
+                                                    nodeStats.getInferenceCount().orElse(null),
+                                                    nodeStats.getAvgInferenceTime().orElse(null),
+                                                    nodeStats.getLastAccess(),
+                                                    nodeStats.getPendingCount(),
+                                                    nodeStats.getErrorCount(),
+                                                    nodeStats.getRejectedExecutionCount(),
+                                                    nodeStats.getTimeoutCount(),
+                                                    nodeStats.getRoutingState(),
+                                                    nodeStats.getStartTime(),
+                                                    nodeStats.getThreadsPerAllocation(),
+                                                    nodeStats.getNumberOfAllocations(),
+                                                    nodeStats.getPeakThroughput(),
+                                                    nodeStats.getThroughputLastPeriod(),
+                                                    nodeStats.getAvgInferenceTimeLastPeriod()
+                                                )
+                                            )
+                                            .toList()
+                                    )
+                            )
+                        )
+                        .collect(Collectors.toList()),
+                    instance.getResources().count(),
+                    RESULTS_FIELD
+                )
+            );
         }
         return instance;
     }

+ 3 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.core.ml.action;
 
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.TaskParams;
@@ -37,7 +38,8 @@ public class StartTrainedModelDeploymentTaskParamsTests extends AbstractSerializ
             randomNonNegativeLong(),
             randomIntBetween(1, 8),
             randomIntBetween(1, 8),
-            randomIntBetween(1, 10000)
+            randomIntBetween(1, 10000),
+            randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
         );
     }
 }

+ 5 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.assignment;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
 
@@ -47,6 +48,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 10000),
+            randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 10000000)),
             Instant.now(),
             nodeStatsList
         );
@@ -91,6 +93,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 10000),
+            randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
             Instant.now(),
             List.of(
                 AssignmentStats.NodeStats.forStartedState(
@@ -146,6 +149,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 10000),
+            randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
             Instant.now(),
             List.of()
         );
@@ -163,6 +167,7 @@ public class AssignmentStatsTests extends AbstractWireSerializingTestCase<Assign
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 10000),
+            randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
             Instant.now(),
             List.of(
                 AssignmentStats.NodeStats.forNotStartedState(

+ 6 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.core.ml.inference.assignment;
 import org.elasticsearch.ResourceAlreadyExistsException;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xcontent.XContentParser;
@@ -23,7 +24,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
-import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
 import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
@@ -37,7 +37,7 @@ public class TrainedModelAssignmentTests extends AbstractSerializingTestCase<Tra
 
     public static TrainedModelAssignment randomInstance() {
         TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams());
-        List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
+        List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).toList();
         for (String node : nodes) {
             builder.addRoutingEntry(node, RoutingInfoTests.randomInstance());
         }
@@ -267,12 +267,14 @@ public class TrainedModelAssignmentTests extends AbstractSerializingTestCase<Tra
     }
 
     private static StartTrainedModelDeploymentAction.TaskParams randomTaskParams(int numberOfAllocations) {
+        long modelSize = randomNonNegativeLong();
         return new StartTrainedModelDeploymentAction.TaskParams(
             randomAlphaOfLength(10),
-            randomNonNegativeLong(),
+            modelSize,
             randomIntBetween(1, 8),
             numberOfAllocations,
-            randomIntBetween(1, 10000)
+            randomIntBetween(1, 10000),
+            randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(0, modelSize + 1))
         );
     }
 

+ 3 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -793,8 +793,9 @@ public class PyTorchModelIT extends ESRestTestCase {
 
     private void putModelDefinition(String modelId) throws IOException {
         Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0");
-        request.setJsonEntity("""
-            {"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL));
+        String body = """
+            {"total_definition_length":%s,"definition": "%s","total_parts": 1}""".formatted(RAW_MODEL_SIZE, BASE_64_ENCODED_MODEL);
+        request.setJsonEntity(body);
         client().performRequest(request);
     }
 

+ 3 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -237,6 +237,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                         stat.getThreadsPerAllocation(),
                         stat.getNumberOfAllocations(),
                         stat.getQueueCapacity(),
+                        stat.getCacheSize(),
                         stat.getStartTime(),
                         updatedNodeStats
                     )
@@ -267,7 +268,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
 
                 nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
 
-                updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, assignment.getStartTime(), nodeStats));
+                updatedAssignmentStats.add(new AssignmentStats(modelId, null, null, null, null, assignment.getStartTime(), nodeStats));
             }
         }
 
@@ -327,6 +328,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                 task.getParams().getThreadsPerAllocation(),
                 assignment == null ? task.getParams().getNumberOfAllocations() : assignment.getTaskParams().getNumberOfAllocations(),
                 task.getParams().getQueueCapacity(),
+                task.getParams().getCacheSize().orElse(null),
                 TrainedModelAssignmentMetadata.fromState(clusterService.state()).getModelAssignment(task.getModelId()).getStartTime(),
                 nodeStats
             )

+ 3 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

@@ -72,6 +72,7 @@ import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.OptionalLong;
 import java.util.Set;
 import java.util.function.Predicate;
@@ -229,7 +230,8 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
                         modelBytes,
                         request.getThreadsPerAllocation(),
                         request.getNumberOfAllocations(),
-                        request.getQueueCapacity()
+                        request.getQueueCapacity(),
+                        Optional.ofNullable(request.getCacheSize()).orElse(ByteSizeValue.ofBytes(modelBytes))
                     );
                     PersistentTasksCustomMetadata persistentTasks = clusterService.state()
                         .getMetadata()

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

@@ -357,7 +357,8 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
                                 trainedModelAssignment.getTaskParams().getModelBytes(),
                                 trainedModelAssignment.getTaskParams().getThreadsPerAllocation(),
                                 routingInfo.getCurrentAllocations(),
-                                trainedModelAssignment.getTaskParams().getQueueCapacity()
+                                trainedModelAssignment.getTaskParams().getQueueCapacity(),
+                                trainedModelAssignment.getTaskParams().getCacheSize().orElse(null)
                             )
                         );
                     }

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

@@ -78,7 +78,8 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
             params.getModelBytes(),
             numberOfAllocations,
             params.getThreadsPerAllocation(),
-            params.getQueueCapacity()
+            params.getQueueCapacity(),
+            null
         );
     }
 

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcessFactory.java

@@ -103,7 +103,8 @@ public class NativePyTorchProcessFactory implements PyTorchProcessFactory {
             nativeController,
             processPipes,
             task.getParams().getThreadsPerAllocation(),
-            task.getParams().getNumberOfAllocations()
+            task.getParams().getNumberOfAllocations(),
+            task.getParams().getCacheSizeBytes()
         );
         try {
             pyTorchBuilder.build();

+ 13 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilder.java

@@ -23,17 +23,26 @@ public class PyTorchBuilder {
     private static final String LICENSE_KEY_VALIDATED_ARG = "--validElasticLicenseKeyConfirmed=";
     private static final String NUM_THREADS_PER_ALLOCATION_ARG = "--numThreadsPerAllocation=";
     private static final String NUM_ALLOCATIONS_ARG = "--numAllocations=";
+    private static final String CACHE_MEMORY_LIMIT_BYTES_ARG = "--cacheMemorylimitBytes=";
 
     private final NativeController nativeController;
     private final ProcessPipes processPipes;
     private final int threadsPerAllocation;
     private final int numberOfAllocations;
+    private final long cacheMemoryLimitBytes;
 
-    public PyTorchBuilder(NativeController nativeController, ProcessPipes processPipes, int threadPerAllocation, int numberOfAllocations) {
+    public PyTorchBuilder(
+        NativeController nativeController,
+        ProcessPipes processPipes,
+        int threadPerAllocation,
+        int numberOfAllocations,
+        long cacheMemoryLimitBytes
+    ) {
         this.nativeController = Objects.requireNonNull(nativeController);
         this.processPipes = Objects.requireNonNull(processPipes);
         this.threadsPerAllocation = threadPerAllocation;
         this.numberOfAllocations = numberOfAllocations;
+        this.cacheMemoryLimitBytes = cacheMemoryLimitBytes;
     }
 
     public void build() throws IOException, InterruptedException {
@@ -51,6 +60,9 @@ public class PyTorchBuilder {
 
         command.add(NUM_THREADS_PER_ALLOCATION_ARG + threadsPerAllocation);
         command.add(NUM_ALLOCATIONS_ARG + numberOfAllocations);
+        if (cacheMemoryLimitBytes > 0) {
+            command.add(CACHE_MEMORY_LIMIT_BYTES_ARG + cacheMemoryLimitBytes);
+        }
 
         return command;
     }

+ 8 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.ml.rest.inference;
 
 import org.elasticsearch.client.internal.node.NodeClient;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.core.RestApiVersion;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.rest.BaseRestHandler;
@@ -23,6 +24,7 @@ import java.util.Collections;
 import java.util.List;
 
 import static org.elasticsearch.rest.RestRequest.Method.POST;
+import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.CACHE_SIZE;
 import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS;
 import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
 import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.THREADS_PER_ALLOCATION;
@@ -84,6 +86,12 @@ public class RestStartTrainedModelDeploymentAction extends BaseRestHandler {
                 request::setThreadsPerAllocation
             );
             request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
+            if (restRequest.hasParam(CACHE_SIZE.getPreferredName())) {
+                request.setCacheSize(
+                    ByteSizeValue.parseBytesSizeValue(restRequest.param(CACHE_SIZE.getPreferredName()), CACHE_SIZE.getPreferredName())
+                );
+            }
+            request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
         }
 
         return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));

+ 5 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java

@@ -21,6 +21,7 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.transport.TransportAddress;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.TestEnvironment;
@@ -345,7 +346,9 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                             ),
                             3,
                             null,
-                            new AssignmentStats("model_3", null, null, null, Instant.now(), List.of()).setState(AssignmentState.STOPPING)
+                            new AssignmentStats("model_3", null, null, null, null, Instant.now(), List.of()).setState(
+                                AssignmentState.STOPPING
+                            )
                         ),
                         new GetTrainedModelsStatsAction.Response.TrainedModelStats(
                             trainedModel4.getModelId(),
@@ -371,6 +374,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                 2,
                                 2,
                                 1000,
+                                ByteSizeValue.ofBytes(1000),
                                 Instant.now(),
                                 List.of(
                                     AssignmentStats.NodeStats.forStartedState(

+ 6 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java

@@ -11,6 +11,7 @@ import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.common.transport.TransportAddress;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsActionResponseTests;
@@ -82,6 +83,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 10000),
+            randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
             Instant.now(),
             nodeStatsList
         );
@@ -117,6 +119,7 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 10000),
+            randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)),
             Instant.now(),
             nodeStatsList
         );
@@ -150,6 +153,8 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
     }
 
     private static TrainedModelAssignment createAssignment(String modelId) {
-        return TrainedModelAssignment.Builder.empty(new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1)).build();
+        return TrainedModelAssignment.Builder.empty(
+            new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1, ByteSizeValue.ofBytes(1024))
+        ).build();
     }
 }

+ 8 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java

@@ -1407,7 +1407,14 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase {
         int numberOfAllocations,
         int threadsPerAllocation
     ) {
-        return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024);
+        return new StartTrainedModelDeploymentAction.TaskParams(
+            modelId,
+            modelSize,
+            threadsPerAllocation,
+            numberOfAllocations,
+            1024,
+            ByteSizeValue.ofBytes(modelSize)
+        );
     }
 
     private static NodesShutdownMetadata shutdownMetadata(String nodeId) {

+ 3 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.ml.inference.assignment;
 
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@@ -64,7 +65,8 @@ public class TrainedModelAssignmentMetadataTests extends AbstractSerializingTest
             randomNonNegativeLong(),
             randomIntBetween(1, 8),
             randomIntBetween(1, 8),
-            randomIntBetween(1, 10000)
+            randomIntBetween(1, 10000),
+            randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
         );
     }
 

+ 9 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.indices.TestIndexNameExpressionResolver;
 import org.elasticsearch.license.XPackLicenseState;
@@ -624,7 +625,14 @@ public class TrainedModelAssignmentNodeServiceTests extends ESTestCase {
     }
 
     private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) {
-        return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1, 1024);
+        return new StartTrainedModelDeploymentAction.TaskParams(
+            modelId,
+            randomNonNegativeLong(),
+            1,
+            1,
+            1024,
+            randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong())
+        );
     }
 
     private TrainedModelAssignmentNodeService createService() {

+ 8 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java

@@ -462,7 +462,14 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
         int numberOfAllocations,
         int threadsPerAllocation
     ) {
-        return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, threadsPerAllocation, numberOfAllocations, 1024);
+        return new StartTrainedModelDeploymentAction.TaskParams(
+            modelId,
+            modelSize,
+            threadsPerAllocation,
+            numberOfAllocations,
+            1024,
+            ByteSizeValue.ofBytes(modelSize)
+        );
     }
 
     private static DiscoveryNode buildNode(String name, long nativeMemory, int allocatedProcessors) {

+ 3 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.ml.inference.deployment;
 
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.license.LicensedFeature;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.tasks.TaskId;
@@ -53,7 +54,8 @@ public class TrainedModelDeploymentTaskTests extends ESTestCase {
                 randomLongBetween(1, Long.MAX_VALUE),
                 randomInt(5),
                 randomInt(5),
-                randomInt(5)
+                randomInt(5),
+                randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, Long.MAX_VALUE))
             ),
             nodeService,
             licenseState,

+ 19 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java

@@ -44,7 +44,25 @@ public class PyTorchBuilderTests extends ESTestCase {
     }
 
     public void testBuild() throws IOException, InterruptedException {
-        new PyTorchBuilder(nativeController, processPipes, 2, 4).build();
+        new PyTorchBuilder(nativeController, processPipes, 2, 4, 12).build();
+
+        verify(nativeController).startProcess(commandCaptor.capture());
+
+        assertThat(
+            commandCaptor.getValue(),
+            contains(
+                "./pytorch_inference",
+                "--validElasticLicenseKeyConfirmed=true",
+                "--numThreadsPerAllocation=2",
+                "--numAllocations=4",
+                "--cacheMemorylimitBytes=12",
+                PROCESS_PIPES_ARG
+            )
+        );
+    }
+
+    public void testBuildWithNoCache() throws IOException, InterruptedException {
+        new PyTorchBuilder(nativeController, processPipes, 2, 4, 0).build();
 
         verify(nativeController).startProcess(commandCaptor.capture());
 

+ 8 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java

@@ -126,7 +126,14 @@ public class NodeLoadDetectorTests extends ESTestCase {
                             .addNewAssignment(
                                 "model1",
                                 TrainedModelAssignment.Builder.empty(
-                                    new StartTrainedModelDeploymentAction.TaskParams("model1", MODEL_MEMORY_REQUIREMENT, 1, 1, 1024)
+                                    new StartTrainedModelDeploymentAction.TaskParams(
+                                        "model1",
+                                        MODEL_MEMORY_REQUIREMENT,
+                                        1,
+                                        1,
+                                        1024,
+                                        ByteSizeValue.ofBytes(MODEL_MEMORY_REQUIREMENT)
+                                    )
                                 )
                                     .addRoutingEntry("_node_id4", new RoutingInfo(1, 1, RoutingState.STARTING, ""))
                                     .addRoutingEntry("_node_id2", new RoutingInfo(1, 1, RoutingState.FAILED, "test"))

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 38 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml


برخی فایل ها در این مقایسه diff نمایش داده نمی شوند زیرا تعداد فایل ها بسیار زیاد است