Browse Source

[ML] Add error counts to trained model stats (#82705)

Adds inference_count, timeout_count, rejected_execution_count
and error_count fields to trained model stats.
David Kyle 3 years ago
parent
commit
c1fbf87de8

+ 5 - 0
docs/changelog/82705.yaml

@@ -0,0 +1,5 @@
+pr: 82705
+summary: Add error counts to trained model stats
+area: Machine Learning
+type: enhancement
+issues: []

+ 75 - 3
docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc

@@ -112,10 +112,26 @@ The detailed allocation state related to the nodes.
 The desired number of nodes for model allocation.
 ======
 
+`error_count`:::
+(integer)
+The sum of `error_count` for all nodes in the deployment.
+
+`inference_count`:::
+(integer)
+The sum of `inference_count` for all nodes in the deployment.
+
+`inference_threads`:::
+(integer)
+The number of threads used by the inference process.
+
 `model_id`:::
 (string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 
+`model_threads`:::
+(integer)
+The number of threads used when sending inference requests to the model.
+
 `nodes`:::
 (array of objects)
 The deployment stats for each node that currently has the model allocated.
@@ -127,14 +143,30 @@ The deployment stats for each node that currently has the model allocated.
 (double)
 The average time for each inference call to complete on this node.
 
+`error_count`:::
+(integer)
+The number of errors when evaluating the trained model.
+
 `inference_count`:::
 (integer)
 The total number of inference calls made against this node for this model.
 
+`inference_threads`:::
+(integer)
+The number of threads used by the inference process.
+This value is limited by the number of hardware threads on the node;
+it might therefore differ from the `inference_threads` value in the <<start-trained-model-deployment>> API.
+
 `last_access`:::
 (long)
 The epoch time stamp of the last inference call for the model on this node.
 
+`model_threads`:::
+(integer)
+The number of threads used when sending inference requests to the model.
+This value is limited by the number of hardware threads on the node;
+it might therefore differ from the `model_threads` value in the <<start-trained-model-deployment>> API.
+
 `node`:::
 (object)
 Information pertaining to the node.
@@ -162,14 +194,24 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-id]
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=node-transport-address]
 ========
 
-`reason`:::
-(string)
-The reason for the current state. Usually only populated when the `routing_state` is `failed`.
+`number_of_pending_requests`:::
+(integer)
+The number of inference requests queued to be processed.
 
 `routing_state`:::
 (object)
 The current routing state and reason for the current routing state for this allocation.
 +
+.Properties of routing_state
+[%collapsible%open]
+========
+`reason`:::
+(string)
+The reason for the current state. Usually only populated when the `routing_state` is `failed`.
+
+`routing_state`:::
+(string)
+The current routing state.
 --
 * `starting`: The model is attempting to allocate on this model, inference calls are not yet accepted.
 * `started`: The model is allocated and ready to accept inference requests.
@@ -177,13 +219,34 @@ The current routing state and reason for the current routing state for this allo
 * `stopped`: The model is fully deallocated from this node.
 * `failed`: The allocation attempt failed, see `reason` field for the potential cause.
 --
+========
+
+`rejected_execution_count`:::
+(integer)
+The number of inference requests that were not processed because the
+queue was full.
 
 `start_time`:::
 (long)
 The epoch timestamp when the allocation started.
 
+`timeout_count`:::
+(integer)
+The number of inference requests that timed out before being processed.
 ======
 
+`rejected_execution_count`:::
+(integer)
+The sum of `rejected_execution_count` for all nodes in the deployment.
+Individual nodes reject an inference request if the inference queue is full.
+The queue size is controlled by the `queue_capacity` setting in the
+<<start-trained-model-deployment>> API.
+
+`reason`:::
+(string)
+The reason for the current deployment state.
+Usually only populated when the model is not deployed to a node.
+
 `start_time`:::
 (long)
 The epoch timestamp when the deployment started.
@@ -198,6 +261,15 @@ The overall state of the deployment. The values may be:
 * `stopping`: The deployment is preparing to stop and deallocate the model from the relevant nodes.
 --
 
+`timeout_count`:::
+(integer)
+The sum of `timeout_count` for all nodes in the deployment.
+
+`queue_capacity`:::
+(integer)
+The number of inference requests that may be queued before new requests are
+rejected.
+
 =====
 
 `inference_stats`:::

