Просмотр исходного кода

[ML] handle timeouts in native inference correctly (#78757)

Native inference was not timing out on writing to the process, consequently, if the named pipe queue was full, the write action would simply block. 

This commit does the following:
 
- Spawns a scheduled future for timingout the native inference so the action can timeout on process writing and waiting for result
- Removes the task level timeout so we rely on either the node communication timingout OR our native action timeout
- Bubbles up 429 on timeout to the inference processor
Benjamin Trent 4 лет назад
Родитель
Сommit
7c45032530

+ 34 - 21
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentAction.java

@@ -14,13 +14,14 @@ import org.elasticsearch.action.support.tasks.BaseTasksResponse;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
-import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
@@ -46,6 +47,13 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
         super(NAME, InferTrainedModelDeploymentAction.Response::new);
     }
 
+    /**
+     * Request for inference against the deployment.
+     *
+     * The task gets routed to a node that indicates its local model allocation is started
+     *
+     * For indicating timeout, the caller should call `setInferenceTimeout` and not the base class `setTimeout` method
+     */
     public static class Request extends BaseTasksRequest<Request> {
 
         public static final ParseField DEPLOYMENT_ID = new ParseField("deployment_id");
@@ -59,7 +67,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
         static {
             PARSER.declareString(Request.Builder::setDeploymentId, DEPLOYMENT_ID);
             PARSER.declareObjectArray(Request.Builder::setDocs, (p, c) -> p.mapOrdered(), DOCS);
-            PARSER.declareString(Request.Builder::setTimeout, TIMEOUT);
+            PARSER.declareString(Request.Builder::setInferenceTimeout, TIMEOUT);
             PARSER.declareNamedObject(
                 Request.Builder::setUpdate,
                 ((p, c, name) -> p.namedObject(InferenceConfigUpdate.class, name, c)),
@@ -67,22 +75,24 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             );
         }
 
-        public static Request parseRequest(String deploymentId, XContentParser parser) {
+        public static Request.Builder parseRequest(String deploymentId, XContentParser parser) {
             Request.Builder builder = PARSER.apply(parser, null);
             if (deploymentId != null) {
                 builder.setDeploymentId(deploymentId);
             }
-            return builder.build();
+            return builder;
         }
 
         private final String deploymentId;
         private final List<Map<String, Object>> docs;
         private final InferenceConfigUpdate update;
+        private final TimeValue inferenceTimeout;
 
-        public Request(String deploymentId, InferenceConfigUpdate update, List<Map<String, Object>> docs) {
+        public Request(String deploymentId, InferenceConfigUpdate update, List<Map<String, Object>> docs, TimeValue inferenceTimeout) {
             this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, DEPLOYMENT_ID);
             this.docs = ExceptionsHelper.requireNonNull(Collections.unmodifiableList(docs), DOCS);
             this.update = update;
+            this.inferenceTimeout = inferenceTimeout;
         }
 
         public Request(StreamInput in) throws IOException {
@@ -90,6 +100,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             deploymentId = in.readString();
             docs = Collections.unmodifiableList(in.readList(StreamInput::readMap));
             update = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
+            inferenceTimeout = in.readOptionalTimeValue();
         }
 
         public String getDeploymentId() {
@@ -104,13 +115,18 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             return Optional.ofNullable(update).orElse(new EmptyConfigUpdate());
         }
 
+        public TimeValue getInferenceTimeout() {
+            return inferenceTimeout == null ? DEFAULT_TIMEOUT : inferenceTimeout;
+        }
+
+        /**
+         * This is always null as we want the inference call to handle the timeout, not the tasks framework
+         * @return null
+         */
         @Override
+        @Nullable
         public TimeValue getTimeout() {
-            TimeValue tv = super.getTimeout();
-            if (tv == null) {
-                return DEFAULT_TIMEOUT;
-            }
-            return tv;
+            return null;
         }
 
         @Override
@@ -139,6 +155,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             out.writeString(deploymentId);
             out.writeCollection(docs, StreamOutput::writeMap);
             out.writeOptionalNamedWriteable(update);
+            out.writeOptionalTimeValue(inferenceTimeout);
         }
 
         @Override
@@ -154,12 +171,12 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
             return Objects.equals(deploymentId, that.deploymentId)
                 && Objects.equals(docs, that.docs)
                 && Objects.equals(update, that.update)
-                && Objects.equals(getTimeout(), that.getTimeout());
+                && Objects.equals(inferenceTimeout, that.inferenceTimeout);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(deploymentId, update, docs, getTimeout());
+            return Objects.hash(deploymentId, update, docs, inferenceTimeout);
         }
 
         public static class Builder {
@@ -181,7 +198,7 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
                 return this;
             }
 
