ソースを参照

[ML] Report thread settings per node for trained model deployments (#81723)

When a trained model deployment is started the user may request specific
number of threads for inference and/or parallel forwarding. However,
when the pytorch process starts these settings are adjusted so that
the actual number of threads doesn't exceed the available number of
threads. This commit changes the result format so that other type of
results may be read from the process than just inference results.
We then introduce a result for reading the thread settings from the
process and we report it in the response of the trained model _stats API.

Closes #81149
Dimitris Athanasiou 3 年 前
コミット
3851f3f75b
28 ファイル変更416 行追加237 行削除
  1. 60 7
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStats.java
  2. 47 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java
  3. 14 10
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatsTests.java
  4. 2 0
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  5. 7 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java
  6. 16 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  7. 8 30
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java
  8. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java
  9. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java
  10. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java
  11. 6 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java
  12. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java
  13. 6 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java
  14. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java
  15. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java
  16. 35 21
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java
  17. 8 36
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java
  18. 50 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java
  19. 40 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java
  20. 6 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java
  21. 5 81
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java
  22. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java
  23. 3 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java
  24. 4 4
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java
  25. 3 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java
  26. 15 11
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java
  27. 34 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java
  28. 35 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java

+ 60 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStats.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.core.ml.inference.allocation;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -33,6 +34,8 @@ public class AllocationStats implements ToXContentObject, Writeable {
         private final Integer pendingCount;
         private final RoutingStateAndReason routingState;
         private final Instant startTime;
+        private final Integer inferenceThreads;
+        private final Integer modelThreads;
 
         public static AllocationStats.NodeStats forStartedState(
             DiscoveryNode node,
@@ -40,7 +43,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
             Double avgInferenceTime,
             int pendingCount,
             Instant lastAccess,
-            Instant startTime
+            Instant startTime,
+            Integer inferenceThreads,
+            Integer modelThreads
         ) {
             return new AllocationStats.NodeStats(
                 node,
@@ -49,22 +54,26 @@ public class AllocationStats implements ToXContentObject, Writeable {
                 lastAccess,
                 pendingCount,
                 new RoutingStateAndReason(RoutingState.STARTED, null),
-                Objects.requireNonNull(startTime)
+                Objects.requireNonNull(startTime),
+                inferenceThreads,
+                modelThreads
             );
         }
 
         public static AllocationStats.NodeStats forNotStartedState(DiscoveryNode node, RoutingState state, String reason) {
-            return new AllocationStats.NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason), null);
+            return new AllocationStats.NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason), null, null, null);
         }
 
