Ver código fonte

[ML] Ensure inference queue is cleared after shutdown (#96738)

David Kyle 2 anos atrás
pai
commit
307425f406

+ 5 - 0
docs/changelog/96738.yaml

@@ -0,0 +1,5 @@
+pr: 96738
+summary: Ensure NLP model inference queue is always cleared after shutdown or failure
+area: Machine Learning
+type: bug
+issues: []

+ 8 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -446,7 +446,14 @@ public class DeploymentManager {
             isStopped = true;
             resultProcessor.stop();
             stateStreamer.cancel();
-            priorityProcessWorker.shutdown();
+
+            if (priorityProcessWorker.isShutdown()) {
+                // most likely there was a crash or exception that caused the
+                // thread to stop. Notify any waiting requests in the work queue
+                priorityProcessWorker.notifyQueueRunnables();
+            } else {
+                priorityProcessWorker.shutdown();
+            }
             killProcessIfPresent();
             if (nlpTaskProcessor.get() != null) {
                 nlpTaskProcessor.get().close();

+ 5 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/PriorityProcessWorkerExecutorService.java

@@ -81,6 +81,7 @@ public class PriorityProcessWorkerExecutorService extends AbstractProcessWorkerE
         if (isShutdown()) {
             EsRejectedExecutionException rejected = new EsRejectedExecutionException(processName + " worker service has shutdown", true);
             command.onRejection(rejected);
+            notifyQueueRunnables();
             return;
         }
 
@@ -93,6 +94,10 @@ public class PriorityProcessWorkerExecutorService extends AbstractProcessWorkerE
 
         // PriorityBlockingQueue::offer always returns true
         queue.offer(new OrderedRunnable(priority, tieBreaker, contextHolder.preserveContext(command)));
+        if (isShutdown()) {
+            // the worker shutdown during this function
+            notifyQueueRunnables();
+        }
     }
 
     @Override

+ 13 - 23
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -121,35 +121,25 @@ public class PyTorchResultProcessor {
             if (isStopping == false) {
                 logger.error(() -> "[" + modelId + "] Error processing results", e);
             }
-            pendingResults.forEach(
-                (id, pendingResult) -> pendingResult.listener.onResponse(
-                    new PyTorchResult(
-                        id,
-                        null,
-                        null,
-                        null,
-                        null,
-                        null,
-                        new ErrorResult(
-                            isStopping
-                                ? "inference canceled as process is stopping"
-                                : "inference native process died unexpectedly with failure [" + e.getMessage() + "]"
-                        )
-                    )
-                )
+            var errorResult = new ErrorResult(
+                isStopping
+                    ? "inference canceled as process is stopping"
+                    : "inference native process died unexpectedly with failure [" + e.getMessage() + "]"
             );
-            pendingResults.clear();
+            notifyAndClearPendingResults(errorResult);
         } finally {
-            pendingResults.forEach(
-                (id, pendingResult) -> pendingResult.listener.onResponse(
-                    new PyTorchResult(id, false, null, null, null, null, new ErrorResult("inference canceled as process is stopping"))
-                )
-            );
-            pendingResults.clear();
+            notifyAndClearPendingResults(new ErrorResult("inference canceled as process is stopping"));
         }
         logger.debug(() -> "[" + modelId + "] Results processing finished");
     }
 
+    private void notifyAndClearPendingResults(ErrorResult errorResult) {
+        pendingResults.forEach(
+            (id, pendingResult) -> pendingResult.listener.onResponse(new PyTorchResult(id, null, null, null, null, null, errorResult))
+        );
+        pendingResults.clear();
+    }
+
     void processInferenceResult(PyTorchResult result) {
         PyTorchInferenceResult inferenceResult = result.inferenceResult();
         assert inferenceResult != null;

+ 31 - 19
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/AbstractProcessWorkerExecutorService.java

@@ -109,29 +109,41 @@ public abstract class AbstractProcessWorkerExecutorService<T extends Runnable> e
                 }
             }
 
-            synchronized (this) {
-                // if shutdown with tasks pending notify the handlers
-                if (queue.isEmpty() == false) {
-                    List<Runnable> notExecuted = new ArrayList<>();
-                    queue.drainTo(notExecuted);
-
-                    String msg = "unable to process as " + processName + " worker service has shutdown";
-                    Exception ex = error.get();
-                    for (Runnable runnable : notExecuted) {
-                        if (runnable instanceof AbstractRunnable ar) {
-                            if (ex != null) {
-                                ar.onFailure(ex);
-                            } else {
-                                ar.onRejection(new EsRejectedExecutionException(msg, true));
-                            }
-                        }
-                    }
-                }
-            }
+            notifyQueueRunnables();
         } catch (InterruptedException e) {
             Thread.currentThread().interrupt();
         } finally {
             awaitTermination.countDown();
         }
     }
+
+    /**
+     * Drains the queue of runnables and notifies each as either
+     * a rejected execution or failure.
+     *
+     * Although public this method should be used with caution.
+     * It should only be called _after_ the worker has shutdown.
+     *
+     * The method is synchronised to protect concurrent calls.
+     */
+    public synchronized void notifyQueueRunnables() {
+        assert isShutdown() : "Queue runnables should only be drained and notified after the worker is shutdown";
+
+        if (queue.isEmpty() == false) {
+            List<Runnable> notExecuted = new ArrayList<>();
+            queue.drainTo(notExecuted);
+
+            String msg = "unable to process as " + processName + " worker service has shutdown";
+            Exception ex = error.get();
+            for (Runnable runnable : notExecuted) {
+                if (runnable instanceof AbstractRunnable ar) {
+                    if (ex != null) {
+                        ar.onFailure(ex);
+                    } else {
+                        ar.onRejection(new EsRejectedExecutionException(msg, true));
+                    }
+                }
+            }
+        }
+    }
 }