-            public Builder setTimeout(TimeValue timeout) {
+            public Builder setInferenceTimeout(TimeValue timeout) {
                 this.timeout = timeout;
                 return this;
             }
@@ -191,16 +208,12 @@ public class InferTrainedModelDeploymentAction extends ActionType<InferTrainedMo
                 return this;
             }
 
-            private Builder setTimeout(String timeout) {
-                return setTimeout(TimeValue.parseTimeValue(timeout, TIMEOUT.getPreferredName()));
+            private Builder setInferenceTimeout(String timeout) {
+                return setInferenceTimeout(TimeValue.parseTimeValue(timeout, TIMEOUT.getPreferredName()));
             }
 
             public Request build() {
-                Request request = new Request(deploymentId, update, docs);
-                if (timeout != null) {
-                    request.setTimeout(timeout);
-                }
-                return request;
+                return new Request(deploymentId, update, docs, timeout);
             }
         }
     }

+ 9 - 7
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferTrainedModelDeploymentRequestsTests.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.core.ml.action;
 
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
@@ -37,15 +38,12 @@ public class InferTrainedModelDeploymentRequestsTests extends AbstractWireSerial
         List<Map<String, Object>> docs = randomList(5, () -> randomMap(1, 3,
             () -> Tuple.tuple(randomAlphaOfLength(7), randomAlphaOfLength(7))));
 
-        InferTrainedModelDeploymentAction.Request request = new InferTrainedModelDeploymentAction.Request(
+        return new InferTrainedModelDeploymentAction.Request(
             randomAlphaOfLength(4),
             randomBoolean() ? null : randomInferenceConfigUpdate(),
-            docs
+            docs,
+            randomBoolean() ? null : TimeValue.parseTimeValue(randomTimeValue(), "timeout")
         );
-        if (randomBoolean()) {
-            request.setTimeout(randomTimeValue());
-        }
-        return request;
     }
 
     @Override
@@ -56,6 +54,10 @@ public class InferTrainedModelDeploymentRequestsTests extends AbstractWireSerial
     }
 
     public void testTimeoutNotNull() {
-        assertNotNull(createTestInstance().getTimeout());
+        assertNotNull(createTestInstance().getInferenceTimeout());
+    }
+
+    public void testTimeoutNull() {
+        assertNull(createTestInstance().getTimeout());
     }
 }

+ 20 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.test.SecuritySettingsSourceField;
 import org.elasticsearch.test.rest.ESRestTestCase;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
@@ -185,6 +186,17 @@ public class PyTorchModelIT extends ESRestTestCase {
         stopDeployment(modelId);
     }
 
+    public void testEvaluateWithMinimalTimeout() throws IOException {
+        String modelId = "test_evaluate_timeout";
+        createTrainedModel(modelId);
+        putModelDefinition(modelId);
+        putVocabulary(List.of("these", "are", "my", "words"), modelId);
+        startDeployment(modelId);
+        ResponseException ex = expectThrows(ResponseException.class, () -> infer("my words", modelId, TimeValue.ZERO));
+        assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(429));
+        stopDeployment(modelId);
+    }
+
     public void testDeleteFailureDueToDeployment() throws IOException {
         String modelId = "test_deployed_model_delete";
         createTrainedModel(modelId);
@@ -449,6 +461,14 @@ public class PyTorchModelIT extends ESRestTestCase {
         return client().performRequest(request);
     }
 
+    private Response infer(String input, String modelId, TimeValue timeout) throws IOException {
+        Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_infer?timeout=" + timeout.toString());
+        request.setJsonEntity("{  " +
+            "\"docs\": [{\"input\":\"" + input + "\"}]\n" +
+            "}");
+        return client().performRequest(request);
+    }
+
     private Response infer(String input, String modelId) throws IOException {
         Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_infer");
         request.setJsonEntity("{  " +

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

@@ -83,10 +83,11 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
         task.infer(
             request.getDocs().get(0),
             request.getUpdate(),
-            request.getTimeout(),
+            request.getInferenceTimeout(),
             ActionListener.wrap(
                 pyTorchResult -> listener.onResponse(new InferTrainedModelDeploymentAction.Response(pyTorchResult)),
-                listener::onFailure)
+                listener::onFailure
+            )
         );
     }
 }