-        private NodeStats(
+        public NodeStats(
             DiscoveryNode node,
             Long inferenceCount,
             Double avgInferenceTime,
             Instant lastAccess,
             Integer pendingCount,
             RoutingStateAndReason routingState,
-            @Nullable Instant startTime
+            @Nullable Instant startTime,
+            @Nullable Integer inferenceThreads,
+            @Nullable Integer modelThreads
         ) {
             this.node = node;
             this.inferenceCount = inferenceCount;
@@ -73,6 +82,8 @@ public class AllocationStats implements ToXContentObject, Writeable {
             this.pendingCount = pendingCount;
             this.routingState = routingState;
             this.startTime = startTime;
+            this.inferenceThreads = inferenceThreads;
+            this.modelThreads = modelThreads;
 
             // if lastAccess time is null there have been no inferences
             assert this.lastAccess != null || (inferenceCount == null || inferenceCount == 0);
@@ -86,6 +97,14 @@ public class AllocationStats implements ToXContentObject, Writeable {
             this.pendingCount = in.readOptionalVInt();
             this.routingState = in.readOptionalWriteable(RoutingStateAndReason::new);
             this.startTime = in.readOptionalInstant();
+
+            if (in.getVersion().onOrAfter(Version.V_8_1_0)) {
+                this.inferenceThreads = in.readOptionalVInt();
+                this.modelThreads = in.readOptionalVInt();
+            } else {
+                this.inferenceThreads = null;
+                this.modelThreads = null;
+            }
         }
 
         public DiscoveryNode getNode() {
@@ -104,6 +123,18 @@ public class AllocationStats implements ToXContentObject, Writeable {
             return Optional.ofNullable(avgInferenceTime);
         }
 
+        public Instant getLastAccess() {
+            return lastAccess;
+        }
+
+        public Integer getPendingCount() {
+            return pendingCount;
+        }
+
+        public Instant getStartTime() {
+            return startTime;
+        }
+
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject();
@@ -128,6 +159,12 @@ public class AllocationStats implements ToXContentObject, Writeable {
             if (startTime != null) {
                 builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
             }
+            if (inferenceThreads != null) {
+                builder.field("inference_threads", inferenceThreads);
+            }
+            if (modelThreads != null) {
+                builder.field("model_threads", modelThreads);
+            }
             builder.endObject();
             return builder;
         }
@@ -141,6 +178,10 @@ public class AllocationStats implements ToXContentObject, Writeable {
             out.writeOptionalVInt(pendingCount);
             out.writeOptionalWriteable(routingState);
             out.writeOptionalInstant(startTime);
+            if (out.getVersion().onOrAfter(Version.V_8_1_0)) {
+                out.writeOptionalVInt(inferenceThreads);
+                out.writeOptionalVInt(modelThreads);
+            }
         }
 
         @Override
@@ -154,12 +195,24 @@ public class AllocationStats implements ToXContentObject, Writeable {
                 && Objects.equals(lastAccess, that.lastAccess)
                 && Objects.equals(pendingCount, that.pendingCount)
                 && Objects.equals(routingState, that.routingState)
-                && Objects.equals(startTime, that.startTime);
+                && Objects.equals(startTime, that.startTime)
+                && Objects.equals(inferenceThreads, that.inferenceThreads)
+                && Objects.equals(modelThreads, that.modelThreads);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(node, inferenceCount, avgInferenceTime, lastAccess, pendingCount, routingState, startTime);
+            return Objects.hash(
+                node,
+                inferenceCount,
+                avgInferenceTime,
+                lastAccess,
+                pendingCount,
+                routingState,
+                startTime,
+                inferenceThreads,
+                modelThreads
+            );
         }
     }
 

+ 47 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.ingest.IngestStats;
 import org.elasticsearch.xpack.core.action.util.QueryPage;
 import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatsTests;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStatsTests;
 
@@ -91,6 +92,52 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                     RESULTS_FIELD
                 )
             );
+        } else if (version.before(Version.V_8_1_0)) {
+            return new Response(
+                new QueryPage<>(
+                    instance.getResources()
+                        .results()
+                        .stream()
+                        .map(
+                            stats -> new Response.TrainedModelStats(
+                                stats.getModelId(),
+                                stats.getIngestStats(),
+                                stats.getPipelineCount(),
+                                stats.getInferenceStats(),
+                                stats.getDeploymentStats() == null
+                                    ? null
+                                    : new AllocationStats(
+                                        stats.getDeploymentStats().getModelId(),
+                                        stats.getDeploymentStats().getModelSize(),
+                                        stats.getDeploymentStats().getInferenceThreads(),
+                                        stats.getDeploymentStats().getModelThreads(),
+                                        stats.getDeploymentStats().getQueueCapacity(),
+                                        stats.getDeploymentStats().getStartTime(),
+                                        stats.getDeploymentStats()
+                                            .getNodeStats()
+                                            .stream()
+                                            .map(
+                                                nodeStats -> new AllocationStats.NodeStats(
+                                                    nodeStats.getNode(),
+                                                    nodeStats.getInferenceCount().orElse(null),
+                                                    nodeStats.getAvgInferenceTime().orElse(null),
+                                                    nodeStats.getLastAccess(),
+                                                    nodeStats.getPendingCount(),
+                                                    nodeStats.getRoutingState(),
+                                                    nodeStats.getStartTime(),
+                                                    null,
+                                                    null
+                                                )
+                                            )
+                                            .toList()
+                                    )
+                            )
+                        )
+                        .collect(Collectors.toList()),
+                    instance.getResources().count(),
+                    RESULTS_FIELD
+                )
+            );
         }
         return instance;
     }

+ 14 - 10
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatsTests.java