+ 89 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStats.java

@@ -31,6 +31,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
         private final Double avgInferenceTime;
         private final Instant lastAccess;
         private final Integer pendingCount;
+        private final int errorCount;
+        private final int rejectedExecutionCount;
+        private final int timeoutCount;
         private final RoutingStateAndReason routingState;
         private final Instant startTime;
         private final Integer inferenceThreads;
@@ -41,6 +44,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
             long inferenceCount,
             Double avgInferenceTime,
             int pendingCount,
+            int errorCount,
+            int rejectedExecutionCount,
+            int timeoutCount,
             Instant lastAccess,
             Instant startTime,
             Integer inferenceThreads,
@@ -52,6 +58,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
                 avgInferenceTime,
                 lastAccess,
                 pendingCount,
+                errorCount,
+                rejectedExecutionCount,
+                timeoutCount,
                 new RoutingStateAndReason(RoutingState.STARTED, null),
                 Objects.requireNonNull(startTime),
                 inferenceThreads,
@@ -60,7 +69,20 @@ public class AllocationStats implements ToXContentObject, Writeable {
         }
 
         public static AllocationStats.NodeStats forNotStartedState(DiscoveryNode node, RoutingState state, String reason) {
-            return new AllocationStats.NodeStats(node, null, null, null, null, new RoutingStateAndReason(state, reason), null, null, null);
+            return new AllocationStats.NodeStats(
+                node,
+                null,
+                null,
+                null,
+                null,
+                0,
+                0,
+                0,
+                new RoutingStateAndReason(state, reason),
+                null,
+                null,
+                null
+            );
         }
 
         public NodeStats(
@@ -69,6 +91,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
             Double avgInferenceTime,
             Instant lastAccess,
             Integer pendingCount,
+            int errorCount,
+            int rejectedExecutionCount,
+            int timeoutCount,
             RoutingStateAndReason routingState,
             @Nullable Instant startTime,
             @Nullable Integer inferenceThreads,
@@ -79,6 +104,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
             this.avgInferenceTime = avgInferenceTime;
             this.lastAccess = lastAccess;
             this.pendingCount = pendingCount;
+            this.errorCount = errorCount;
+            this.rejectedExecutionCount = rejectedExecutionCount;
+            this.timeoutCount = timeoutCount;
             this.routingState = routingState;
             this.startTime = startTime;
             this.inferenceThreads = inferenceThreads;
@@ -96,13 +124,18 @@ public class AllocationStats implements ToXContentObject, Writeable {
             this.pendingCount = in.readOptionalVInt();
             this.routingState = in.readOptionalWriteable(RoutingStateAndReason::new);
             this.startTime = in.readOptionalInstant();
-
             if (in.getVersion().onOrAfter(Version.V_8_1_0)) {
                 this.inferenceThreads = in.readOptionalVInt();
                 this.modelThreads = in.readOptionalVInt();
+                this.errorCount = in.readVInt();
+                this.rejectedExecutionCount = in.readVInt();
+                this.timeoutCount = in.readVInt();
             } else {
                 this.inferenceThreads = null;
                 this.modelThreads = null;
+                this.errorCount = 0;
+                this.rejectedExecutionCount = 0;
+                this.timeoutCount = 0;
             }
         }
 
@@ -130,6 +163,18 @@ public class AllocationStats implements ToXContentObject, Writeable {
             return pendingCount;
         }
 
+        public int getErrorCount() {
+            return errorCount;
+        }
+
+        public int getRejectedExecutionCount() {
+            return rejectedExecutionCount;
+        }
+
+        public int getTimeoutCount() {
+            return timeoutCount;
+        }
+
         public Instant getStartTime() {
             return startTime;
         }
@@ -146,7 +191,8 @@ public class AllocationStats implements ToXContentObject, Writeable {
             if (inferenceCount != null) {
                 builder.field("inference_count", inferenceCount);
             }
-            if (avgInferenceTime != null) {
+            // avoid reporting the average time as 0 if count < 1
+            if (avgInferenceTime != null && (inferenceCount != null && inferenceCount > 0)) {
                 builder.field("average_inference_time_ms", avgInferenceTime);
             }
             if (lastAccess != null) {
@@ -155,6 +201,15 @@ public class AllocationStats implements ToXContentObject, Writeable {
             if (pendingCount != null) {
                 builder.field("number_of_pending_requests", pendingCount);
             }
+            if (errorCount > 0) {
+                builder.field("error_count", errorCount);
+            }
+            if (rejectedExecutionCount > 0) {
+                builder.field("rejected_execution_count", rejectedExecutionCount);
+            }
+            if (timeoutCount > 0) {
+                builder.field("timeout_count", timeoutCount);
+            }
             if (startTime != null) {
                 builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
             }
@@ -180,6 +235,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
             if (out.getVersion().onOrAfter(Version.V_8_1_0)) {
                 out.writeOptionalVInt(inferenceThreads);
                 out.writeOptionalVInt(modelThreads);
+                out.writeVInt(errorCount);
+                out.writeVInt(rejectedExecutionCount);
+                out.writeVInt(timeoutCount);
             }
         }
 
@@ -193,6 +251,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
                 && Objects.equals(node, that.node)
                 && Objects.equals(lastAccess, that.lastAccess)
                 && Objects.equals(pendingCount, that.pendingCount)
+                && Objects.equals(errorCount, that.errorCount)
+                && Objects.equals(rejectedExecutionCount, that.rejectedExecutionCount)
+                && Objects.equals(timeoutCount, that.timeoutCount)
                 && Objects.equals(routingState, that.routingState)
                 && Objects.equals(startTime, that.startTime)
                 && Objects.equals(inferenceThreads, that.inferenceThreads)
@@ -207,6 +268,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
                 avgInferenceTime,
                 lastAccess,
                 pendingCount,
+                errorCount,
+                rejectedExecutionCount,
+                timeoutCount,
                 routingState,
                 startTime,
                 inferenceThreads,
@@ -331,6 +395,28 @@ public class AllocationStats implements ToXContentObject, Writeable {
             builder.field("allocation_status", allocationStatus);
         }
         builder.timeField("start_time", "start_time_string", startTime.toEpochMilli());
+
+        int totalErrorCount = nodeStats.stream().mapToInt(NodeStats::getErrorCount).sum();
+        int totalRejectedExecutionCount = nodeStats.stream().mapToInt(NodeStats::getRejectedExecutionCount).sum();
+        int totalTimeoutCount = nodeStats.stream().mapToInt(NodeStats::getTimeoutCount).sum();
+        long totalInferenceCount = nodeStats.stream()
+            .filter(n -> n.getInferenceCount().isPresent())
+            .mapToLong(n -> n.getInferenceCount().get())
+            .sum();
+
+        if (totalErrorCount > 0) {
+            builder.field("error_count", totalErrorCount);
+        }
+        if (totalRejectedExecutionCount > 0) {
+            builder.field("rejected_execution_count", totalRejectedExecutionCount);
+        }
+        if (totalTimeoutCount > 0) {
+            builder.field("timeout_count", totalTimeoutCount);
+        }
+        if (totalInferenceCount > 0) {
+            builder.field("inference_count", totalInferenceCount);
+        }
+
         builder.startArray("nodes");
         for (AllocationStats.NodeStats nodeStat : nodeStats) {
             nodeStat.toXContent(builder, params);

+ 3 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java

@@ -126,6 +126,9 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                                     nodeStats.getAvgInferenceTime().orElse(null),
                                                     nodeStats.getLastAccess(),
                                                     nodeStats.getPendingCount(),
+                                                    0,
+                                                    0,
+                                                    0,
                                                     nodeStats.getRoutingState(),
                                                     nodeStats.getStartTime(),
                                                     null,

+ 3 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatsTests.java

@@ -57,6 +57,9 @@ public class AllocationStatsTests extends AbstractWireSerializingTestCase<Alloca
             randomNonNegativeLong(),
             randomBoolean() ? randomDoubleBetween(0.0, 100.0, true) : null,
             randomIntBetween(0, 100),
+            randomIntBetween(0, 100),
+            randomIntBetween(0, 100),
+            randomIntBetween(0, 100),
             Instant.now(),
             Instant.now(),
             randomIntBetween(1, 16),

+ 4 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -297,9 +297,11 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                 AllocationStats.NodeStats.forStartedState(
                     clusterService.localNode(),
                     stats.get().timingStats().getCount(),
-                    // avoid reporting the average time as 0 if count < 1
-                    (stats.get().timingStats().getCount() > 0) ? stats.get().timingStats().getAverage() : null,
+                    stats.get().timingStats().getAverage(),
                     stats.get().pendingCount(),
+                    stats.get().errorCount(),
+                    stats.get().rejectedExecutionCount(),
+                    stats.get().timeoutCount(),
                     stats.get().lastUsed(),
                     stats.get().startTime(),
                     stats.get().inferenceThreads(),

+ 43 - 13
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -18,6 +18,7 @@ import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.query.IdsQueryBuilder;
@@ -63,6 +64,7 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Consumer;
 
@@ -101,25 +103,32 @@ public class DeploymentManager {
     }
 
     public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
-        return Optional.ofNullable(processContextByAllocation.get(task.getId()))
-            .map(
-                processContext -> new ModelStats(
-                    processContext.startTime,
-                    processContext.getResultProcessor().getTimingStats(),
-                    processContext.getResultProcessor().getLastUsed(),
-                    processContext.executorService.queueSize() + processContext.getResultProcessor().numberOfPendingResults(),
-                    processContext.inferenceThreads,
-                    processContext.modelThreads
-                )
+        return Optional.ofNullable(processContextByAllocation.get(task.getId())).map(processContext -> {
+            var stats = processContext.getResultProcessor().getResultStats();
+            return new ModelStats(
+                processContext.startTime,
+                stats.timingStats(),
+                stats.lastUsed(),
+                processContext.executorService.queueSize() + stats.numberOfPendingResults(),
+                stats.errorCount(),
+                processContext.rejectedExecutionCount.intValue(),
+                processContext.timeoutCount.intValue(),
+                processContext.inferenceThreads,
+                processContext.modelThreads
             );
+        });
+    }
+
+    // function exposed for testing
+    ProcessContext addProcessContext(Long id, ProcessContext processContext) {
+        return processContextByAllocation.putIfAbsent(id, processContext);
     }
 
     private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
         logger.info("[{}] Starting model deployment", task.getModelId());
 
         ProcessContext processContext = new ProcessContext(task, executorServiceForProcess);
-
-        if (processContextByAllocation.putIfAbsent(task.getId(), processContext) != null) {
+        if (addProcessContext(task.getId(), processContext) != null) {
             finalListener.onFailure(
                 ExceptionsHelper.serverError("[{}] Could not create inference process as one already exists", task.getModelId())
             );
@@ -259,7 +268,10 @@ public class DeploymentManager {
             listener
         );
         try {
-            processContext.executorService.execute(inferenceAction);
+            processContext.getExecutorService().execute(inferenceAction);
+        } catch (EsRejectedExecutionException e) {
+            processContext.getRejectedExecutionCount().incrementAndGet();
+            inferenceAction.onFailure(e);
         } catch (Exception e) {
             inferenceAction.onFailure(e);
         }
@@ -302,6 +314,7 @@ public class DeploymentManager {
 
         void onTimeout() {
             if (notified.compareAndSet(false, true)) {
+                processContext.getTimeoutCount().incrementAndGet();
                 processContext.getResultProcessor().ignoreResponseWithoutNotifying(String.valueOf(requestId));
                 listener.onFailure(
                     new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.REQUEST_TIMEOUT, timeout)
@@ -435,6 +448,8 @@ public class DeploymentManager {
         private volatile Instant startTime;
         private volatile Integer inferenceThreads;
         private volatile Integer modelThreads;
+        private AtomicInteger rejectedExecutionCount = new AtomicInteger();
+        private AtomicInteger timeoutCount = new AtomicInteger();
 
         ProcessContext(TrainedModelDeploymentTask task, ExecutorService executorService) {
             this.task = Objects.requireNonNull(task);
@@ -492,5 +507,20 @@ public class DeploymentManager {
                 throw new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]");
             }
         }
+
+        // accessor used for mocking in tests
+        AtomicInteger getTimeoutCount() {
+            return timeoutCount;
+        }
+
+        // accessor used for mocking in tests
+        ExecutorService getExecutorService() {
+            return executorService;
+        }
+
+        // accessor used for mocking in tests
+        AtomicInteger getRejectedExecutionCount() {
+            return rejectedExecutionCount;
+        }
     }
 }

+ 3 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ModelStats.java

@@ -15,6 +15,9 @@ public record ModelStats(
     LongSummaryStatistics timingStats,
     Instant lastUsed,
     int pendingCount,
+    int errorCount,
+    int rejectedExecutionCount,
+    int timeoutCount,
     Integer inferenceThreads,
     Integer modelThreads
 ) {}

+ 18 - 14
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -25,6 +25,8 @@ import java.util.function.Consumer;
 
 public class PyTorchResultProcessor {
 
+    public record ResultStats(LongSummaryStatistics timingStats, int errorCount, int numberOfPendingResults, Instant lastUsed) {}
+
     private static final Logger logger = LogManager.getLogger(PyTorchResultProcessor.class);
 
     private final ConcurrentMap<String, PendingResult> pendingResults = new ConcurrentHashMap<>();
@@ -33,6 +35,7 @@ public class PyTorchResultProcessor {
     private volatile boolean isStopping;
     private final LongSummaryStatistics timingStats;
     private final Consumer<ThreadSettings> threadSettingsConsumer;
+    private int errorCount;
     private Instant lastUsed;
 
     public PyTorchResultProcessor(String deploymentId, Consumer<ThreadSettings> threadSettingsConsumer) {
@@ -100,12 +103,7 @@ public class PyTorchResultProcessor {
 
     private void processInferenceResult(PyTorchInferenceResult inferenceResult) {
         logger.trace(() -> new ParameterizedMessage("[{}] Parsed result with id [{}]", deploymentId, inferenceResult.getRequestId()));
-        if (inferenceResult.isError() == false) {
-            synchronized (this) {
-                timingStats.accept(inferenceResult.getTimeMs());
-                lastUsed = Instant.now();
-            }
-        }
+        processResult(inferenceResult);
         PendingResult pendingResult = pendingResults.remove(inferenceResult.getRequestId());
         if (pendingResult == null) {
             logger.debug(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, inferenceResult.getRequestId()));
@@ -114,16 +112,22 @@ public class PyTorchResultProcessor {
         }
     }
 
-    public synchronized LongSummaryStatistics getTimingStats() {
-        return new LongSummaryStatistics(timingStats.getCount(), timingStats.getMin(), timingStats.getMax(), timingStats.getSum());
+    public synchronized ResultStats getResultStats() {
+        return new ResultStats(
+            new LongSummaryStatistics(timingStats.getCount(), timingStats.getMin(), timingStats.getMax(), timingStats.getSum()),
+            errorCount,
+            pendingResults.size(),
+            lastUsed
+        );
     }
 
-    public synchronized Instant getLastUsed() {
-        return lastUsed;
-    }
-
-    public int numberOfPendingResults() {
-        return pendingResults.size();
+    private synchronized void processResult(PyTorchInferenceResult result) {
+        if (result.isError() == false) {
+            timingStats.accept(result.getTimeMs());
+            lastUsed = Instant.now();
+        } else {
+            errorCount++;
+        }
     }
 
     public void stop() {

+ 1 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java

@@ -9,13 +9,11 @@ package org.elasticsearch.xpack.ml.job.process;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
-import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.core.SuppressForbidden;
-import org.elasticsearch.rest.RestStatus;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -104,11 +102,7 @@ public class ProcessWorkerExecutorService extends AbstractExecutorService {
 
         boolean added = queue.offer(contextHolder.preserveContext(command));
         if (added == false) {
-            throw new ElasticsearchStatusException(
-                processName + " queue is full. Unable to execute command",
-                RestStatus.TOO_MANY_REQUESTS,
-                processName
-            );
+            throw new EsRejectedExecutionException(processName + " queue is full. Unable to execute command", false);
         }
     }
 

+ 7 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java

@@ -378,6 +378,9 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                         5,
                                         42.0,
                                         0,
+                                        1,
+                                        2,
+                                        3,
                                         Instant.now(),
                                         Instant.now(),
                                         randomIntBetween(1, 16),
@@ -388,6 +391,9 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                                         4,
                                         50.0,
                                         0,
+                                        1,
+                                        2,
+                                        3,
                                         Instant.now(),
                                         Instant.now(),
                                         randomIntBetween(1, 16),
@@ -485,6 +491,7 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             assertThat(source.getValue("jobs.opened.forecasts.total"), equalTo(11));
             assertThat(source.getValue("jobs.opened.forecasts.forecasted_jobs"), equalTo(2));
 
+            // TODO error_count here???
             assertThat(source.getValue("inference.trained_models._all.count"), equalTo(4));
             assertThat(source.getValue("inference.trained_models.model_size_bytes.min"), equalTo(100.0));
             assertThat(source.getValue("inference.trained_models.model_size_bytes.max"), equalTo(300.0));

+ 50 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java

@@ -8,23 +8,33 @@
 package org.elasticsearch.xpack.ml.inference.deployment;
 
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ScalingExecutorBuilder;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
+import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
 import org.junit.After;
 import org.junit.Before;
 
 import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.elasticsearch.xpack.ml.MachineLearning.JOB_COMMS_THREAD_POOL_NAME;
 import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -61,9 +71,11 @@ public class DeploymentManagerTests extends ESTestCase {
     }
 
     public void testInferListenerOnlyCalledOnce() {
-        PyTorchResultProcessor resultProcessor = new PyTorchResultProcessor("1", threadSettings -> {});
         DeploymentManager.ProcessContext processContext = mock(DeploymentManager.ProcessContext.class);
+        PyTorchResultProcessor resultProcessor = new PyTorchResultProcessor("1", threadSettings -> {});
         when(processContext.getResultProcessor()).thenReturn(resultProcessor);
+        AtomicInteger timeoutCount = new AtomicInteger();
+        when(processContext.getTimeoutCount()).thenReturn(timeoutCount);
 
         ListenerCounter listener = new ListenerCounter();
         DeploymentManager.InferenceAction action = new DeploymentManager.InferenceAction(
@@ -105,6 +117,7 @@ public class DeploymentManagerTests extends ESTestCase {
         }
         assertThat(listener.failureCounts, equalTo(1));
         assertThat(listener.responseCounts, equalTo(1));
+        assertThat(timeoutCount.intValue(), equalTo(1));
 
         action = new DeploymentManager.InferenceAction(
             "test-model",
@@ -127,6 +140,42 @@ public class DeploymentManagerTests extends ESTestCase {
         assertThat(listener.responseCounts, equalTo(1));
     }
 
+    public void testRejectedExecution() {
+        TrainedModelDeploymentTask task = mock(TrainedModelDeploymentTask.class);
+        Long taskId = 1L;
+        when(task.getId()).thenReturn(taskId);
+        when(task.isStopped()).thenReturn(Boolean.FALSE);
+
+        DeploymentManager deploymentManager = new DeploymentManager(
+            mock(Client.class),
+            mock(NamedXContentRegistry.class),
+            tp,
+            mock(PyTorchProcessFactory.class)
+        );
+
+        ExecutorService executorService = mock(ExecutorService.class);
+        doThrow(new EsRejectedExecutionException("mock executor rejection")).when(executorService).execute(any(Runnable.class));
+
+        AtomicInteger rejectedCount = new AtomicInteger();
+
+        DeploymentManager.ProcessContext context = mock(DeploymentManager.ProcessContext.class);
+        PyTorchResultProcessor resultProcessor = new PyTorchResultProcessor("1", threadSettings -> {});
+        when(context.getResultProcessor()).thenReturn(resultProcessor);
+        when(context.getExecutorService()).thenReturn(executorService);
+        when(context.getRejectedExecutionCount()).thenReturn(rejectedCount);
+
+        deploymentManager.addProcessContext(taskId, context);
+        deploymentManager.infer(
+            task,
+            mock(InferenceConfig.class),
+            Map.of(),
+            TimeValue.timeValueMinutes(1),
+            ActionListener.wrap(result -> fail("unexpected success"), e -> assertThat(e, instanceOf(EsRejectedExecutionException.class)))
+        );
+
+        assertThat(rejectedCount.intValue(), equalTo(1));
+    }
+
     private static class ListenerCounter implements ActionListener<InferenceResults> {
         private int responseCounts;
         private int failureCounts;