+ 17 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -6,6 +6,8 @@
  */
 package org.elasticsearch.xpack.ml.action;
 
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
@@ -14,6 +16,7 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.XPackLicenseState;
+import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
@@ -158,10 +161,22 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
         executeAsyncWithOrigin(client,
             ML_ORIGIN,
             InferTrainedModelDeploymentAction.INSTANCE,
-            new InferTrainedModelDeploymentAction.Request(modelId, inferenceConfigUpdate, Collections.singletonList(doc)),
+            new InferTrainedModelDeploymentAction.Request(modelId, inferenceConfigUpdate, Collections.singletonList(doc), null),
             ActionListener.wrap(
                 r -> listener.onResponse(r.getResults()),
-                e -> listener.onResponse(new WarningInferenceResults(e.getMessage()))
+                e -> {
+                    Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
+                    if (unwrapped instanceof ElasticsearchStatusException) {
+                        ElasticsearchStatusException ex = (ElasticsearchStatusException) unwrapped;
+                        if (ex.status().equals(RestStatus.TOO_MANY_REQUESTS)) {
+                            listener.onFailure(ex);
+                        } else {
+                            listener.onResponse(new WarningInferenceResults(ex.getMessage()));
+                        }
+                    } else {
+                        listener.onResponse(new WarningInferenceResults(e.getMessage()));
+                    }
+                }
             )
         );
     }

+ 140 - 70
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -27,6 +27,7 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.query.IdsQueryBuilder;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.threadpool.Scheduler;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@@ -58,6 +59,7 @@ import java.util.Optional;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.ExecutorService;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.Consumer;
 