@@ -28,16 +28,7 @@ public class AllocationStatsTests extends AbstractWireSerializingTestCase<Alloca
         for (int i = 0; i < numNodes; i++) {
             var node = new DiscoveryNode("node_" + i, new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT);
             if (randomBoolean()) {
-                nodeStatsList.add(
-                    AllocationStats.NodeStats.forStartedState(
-                        node,
-                        randomNonNegativeLong(),
-                        randomBoolean() ? randomDoubleBetween(0.0, 100.0, true) : null,
-                        randomIntBetween(0, 100),
-                        Instant.now(),
-                        Instant.now()
-                    )
-                );
+                nodeStatsList.add(randomNodeStats(node));
             } else {
                 nodeStatsList.add(
                     AllocationStats.NodeStats.forNotStartedState(
@@ -62,6 +53,19 @@ public class AllocationStatsTests extends AbstractWireSerializingTestCase<Alloca
         );
     }
 
+    public static AllocationStats.NodeStats randomNodeStats(DiscoveryNode node) {
+        return AllocationStats.NodeStats.forStartedState(
+            node,
+            randomNonNegativeLong(),
+            randomBoolean() ? randomDoubleBetween(0.0, 100.0, true) : null,
+            randomIntBetween(0, 100),
+            Instant.now(),
+            Instant.now(),
+            randomIntBetween(1, 16),
+            randomIntBetween(1, 16)
+        );
+    }
+
     @Override
     protected Writeable.Reader<AllocationStats> instanceReader() {
         return AllocationStats::new;

+ 2 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.ml.integration;
 
 import org.apache.http.util.EntityUtils;
+import org.apache.lucene.util.LuceneTestCase.AwaitsFix;
 import org.elasticsearch.client.Request;
 import org.elasticsearch.client.Response;
 import org.elasticsearch.client.ResponseException;
@@ -73,6 +74,7 @@ import static org.hamcrest.Matchers.nullValue;
  * torch.jit.save(traced_model, "simplemodel.pt")
  * ## End Python
  */
+@AwaitsFix(bugUrl = "until https://github.com/elastic/ml-cpp/pull/2159 is merged")
 public class PyTorchModelIT extends ESRestTestCase {
 
     private static final String BASIC_AUTH_VALUE_SUPER_USER = UsernamePasswordToken.basicAuthHeaderValue(

+ 7 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -298,12 +298,14 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
             nodeStats.add(
                 AllocationStats.NodeStats.forStartedState(
                     clusterService.localNode(),
-                    stats.get().getTimingStats().getCount(),
+                    stats.get().timingStats().getCount(),
                     // avoid reporting the average time as 0 if count < 1
-                    (stats.get().getTimingStats().getCount() > 0) ? stats.get().getTimingStats().getAverage() : null,
-                    stats.get().getPendingCount(),
-                    stats.get().getLastUsed(),
-                    stats.get().getStartTime()
+                    (stats.get().timingStats().getCount() > 0) ? stats.get().timingStats().getAverage() : null,
+                    stats.get().pendingCount(),
+                    stats.get().lastUsed(),
+                    stats.get().startTime(),
+                    stats.get().inferenceThreads(),
+                    stats.get().modelThreads()
                 )
             );
         } else {

+ 16 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -50,6 +50,7 @@ import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcess
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 import org.elasticsearch.xpack.ml.job.process.ProcessWorkerExecutorService;
 
 import java.io.IOException;
@@ -108,7 +109,9 @@ public class DeploymentManager {
                     processContext.startTime,
                     processContext.getResultProcessor().getTimingStats(),
                     processContext.getResultProcessor().getLastUsed(),
-                    processContext.executorService.queueSize() + processContext.getResultProcessor().numberOfPendingResults()
+                    processContext.executorService.queueSize() + processContext.getResultProcessor().numberOfPendingResults(),
+                    processContext.inferenceThreads,
+                    processContext.modelThreads
                 )
             );
     }
@@ -366,8 +369,8 @@ public class DeploymentManager {
                     .registerRequest(
                         requestIdStr,
                         ActionListener.wrap(
-                            pyTorchResult -> processResult(
-                                pyTorchResult,
+                            inferenceResult -> processResult(
+                                inferenceResult,
                                 processContext,
                                 request.tokenization,
                                 processor.getResultProcessor((NlpConfig) config),
@@ -401,14 +404,14 @@ public class DeploymentManager {
         }
 
         private void processResult(
-            PyTorchResult pyTorchResult,
+            PyTorchInferenceResult inferenceResult,
             ProcessContext context,
             TokenizationResult tokenization,
             NlpTask.ResultProcessor inferenceResultsProcessor,
             ActionListener<InferenceResults> resultsListener
         ) {
-            if (pyTorchResult.isError()) {
-                resultsListener.onFailure(new ElasticsearchStatusException(pyTorchResult.getError(), RestStatus.INTERNAL_SERVER_ERROR));
+            if (inferenceResult.isError()) {
+                resultsListener.onFailure(new ElasticsearchStatusException(inferenceResult.getError(), RestStatus.INTERNAL_SERVER_ERROR));
                 return;
             }
 
@@ -424,7 +427,7 @@ public class DeploymentManager {
                 );
                 return;
             }
-            InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult);
+            InferenceResults results = inferenceResultsProcessor.processResult(tokenization, inferenceResult);
             logger.debug(() -> new ParameterizedMessage("[{}] processed result for request [{}]", context.task.getModelId(), requestId));
             resultsListener.onResponse(results);
         }
@@ -440,10 +443,15 @@ public class DeploymentManager {
         private final PyTorchStateStreamer stateStreamer;
         private final ProcessWorkerExecutorService executorService;
         private volatile Instant startTime;
+        private volatile Integer inferenceThreads;
+        private volatile Integer modelThreads;
 
         ProcessContext(TrainedModelDeploymentTask task, ExecutorService executorService) {
             this.task = Objects.requireNonNull(task);
-            resultProcessor = new PyTorchResultProcessor(task.getModelId());
+            resultProcessor = new PyTorchResultProcessor(task.getModelId(), threadSettings -> {
+                this.inferenceThreads = threadSettings.inferenceThreads();
+                this.modelThreads = threadSettings.modelThreads();
+            });
             this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
             this.executorService = new ProcessWorkerExecutorService(
                 threadPool.getThreadContext(),

+ 8 - 30
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java

@@ -10,33 +10,11 @@ package org.elasticsearch.xpack.ml.inference.deployment;
 import java.time.Instant;
 import java.util.LongSummaryStatistics;
 
-public class ModelStats {
-
-    private final Instant startTime;
-    private final LongSummaryStatistics timingStats;
-    private final Instant lastUsed;
-    private final int pendingCount;
-
-    ModelStats(Instant startTime, LongSummaryStatistics timingStats, Instant lastUsed, int pendingCount) {
-        this.startTime = startTime;
-        this.timingStats = timingStats;
-        this.lastUsed = lastUsed;
-        this.pendingCount = pendingCount;
-    }
-
-    public Instant getStartTime() {
-        return startTime;
-    }
-
-    public LongSummaryStatistics getTimingStats() {
-        return timingStats;
-    }
-
-    public Instant getLastUsed() {
-        return lastUsed;
-    }
-
-    public int getPendingCount() {
-        return pendingCount;
-    }
-}
+public record ModelStats(
+    Instant startTime,
+    LongSummaryStatistics timingStats,
+    Instant lastUsed,
+    int pendingCount,
+    Integer inferenceThreads,
+    Integer modelThreads
+) {}

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

@@ -14,9 +14,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -80,7 +80,7 @@ public class FillMaskProcessor implements NlpTask.Processor {
 
     static InferenceResults processResult(
         TokenizationResult tokenization,
-        PyTorchResult pyTorchResult,
+        PyTorchInferenceResult pyTorchResult,
         NlpTokenizer tokenizer,
         int numResults,
         String resultsField

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

@@ -15,10 +15,10 @@ import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -193,7 +193,7 @@ public class NerProcessor implements NlpTask.Processor {
         }
 
         @Override
-        public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
+        public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) {
             if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
                 return new WarningInferenceResults("no valid tokens to build result");
             }

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NlpTask.java

@@ -16,9 +16,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.io.IOException;
 import java.util.List;
@@ -102,7 +102,7 @@ public class NlpTask {
     }
 
     public interface ResultProcessor {
-        InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult);
+        InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult);
     }
 
     public interface Processor {

+ 6 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/PassThroughProcessor.java

@@ -11,9 +11,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.util.List;
 import java.util.Optional;
@@ -49,7 +49,11 @@ public class PassThroughProcessor implements NlpTask.Processor {
         return (tokenization, pyTorchResult) -> processResult(tokenization, pyTorchResult, config.getResultsField());
     }
 
-    private static InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult, String resultsField) {
+    private static InferenceResults processResult(
+        TokenizationResult tokenization,
+        PyTorchInferenceResult pyTorchResult,
+        String resultsField
+    ) {
         // TODO - process all results in the batch
         return new PyTorchPassThroughResults(
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java

@@ -13,9 +13,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.util.Arrays;
 import java.util.Comparator;
@@ -81,7 +81,7 @@ public class TextClassificationProcessor implements NlpTask.Processor {
 
     static InferenceResults processResult(
         TokenizationResult tokenization,
-        PyTorchResult pyTorchResult,
+        PyTorchInferenceResult pyTorchResult,
         int numTopClasses,
         List<String> labels,
         String resultsField

+ 6 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextEmbeddingProcessor.java

@@ -11,9 +11,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.util.List;
 import java.util.Optional;
@@ -46,7 +46,11 @@ public class TextEmbeddingProcessor implements NlpTask.Processor {
         return (tokenization, pyTorchResult) -> processResult(tokenization, pyTorchResult, config.getResultsField());
     }
 
-    private static InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult, String resultsField) {
+    private static InferenceResults processResult(
+        TokenizationResult tokenization,
+        PyTorchInferenceResult pyTorchResult,
+        String resultsField
+    ) {
         // TODO - process all results in the batch
         return new TextEmbeddingResults(
             Optional.ofNullable(resultsField).orElse(DEFAULT_RESULTS_FIELD),

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java

@@ -16,9 +16,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -146,7 +146,7 @@ public class ZeroShotClassificationProcessor implements NlpTask.Processor {
         }
 
         @Override
-        public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
+        public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) {
             if (pyTorchResult.getInferenceResult().length < 1) {
                 return new WarningInferenceResults("Zero shot classification result has no data");
             }

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java

@@ -10,7 +10,7 @@ package org.elasticsearch.xpack.ml.inference.pytorch.process;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
 import org.elasticsearch.xpack.ml.process.AbstractNativeProcess;
 import org.elasticsearch.xpack.ml.process.NativeController;
 import org.elasticsearch.xpack.ml.process.ProcessPipes;

+ 35 - 21
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -11,7 +11,9 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;
 
 import java.time.Instant;
 import java.util.Iterator;
@@ -19,6 +21,7 @@ import java.util.LongSummaryStatistics;
 import java.util.Objects;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
+import java.util.function.Consumer;
 
 public class PyTorchResultProcessor {
 
@@ -29,14 +32,16 @@ public class PyTorchResultProcessor {
     private final String deploymentId;
     private volatile boolean isStopping;
     private final LongSummaryStatistics timingStats;
+    private final Consumer<ThreadSettings> threadSettingsConsumer;
     private Instant lastUsed;
 
-    public PyTorchResultProcessor(String deploymentId) {
+    public PyTorchResultProcessor(String deploymentId, Consumer<ThreadSettings> threadSettingsConsumer) {
         this.deploymentId = Objects.requireNonNull(deploymentId);
         this.timingStats = new LongSummaryStatistics();
+        this.threadSettingsConsumer = Objects.requireNonNull(threadSettingsConsumer);
     }
 
-    public void registerRequest(String requestId, ActionListener<PyTorchResult> listener) {
+    public void registerRequest(String requestId, ActionListener<PyTorchInferenceResult> listener) {
         pendingResults.computeIfAbsent(requestId, k -> new PendingResult(listener));
     }
 
@@ -55,13 +60,13 @@ public class PyTorchResultProcessor {
             Iterator<PyTorchResult> iterator = process.readResults();
             while (iterator.hasNext()) {
                 PyTorchResult result = iterator.next();
-                logger.trace(() -> new ParameterizedMessage("[{}] Parsed result with id [{}]", deploymentId, result.getRequestId()));
-                processResult(result);
-                PendingResult pendingResult = pendingResults.remove(result.getRequestId());
-                if (pendingResult == null) {
-                    logger.debug(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId()));
-                } else {
-                    pendingResult.listener.onResponse(result);
+                PyTorchInferenceResult inferenceResult = result.inferenceResult();
+                if (inferenceResult != null) {
+                    processInferenceResult(inferenceResult);
+                }
+                ThreadSettings threadSettings = result.threadSettings();
+                if (threadSettings != null) {
+                    threadSettingsConsumer.accept(threadSettings);
                 }
             }
         } catch (Exception e) {
@@ -71,7 +76,7 @@ public class PyTorchResultProcessor {
             }
             pendingResults.forEach(
                 (id, pendingResult) -> pendingResult.listener.onResponse(
-                    new PyTorchResult(
+                    new PyTorchInferenceResult(
                         id,
                         null,
                         null,
@@ -85,7 +90,7 @@ public class PyTorchResultProcessor {
         } finally {
             pendingResults.forEach(
                 (id, pendingResult) -> pendingResult.listener.onResponse(
-                    new PyTorchResult(id, null, null, "inference canceled as process is stopping")
+                    new PyTorchInferenceResult(id, null, null, "inference canceled as process is stopping")
                 )
             );
             pendingResults.clear();
@@ -93,15 +98,24 @@ public class PyTorchResultProcessor {
         logger.debug(() -> new ParameterizedMessage("[{}] Results processing finished", deploymentId));
     }
 
-    public synchronized LongSummaryStatistics getTimingStats() {
-        return new LongSummaryStatistics(timingStats.getCount(), timingStats.getMin(), timingStats.getMax(), timingStats.getSum());
+    private void processInferenceResult(PyTorchInferenceResult inferenceResult) {
+        logger.trace(() -> new ParameterizedMessage("[{}] Parsed result with id [{}]", deploymentId, inferenceResult.getRequestId()));
+        if (inferenceResult.isError() == false) {
+            synchronized (this) {
+                timingStats.accept(inferenceResult.getTimeMs());
+                lastUsed = Instant.now();
+            }
+        }
+        PendingResult pendingResult = pendingResults.remove(inferenceResult.getRequestId());
+        if (pendingResult == null) {
+            logger.debug(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, inferenceResult.getRequestId()));
+        } else {
+            pendingResult.listener.onResponse(inferenceResult);
+        }
     }
 
-    private synchronized void processResult(PyTorchResult result) {
-        if (result.isError() == false) {
-            timingStats.accept(result.getTimeMs());
-            lastUsed = Instant.now();
-        }
+    public synchronized LongSummaryStatistics getTimingStats() {
+        return new LongSummaryStatistics(timingStats.getCount(), timingStats.getMin(), timingStats.getMax(), timingStats.getSum());
     }
 
     public synchronized Instant getLastUsed() {
@@ -117,9 +131,9 @@ public class PyTorchResultProcessor {
     }
 
     public static class PendingResult {
-        public final ActionListener<PyTorchResult> listener;
+        public final ActionListener<PyTorchInferenceResult> listener;
 
-        public PendingResult(ActionListener<PyTorchResult> listener) {
+        public PendingResult(ActionListener<PyTorchInferenceResult> listener) {
             this.listener = Objects.requireNonNull(listener);
         }
     }

+ 8 - 36
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/PyTorchResult.java → x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java

@@ -5,11 +5,8 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.inference.deployment;
+package org.elasticsearch.xpack.ml.inference.pytorch.results;
 
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ObjectParser;
@@ -28,16 +25,16 @@ import java.util.Objects;
  * If the error field is not null the instance is an error result
  * so the inference and time_ms fields will be null.
  */
-public class PyTorchResult implements ToXContentObject, Writeable {
+public class PyTorchInferenceResult implements ToXContentObject {
 
     private static final ParseField REQUEST_ID = new ParseField("request_id");
     private static final ParseField INFERENCE = new ParseField("inference");
     private static final ParseField ERROR = new ParseField("error");
     private static final ParseField TIME_MS = new ParseField("time_ms");
 
-    public static final ConstructingObjectParser<PyTorchResult, Void> PARSER = new ConstructingObjectParser<>(
-        "pytorch_result",
-        a -> new PyTorchResult((String) a[0], (double[][][]) a[1], (Long) a[2], (String) a[3])
+    public static final ConstructingObjectParser<PyTorchInferenceResult, Void> PARSER = new ConstructingObjectParser<>(
+        "pytorch_inference_result",
+        a -> new PyTorchInferenceResult((String) a[0], (double[][][]) a[1], (Long) a[2], (String) a[3])
     );
 
     static {
@@ -52,7 +49,7 @@ public class PyTorchResult implements ToXContentObject, Writeable {
         PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), ERROR);
     }
 
-    public static PyTorchResult fromXContent(XContentParser parser) throws IOException {
+    public static PyTorchInferenceResult fromXContent(XContentParser parser) throws IOException {
         return PARSER.parse(parser, null);
     }
 
@@ -61,25 +58,13 @@ public class PyTorchResult implements ToXContentObject, Writeable {
     private final Long timeMs;
     private final String error;
 
-    public PyTorchResult(String requestId, @Nullable double[][][] inference, @Nullable Long timeMs, @Nullable String error) {
+    public PyTorchInferenceResult(String requestId, @Nullable double[][][] inference, @Nullable Long timeMs, @Nullable String error) {
         this.requestId = Objects.requireNonNull(requestId);
         this.inference = inference;
         this.timeMs = timeMs;
         this.error = error;
     }
 
-    public PyTorchResult(StreamInput in) throws IOException {
-        requestId = in.readString();
-        boolean hasInference = in.readBoolean();
-        if (hasInference) {
-            inference = in.readArray(in2 -> in2.readArray(StreamInput::readDoubleArray, double[][]::new), double[][][]::new);
-        } else {
-            inference = null;
-        }
-        timeMs = in.readOptionalLong();
-        error = in.readOptionalString();
-    }
-
     public String getRequestId() {
         return requestId;
     }
@@ -129,19 +114,6 @@ public class PyTorchResult implements ToXContentObject, Writeable {
         return builder;
     }
 
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        out.writeString(requestId);
-        if (inference == null) {
-            out.writeBoolean(false);
-        } else {
-            out.writeBoolean(true);
-            out.writeArray((out2, arr) -> out2.writeArray(StreamOutput::writeDoubleArray, arr), inference);
-        }
-        out.writeOptionalLong(timeMs);
-        out.writeOptionalString(error);
-    }
-
     @Override
     public int hashCode() {
         return Objects.hash(requestId, Arrays.deepHashCode(inference), error);
@@ -152,7 +124,7 @@ public class PyTorchResult implements ToXContentObject, Writeable {
         if (this == other) return true;
         if (other == null || getClass() != other.getClass()) return false;
 
-        PyTorchResult that = (PyTorchResult) other;
+        PyTorchInferenceResult that = (PyTorchInferenceResult) other;
         return Objects.equals(requestId, that.requestId)
             && Arrays.deepEquals(inference, that.inference)
             && Objects.equals(error, that.error);

+ 50 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java

@@ -0,0 +1,50 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.pytorch.results;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+
+/**
+ * The top level object capturing output from the pytorch process
+ */
+public record PyTorchResult(@Nullable PyTorchInferenceResult inferenceResult, @Nullable ThreadSettings threadSettings)
+    implements
+        ToXContentObject {
+
+    private static final ParseField RESULT = new ParseField("result");
+    private static final ParseField THREAD_SETTINGS = new ParseField("thread_settings");
+
+    public static ConstructingObjectParser<PyTorchResult, Void> PARSER = new ConstructingObjectParser<>(
+        "pytorch_result",
+        a -> new PyTorchResult((PyTorchInferenceResult) a[0], (ThreadSettings) a[1])
+    );
+
+    static {
+        PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), PyTorchInferenceResult.PARSER, RESULT);
+        PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ThreadSettings.PARSER, THREAD_SETTINGS);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (inferenceResult != null) {
+            builder.field(RESULT.getPreferredName(), inferenceResult);
+        }
+        if (threadSettings != null) {
+            builder.field(THREAD_SETTINGS.getPreferredName(), threadSettings);
+        }
+        builder.endObject();
+        return builder;
+    }
+}

+ 40 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java

@@ -0,0 +1,40 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.pytorch.results;
+
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+
+public record ThreadSettings(int inferenceThreads, int modelThreads) implements ToXContentObject {
+
+    private static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
+    private static final ParseField MODEL_THREADS = new ParseField("model_threads");
+
+    public static ConstructingObjectParser<ThreadSettings, Void> PARSER = new ConstructingObjectParser<>(
+        "thread_settings",
+        a -> new ThreadSettings((int) a[0], (int) a[1])
+    );
+
+    static {
+        PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
+        PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
+        builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
+        builder.endObject();
+        return builder;
+    }
+}

+ 6 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java

@@ -357,7 +357,9 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                 42.0,
                                 0,
                                 Instant.now(),
-                                Instant.now()
+                                Instant.now(),
+                                randomIntBetween(1, 16),
+                                randomIntBetween(1, 16)
                             ),
                             AllocationStats.NodeStats.forStartedState(
                                 new DiscoveryNode("bar", new TransportAddress(TransportAddress.META_ADDRESS, 3), Version.CURRENT),
@@ -365,7 +367,9 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                 50.0,
                                 0,
                                 Instant.now(),
-                                Instant.now()
+                                Instant.now(),
+                                randomIntBetween(1, 16),
+                                randomIntBetween(1, 16)
                             )
                         )
                     ).setState(AllocationState.STARTED).setAllocationStatus(new AllocationStatus(2, 2))

+ 5 - 81
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java

@@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsActionResponseTests;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatsTests;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
@@ -26,7 +27,6 @@ import java.net.UnknownHostException;
 import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -75,26 +75,8 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
         DiscoveryNodes nodes = buildNodes("node1", "node2", "node3");
 
         List<AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
-        nodeStatsList.add(
-            AllocationStats.NodeStats.forStartedState(
-                nodes.get("node1"),
-                randomNonNegativeLong(),
-                randomDoubleBetween(0.0, 100.0, true),
-                randomIntBetween(1, 100),
-                Instant.now(),
-                Instant.now()
-            )
-        );
-        nodeStatsList.add(
-            AllocationStats.NodeStats.forStartedState(
-                nodes.get("node2"),
-                randomNonNegativeLong(),
-                randomDoubleBetween(0.0, 100.0, true),
-                randomIntBetween(1, 100),
-                Instant.now(),
-                Instant.now()
-            )
-        );
+        nodeStatsList.add(AllocationStatsTests.randomNodeStats(nodes.get("node1")));
+        nodeStatsList.add(AllocationStatsTests.randomNodeStats(nodes.get("node2")));
 
         var model1 = new AllocationStats(
             "model1",
@@ -129,26 +111,8 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
         DiscoveryNodes nodes = buildNodes("node1", "node2");
 
         List<AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
-        nodeStatsList.add(
-            AllocationStats.NodeStats.forStartedState(
-                nodes.get("node1"),
-                randomNonNegativeLong(),
-                randomDoubleBetween(0.0, 100.0, true),
-                randomIntBetween(1, 100),
-                Instant.now(),
-                Instant.now()
-            )
-        );
-        nodeStatsList.add(
-            AllocationStats.NodeStats.forStartedState(
-                nodes.get("node2"),
-                randomNonNegativeLong(),
-                randomDoubleBetween(0.0, 100.0, true),
-                randomIntBetween(1, 100),
-                Instant.now(),
-                Instant.now()
-            )
-        );
+        nodeStatsList.add(AllocationStatsTests.randomNodeStats(nodes.get("node1")));
+        nodeStatsList.add(AllocationStatsTests.randomNodeStats(nodes.get("node2")));
 
         var model1 = new AllocationStats(
             "model1",
@@ -188,46 +152,6 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
         return builder.build();
     }
 
-    private AllocationStats randomDeploymentStats() {
-        List<AllocationStats.NodeStats> nodeStatsList = new ArrayList<>();
-        int numNodes = randomIntBetween(1, 4);
-        for (int i = 0; i < numNodes; i++) {
-            var node = new DiscoveryNode("node_" + i, new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Version.CURRENT);
-            if (randomBoolean()) {
-                nodeStatsList.add(
-                    AllocationStats.NodeStats.forStartedState(
-                        node,
-                        randomNonNegativeLong(),
-                        randomDoubleBetween(0.0, 100.0, true),
-                        randomIntBetween(1, 100),
-                        Instant.now(),
-                        Instant.now()
-                    )
-                );
-            } else {
-                nodeStatsList.add(
-                    AllocationStats.NodeStats.forNotStartedState(
-                        node,
-                        randomFrom(RoutingState.values()),
-                        randomBoolean() ? null : "a good reason"
-                    )
-                );
-            }
-        }
-
-        nodeStatsList.sort(Comparator.comparing(n -> n.getNode().getId()));
-
-        return new AllocationStats(
-            randomAlphaOfLength(5),
-            ByteSizeValue.ofBytes(randomNonNegativeLong()),
-            randomBoolean() ? null : randomIntBetween(1, 8),
-            randomBoolean() ? null : randomIntBetween(1, 8),
-            randomBoolean() ? null : randomIntBetween(1, 10000),
-            Instant.now(),
-            nodeStatsList
-        );
-    }
-
     private static TrainedModelAllocation createAllocation(String modelId) {
         return TrainedModelAllocation.Builder.empty(new StartTrainedModelDeploymentAction.TaskParams(modelId, 1024, 1, 1, 1)).build();
     }

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java

@@ -47,7 +47,7 @@ public class DeploymentManagerTests extends ESTestCase {
     }
 
     public void testInferListenerOnlyCalledOnce() {
-        PyTorchResultProcessor resultProcessor = new PyTorchResultProcessor("1");
+        PyTorchResultProcessor resultProcessor = new PyTorchResultProcessor("1", threadSettings -> {});
         DeploymentManager.ProcessContext processContext = mock(DeploymentManager.ProcessContext.class);
         when(processContext.getResultProcessor()).thenReturn(resultProcessor);
 

+ 3 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

@@ -14,11 +14,11 @@ import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.util.Arrays;
 import java.util.Collections;
@@ -65,7 +65,7 @@ public class FillMaskProcessorTests extends ESTestCase {
         String resultsField = randomAlphaOfLength(10);
         FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult(
             tokenization,
-            new PyTorchResult("1", scores, 0L, null),
+            new PyTorchInferenceResult("1", scores, 0L, null),
             tokenizer,
             4,
             resultsField
@@ -89,7 +89,7 @@ public class FillMaskProcessorTests extends ESTestCase {
         TokenizationResult tokenization = new TokenizationResult(Collections.emptyList());
         tokenization.addTokenization("", false, Collections.emptyList(), new int[] {}, new int[] {});
 
-        PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][] { { {} } }, 0L, null);
+        PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, null);
         assertThat(
             FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)),
             instanceOf(WarningInferenceResults.class)

+ 4 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

@@ -15,11 +15,11 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -92,7 +92,7 @@ public class NerProcessorTests extends ESTestCase {
         NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, false);
         TokenizationResult tokenization = tokenize(List.of(BertTokenizer.PAD_TOKEN, BertTokenizer.UNKNOWN_TOKEN), "");
         assertThat(
-            processor.processResult(tokenization, new PyTorchResult("test", null, 0L, null)),
+            processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, null)),
             instanceOf(WarningInferenceResults.class)
         );
     }
@@ -114,7 +114,7 @@ public class NerProcessorTests extends ESTestCase {
                 { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
                 { 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london
             } };
-        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 1L, null));
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
 
         assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));
@@ -151,7 +151,7 @@ public class NerProcessorTests extends ESTestCase {
                 { 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in
                 { 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london
             } };
-        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchResult("1", scores, 1L, null));
+        NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
 
         assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
         assertThat(result.getEntityGroups().size(), equalTo(2));

+ 3 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

@@ -16,8 +16,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
-import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
 
 import java.io.IOException;
 import java.util.Arrays;
@@ -32,7 +32,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
 
     public void testInvalidResult() {
         {
-            PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] {}, 0L, null);
+            PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] {}, 0L, null);
             InferenceResults inferenceResults = TextClassificationProcessor.processResult(
                 null,
                 torchResult,
@@ -44,7 +44,7 @@ public class TextClassificationProcessorTests extends ESTestCase {
             assertEquals("Text classification result has no data", ((WarningInferenceResults) inferenceResults).getWarning());
         }
         {
-            PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] { { { 1.0 } } }, 0L, null);
+            PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] { { { 1.0 } } }, 0L, null);
             InferenceResults inferenceResults = TextClassificationProcessor.processResult(
                 null,
                 torchResult,

+ 15 - 11
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/PyTorchResultTests.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java

@@ -5,31 +5,35 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.inference.deployment;
+package org.elasticsearch.xpack.ml.inference.pytorch.results;
 
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.test.AbstractXContentTestCase;
 import org.elasticsearch.xcontent.XContentParser;
 
 import java.io.IOException;
 
-public class PyTorchResultTests extends AbstractSerializingTestCase<PyTorchResult> {
+public class PyTorchInferenceResultTests extends AbstractXContentTestCase<PyTorchInferenceResult> {
+
     @Override
-    protected PyTorchResult doParseInstance(XContentParser parser) throws IOException {
-        return PyTorchResult.fromXContent(parser);
+    protected PyTorchInferenceResult doParseInstance(XContentParser parser) throws IOException {
+        return PyTorchInferenceResult.fromXContent(parser);
     }
 
     @Override
-    protected Writeable.Reader<PyTorchResult> instanceReader() {
-        return PyTorchResult::new;
+    protected boolean supportsUnknownFields() {
+        return false;
     }
 
     @Override
-    protected PyTorchResult createTestInstance() {
+    protected PyTorchInferenceResult createTestInstance() {
+        return createRandom();
+    }
+
+    public static PyTorchInferenceResult createRandom() {
         boolean createError = randomBoolean();
         String id = randomAlphaOfLength(6);
         if (createError) {
-            return new PyTorchResult(id, null, null, "This is an error message");
+            return new PyTorchInferenceResult(id, null, null, "This is an error message");
         } else {
             int rows = randomIntBetween(1, 10);
             int columns = randomIntBetween(1, 10);
@@ -42,7 +46,7 @@ public class PyTorchResultTests extends AbstractSerializingTestCase<PyTorchResul
                     }
                 }
             }
-            return new PyTorchResult(id, arr, randomLong(), null);
+            return new PyTorchInferenceResult(id, arr, randomLong(), null);
         }
     }
 }

+ 34 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java

@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.pytorch.results;
+
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+
+public class PyTorchResultTests extends AbstractXContentTestCase<PyTorchResult> {
+
+    @Override
+    protected PyTorchResult createTestInstance() {
+        return new PyTorchResult(
+            randomBoolean() ? null : PyTorchInferenceResultTests.createRandom(),
+            randomBoolean() ? null : ThreadSettingsTests.createRandom()
+        );
+    }
+
+    @Override
+    protected PyTorchResult doParseInstance(XContentParser parser) throws IOException {
+        return PyTorchResult.PARSER.parse(parser, null);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return false;
+    }
+}

+ 35 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java

@@ -0,0 +1,35 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.pytorch.results;
+
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+
+public class ThreadSettingsTests extends AbstractXContentTestCase<ThreadSettings> {
+
+    public static ThreadSettings createRandom() {
+        return new ThreadSettings(randomIntBetween(1, Integer.MAX_VALUE), randomIntBetween(1, Integer.MAX_VALUE));
+    }
+
+    @Override
+    protected ThreadSettings createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected ThreadSettings doParseInstance(XContentParser parser) throws IOException {
+        return ThreadSettings.PARSER.parse(parser, null);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return false;
+    }
+}