@@ -74,6 +76,7 @@ public class DeploymentManager {
     private final PyTorchProcessFactory pyTorchProcessFactory;
     private final ExecutorService executorServiceForDeployment;
     private final ExecutorService executorServiceForProcess;
+    private final ThreadPool threadPool;
     private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
 
     public DeploymentManager(Client client, NamedXContentRegistry xContentRegistry,
@@ -81,6 +84,7 @@ public class DeploymentManager {
         this.client = Objects.requireNonNull(client);
         this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
         this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory);
+        this.threadPool = Objects.requireNonNull(threadPool);
         this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
         this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
     }
@@ -92,8 +96,8 @@ public class DeploymentManager {
     public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
         return Optional.ofNullable(processContextByAllocation.get(task.getId()))
             .map(processContext ->
-                new ModelStats(processContext.resultProcessor.getTimingStats(),
-                    processContext.resultProcessor.getLastUsed())
+                new ModelStats(processContext.getResultProcessor().getTimingStats(),
+                    processContext.getResultProcessor().getLastUsed())
             );
     }
 
@@ -117,7 +121,7 @@ public class DeploymentManager {
 
         ActionListener<Boolean> modelLoadedListener = ActionListener.wrap(
             success -> {
-                executorServiceForProcess.execute(() -> processContext.resultProcessor.process(processContext.process.get()));
+                executorServiceForProcess.execute(() -> processContext.getResultProcessor().process(processContext.process.get()));
                 listener.onResponse(task);
             },
             listener::onFailure
@@ -226,83 +230,145 @@ public class DeploymentManager {
             return;
         }
 
-        final String requestId = String.valueOf(requestIdCounter.getAndIncrement());
+        final long requestId = requestIdCounter.getAndIncrement();
+        executorServiceForProcess.execute(new InferenceAction(requestId, timeout, processContext, config, doc, threadPool, listener));
+    }
 
-        executorServiceForProcess.execute(new AbstractRunnable() {
-            @Override
-            public void onFailure(Exception e) {
-                listener.onFailure(e);
-            }
+    static class InferenceAction extends AbstractRunnable {
+        private final long requestId;
+        private final TimeValue timeout;
+        private final Scheduler.Cancellable timeoutHandler;
+        private final ProcessContext processContext;
+        private final InferenceConfig config;
+        private final Map<String, Object> doc;
+        private final ActionListener<InferenceResults> listener;
+        private final AtomicBoolean notified = new AtomicBoolean();
+
+        InferenceAction(
+            long requestId,
+            TimeValue timeout,
+            ProcessContext processContext,
+            InferenceConfig config,
+            Map<String, Object> doc,
+            ThreadPool threadPool,
+            ActionListener<InferenceResults> listener
+        ) {
+            this.requestId = requestId;
+            this.timeout = timeout;
+            this.processContext = processContext;
+            this.config = config;
+            this.doc = doc;
+            this.listener = listener;
+            this.timeoutHandler = threadPool.schedule(
+                this::onTimeout,
+                ExceptionsHelper.requireNonNull(timeout, "timeout"),
+                MachineLearning.UTILITY_THREAD_POOL_NAME
+            );
+        }
 
-            @Override
-            protected void doRun() {
-                try {
-                    // The request builder expect a list of inputs which are then batched.
-                    // TODO batching was implemented for expected use-cases such as zero-shot
-                    // classification but is not used here.
-                    List<String> text = Collections.singletonList(NlpTask.extractInput(processContext.modelInput.get(), doc));
-                    NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
-                    processor.validateInputs(text);
-                    assert config instanceof NlpConfig;
-                    NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestId);
-                    logger.trace(() -> "Inference Request "+ request.processInput.utf8ToString());
-                    PyTorchResultProcessor.PendingResult pendingResult = processContext.resultProcessor.registerRequest(requestId);
-                    processContext.process.get().writeInferenceRequest(request.processInput);
-                    waitForResult(
-                        processContext,
-                        pendingResult,
-                        request.tokenization,
-                        requestId,
-                        timeout,
-                        processor.getResultProcessor((NlpConfig) config),
-                        listener
-                    );
-                } catch (IOException e) {
-                    logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
-                    onFailure(ExceptionsHelper.serverError("error writing to process", e));
-                } catch (Exception e) {
-                    onFailure(e);
-                } finally {
-                    processContext.resultProcessor.requestAccepted(requestId);
-                }
+        void onTimeout() {
+            if (notified.compareAndSet(false, true)) {
+                processContext.getResultProcessor().requestIgnored(String.valueOf(requestId));
+                listener.onFailure(
+                    new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.TOO_MANY_REQUESTS, timeout)
+                );
+                return;
             }
-        });
-    }
+            logger.debug("request [{}] received timeout after [{}] but listener already alerted", requestId, timeout);
+        }
 
-    private void waitForResult(ProcessContext processContext,
-                               PyTorchResultProcessor.PendingResult pendingResult,
-                               TokenizationResult tokenization,
-                               String requestId,
-                               TimeValue timeout,
-                               NlpTask.ResultProcessor inferenceResultsProcessor,
-                               ActionListener<InferenceResults> listener) {
-        try {
-            PyTorchResult pyTorchResult = processContext.resultProcessor.waitForResult(
-                processContext.process.get(),
-                requestId,
-                pendingResult,
-                timeout
-            );
-            if (pyTorchResult == null) {
-                listener.onFailure(new ElasticsearchStatusException("timeout [{}] waiting for inference result",
-                    RestStatus.TOO_MANY_REQUESTS, timeout));
+        void onSuccess(InferenceResults inferenceResults) {
+            timeoutHandler.cancel();
+            if (notified.compareAndSet(false, true)) {
+                listener.onResponse(inferenceResults);
                 return;
             }
+            logger.debug("request [{}] received inference response but listener already notified", requestId);
+        }
 
-            if (pyTorchResult.isError()) {
-                listener.onFailure(new ElasticsearchStatusException(pyTorchResult.getError(),
-                    RestStatus.INTERNAL_SERVER_ERROR));
+        @Override
+        public void onFailure(Exception e) {
+            timeoutHandler.cancel();
+            if (notified.compareAndSet(false, true)) {
+                listener.onFailure(e);
                 return;
             }
+            logger.debug(
+                () -> new ParameterizedMessage("request [{}] received failure but listener already notified", requestId),
+                e
+            );
+        }
 
-            logger.debug(() -> new ParameterizedMessage(
-                "[{}] retrieved result for request [{}]", processContext.task.getModelId(), requestId));
-            InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult);
-            logger.debug(() -> new ParameterizedMessage(
-                "[{}] processed result for request [{}]", processContext.task.getModelId(), requestId));
-            listener.onResponse(results);
-        } catch (InterruptedException e) {
-            listener.onFailure(e);
+        @Override
+        protected void doRun() throws Exception {
+            final String requestIdStr = String.valueOf(requestId);
+            try {
+                // The request builder expect a list of inputs which are then batched.
+                // TODO batching was implemented for expected use-cases such as zero-shot
+                // classification but is not used here.
+                List<String> text = Collections.singletonList(NlpTask.extractInput(processContext.modelInput.get(), doc));
+                NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
+                processor.validateInputs(text);
+                assert config instanceof NlpConfig;
+                NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestIdStr);
+                logger.trace(() -> "Inference Request "+ request.processInput.utf8ToString());
+                PyTorchResultProcessor.PendingResult pendingResult = processContext.getResultProcessor().registerRequest(requestIdStr);
+                processContext.process.get().writeInferenceRequest(request.processInput);
+                waitForResult(
+                    processContext,
+                    pendingResult,
+                    request.tokenization,
+                    requestIdStr,
+                    timeout,
+                    processor.getResultProcessor((NlpConfig) config),
+                    ActionListener.wrap(this::onSuccess,this::onFailure)
+                );
+            } catch (IOException e) {
+                logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
+                onFailure(ExceptionsHelper.serverError("error writing to process", e));
+            } catch (Exception e) {
+                onFailure(e);
+            } finally {
+                processContext.getResultProcessor().requestIgnored(String.valueOf(requestId));
+            }
+        }
+
+        private void waitForResult(ProcessContext processContext,
+                                   PyTorchResultProcessor.PendingResult pendingResult,
+                                   TokenizationResult tokenization,
+                                   String requestId,
+                                   TimeValue timeout,
+                                   NlpTask.ResultProcessor inferenceResultsProcessor,
+                                   ActionListener<InferenceResults> listener) {
+            try {
+                PyTorchResult pyTorchResult = processContext.getResultProcessor().waitForResult(
+                    processContext.process.get(),
+                    requestId,
+                    pendingResult,
+                    timeout
+                );
+                if (pyTorchResult == null) {
+                    listener.onFailure(
+                        new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.TOO_MANY_REQUESTS, timeout)
+                    );
+                    return;
+                }
+
+                if (pyTorchResult.isError()) {
+                    listener.onFailure(new ElasticsearchStatusException(pyTorchResult.getError(),
+                        RestStatus.INTERNAL_SERVER_ERROR));
+                    return;
+                }
+
+                logger.debug(() -> new ParameterizedMessage(
+                    "[{}] retrieved result for request [{}]", processContext.task.getModelId(), requestId));
+                InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult);
+                logger.debug(() -> new ParameterizedMessage(
+                    "[{}] processed result for request [{}]", processContext.task.getModelId(), requestId));
+                listener.onResponse(results);
+            } catch (InterruptedException e) {
+                listener.onFailure(e);
+            }
         }
     }
 
@@ -321,6 +387,10 @@ public class DeploymentManager {
             this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
         }
 
+        PyTorchResultProcessor getResultProcessor() {
+            return resultProcessor;
+        }
+
         synchronized void startProcess() {
             process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, onProcessCrash()));
         }

+ 12 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -44,8 +44,16 @@ public class PyTorchResultProcessor {
         return pendingResults.computeIfAbsent(requestId, k -> new PendingResult());
     }
 
-    public void requestAccepted(String requestId) {
-        pendingResults.remove(requestId);
+    /**
+     * Call this method when the caller is no longer waiting on the request response.
+     *
+     * @param requestId The request ID that is no longer being waited on
+     */
+    public void requestIgnored(String requestId) {
+        PendingResult pendingResult = pendingResults.remove(requestId);
+        if (pendingResult != null) {
+            pendingResult.latch.countDown();
+        }
     }
 
     public void process(NativePyTorchProcess process) {
@@ -55,9 +63,9 @@ public class PyTorchResultProcessor {
                 PyTorchResult result = iterator.next();
                 logger.trace(() -> new ParameterizedMessage("[{}] Parsed result with id [{}]", deploymentId, result.getRequestId()));
                 processResult(result);
-                PendingResult pendingResult = pendingResults.get(result.getRequestId());
+                PendingResult pendingResult = pendingResults.remove(result.getRequestId());
                 if (pendingResult == null) {
-                    logger.warn(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId()));
+                    logger.debug(() -> new ParameterizedMessage("[{}] no pending result for [{}]", deploymentId, result.getRequestId()));
                 } else {
                     pendingResult.result.set(result);
                     pendingResult.latch.countDown();

+ 7 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestInferTrainedModelDeploymentAction.java

@@ -45,15 +45,19 @@ public class RestInferTrainedModelDeploymentAction extends BaseRestHandler {
         if (restRequest.hasContent() == false) {
             throw ExceptionsHelper.badRequestException("requires body");
         }
-        InferTrainedModelDeploymentAction.Request request =
+        InferTrainedModelDeploymentAction.Request.Builder request =
             InferTrainedModelDeploymentAction.Request.parseRequest(deploymentId, restRequest.contentParser());
 
         if (restRequest.hasParam(InferTrainedModelDeploymentAction.Request.TIMEOUT.getPreferredName())) {
             TimeValue inferTimeout = restRequest.paramAsTime(InferTrainedModelDeploymentAction.Request.TIMEOUT.getPreferredName(),
                 InferTrainedModelDeploymentAction.Request.DEFAULT_TIMEOUT);
-            request.setTimeout(inferTimeout);
+            request.setInferenceTimeout(inferTimeout);
         }
 
-        return channel -> client.execute(InferTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));
+        return channel -> client.execute(
+            InferTrainedModelDeploymentAction.INSTANCE,
+            request.build(),
+            new RestToXContentListener<>(channel)
+        );
     }
 }

+ 137 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java

@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.deployment;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.ScalingExecutorBuilder;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
+import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcess;
+import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
+import org.junit.After;
+import org.junit.Before;
+
+import java.util.Map;
+
+import static org.elasticsearch.xpack.ml.MachineLearning.JOB_COMMS_THREAD_POOL_NAME;
+import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
+import static org.hamcrest.Matchers.equalTo;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class DeploymentManagerTests extends ESTestCase {
+
+    private DeploymentManager deploymentManager;
+    private ThreadPool tp;
+
+    @Before
+    public void managerSetup() {
+        tp = new TestThreadPool(
+            "DeploymentManagerTests",
+            new ScalingExecutorBuilder(UTILITY_THREAD_POOL_NAME,1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.utility_thread_pool"),
+            new ScalingExecutorBuilder(JOB_COMMS_THREAD_POOL_NAME,1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.job_comms_thread_pool")
+        );
+        deploymentManager = new DeploymentManager(
+            mock(Client.class),
+            xContentRegistry(),
+            tp,
+            (task, executorService, onProcessCrash) -> mock(NativePyTorchProcess.class)
+        );
+    }
+
+    @After
+    public void shutdownThreadpool() {
+        tp.shutdown();
+    }
+
+    public void testInferListenerOnlyCalledOnce() {
+        PyTorchResultProcessor resultProcessor = new PyTorchResultProcessor("1");
+        DeploymentManager.ProcessContext processContext = mock(DeploymentManager.ProcessContext.class);
+        when(processContext.getResultProcessor()).thenReturn(resultProcessor);
+
+        ListenerCounter listener = new ListenerCounter();
+        DeploymentManager.InferenceAction action = new DeploymentManager.InferenceAction(
+            1,
+            TimeValue.MAX_VALUE,
+            processContext,
+            new PassThroughConfig(null, null, null),
+            Map.of(),
+            tp,
+            listener
+        );
+
+        action.onSuccess(new WarningInferenceResults("foo"));
+        for (int i = 0; i < 10; i++) {
+            action.onSuccess(new WarningInferenceResults("foo"));
+            action.onFailure(new Exception("foo"));
+            action.onTimeout();
+        }
+        assertThat(listener.failureCounts, equalTo(0));
+        assertThat(listener.responseCounts, equalTo(1));
+
+        action = new DeploymentManager.InferenceAction(
+            1,
+            TimeValue.MAX_VALUE,
+            processContext,
+            new PassThroughConfig(null, null, null),
+            Map.of(),
+            tp,
+            listener
+        );
+
+        action.onTimeout();
+        for (int i = 0; i < 10; i++) {
+            action.onSuccess(new WarningInferenceResults("foo"));
+            action.onFailure(new Exception("foo"));
+            action.onTimeout();
+        }
+        assertThat(listener.failureCounts, equalTo(1));
+        assertThat(listener.responseCounts, equalTo(1));
+
+        action = new DeploymentManager.InferenceAction(
+            1,
+            TimeValue.MAX_VALUE,
+            processContext,
+            new PassThroughConfig(null, null, null),
+            Map.of(),
+            tp,
+            listener
+        );
+
+        action.onFailure(new Exception("bar"));
+        for (int i = 0; i < 10; i++) {
+            action.onSuccess(new WarningInferenceResults("foo"));
+            action.onFailure(new Exception("foo"));
+            action.onTimeout();
+        }
+        assertThat(listener.failureCounts, equalTo(2));
+        assertThat(listener.responseCounts, equalTo(1));
+    }
+
+    private static class ListenerCounter implements ActionListener<InferenceResults> {
+        private int responseCounts;
+        private int failureCounts;
+
+        @Override
+        public void onResponse(InferenceResults inferenceResults) {
+            responseCounts++;
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            failureCounts++;
+        }
+    }
+
+}