소스 검색

[ML] System auditor notifications for pytorch restarts (#105411)

* Add systemAudit logs to process restarts

* Updated error message for system audit / notifications

* Added system audit message for not restarting pytorch process

* Switch to inferenceAuditor and update error message
Max Hniebergall 1 년 전
부모
커밋
05d2375e61

+ 8 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -1151,7 +1151,14 @@ public class MachineLearning extends Plugin
         );
 
         this.deploymentManager.set(
-            new DeploymentManager(client, xContentRegistry, threadPool, pyTorchProcessFactory, getMaxModelDeploymentsPerNode())
+            new DeploymentManager(
+                client,
+                xContentRegistry,
+                threadPool,
+                pyTorchProcessFactory,
+                getMaxModelDeploymentsPerNode(),
+                inferenceAuditor
+            )
         );
 
         // Data frame analytics components

+ 40 - 14
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -53,6 +53,7 @@ import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactor
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;
+import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 
 import java.io.IOException;
 import java.time.Duration;
@@ -85,6 +86,7 @@ public class DeploymentManager {
     private final ExecutorService executorServiceForDeployment;
     private final ExecutorService executorServiceForProcess;
     private final ThreadPool threadPool;
+    private final InferenceAuditor inferenceAuditor;
     private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
     private final int maxProcesses;
 
@@ -93,12 +95,14 @@ public class DeploymentManager {
         NamedXContentRegistry xContentRegistry,
         ThreadPool threadPool,
         PyTorchProcessFactory pyTorchProcessFactory,
-        int maxProcesses
+        int maxProcesses,
+        InferenceAuditor inferenceAuditor
     ) {
         this.client = Objects.requireNonNull(client);
         this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
         this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory);
         this.threadPool = Objects.requireNonNull(threadPool);
+        this.inferenceAuditor = Objects.requireNonNull(inferenceAuditor);
         this.executorServiceForDeployment = threadPool.executor(UTILITY_THREAD_POOL_NAME);
         this.executorServiceForProcess = threadPool.executor(MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME);
         this.maxProcesses = maxProcesses;
@@ -523,7 +527,7 @@ public class DeploymentManager {
                     task,
                     executorServiceForProcess,
                     () -> resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES),
-                    onProcessCrashHandleRestarts(startsCount)
+                    onProcessCrashHandleRestarts(startsCount, task.getDeploymentId())
                 )
             );
             startTime = Instant.now();
@@ -546,16 +550,19 @@ public class DeploymentManager {
             }
         }
 
-        private Consumer<String> onProcessCrashHandleRestarts(AtomicInteger startsCount) {
+        private Consumer<String> onProcessCrashHandleRestarts(AtomicInteger startsCount, String deploymentId) {
             return (reason) -> {
                 if (isThisProcessOlderThan1Day()) {
                     startsCount.set(1);
-                    logger.error(
-                        "[{}] inference process crashed due to reason [{}]. This process was started more than 24 hours ago; "
-                            + "the starts count is reset to 1.",
-                        task.getDeploymentId(),
-                        reason
-                    );
+                    {
+                        String logMessage = "["
+                            + task.getDeploymentId()
+                            + "] inference process crashed due to reason ["
+                            + reason
+                            + "]. This process was started more than 24 hours ago; "
+                            + "the starts count is reset to 1.";
+                        logger.error(logMessage);
+                    }
                 } else {
                     logger.error("[{}] inference process crashed due to reason [{}]", task.getDeploymentId(), reason);
                 }
@@ -566,20 +573,32 @@ public class DeploymentManager {
                 stateStreamer.cancel();
 
                 if (startsCount.get() <= NUM_RESTART_ATTEMPTS) {
-                    logger.info("[{}] restarting inference process after [{}] starts", task.getDeploymentId(), startsCount.get());
+                    {
+                        String logAndAuditMessage = "Inference process ["
+                            + task.getDeploymentId()
+                            + "] failed due to ["
+                            + reason
+                            + "]. This is the ["
+                            + startsCount.get()
+                            + "] failure in 24 hours, and the process will be restarted.";
+                        logger.info(logAndAuditMessage);
+                        threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
+                            .execute(() -> inferenceAuditor.warning(deploymentId, logAndAuditMessage));
+                    }
                     priorityProcessWorker.shutdownNow(); // TODO what to do with these tasks?
                     ActionListener<TrainedModelDeploymentTask> errorListener = ActionListener.wrap((trainedModelDeploymentTask -> {
                         logger.debug("Completed restart of inference process, the [{}] start", startsCount);
                     }),
                         (e) -> finishClosingProcess(
                             startsCount,
-                            "Failed to restart inference process because of error [" + e.getMessage() + "]"
+                            "Failed to restart inference process because of error [" + e.getMessage() + "]",
+                            deploymentId
                         )
                     );
 
                     startDeployment(task, startsCount.incrementAndGet(), errorListener);
                 } else {
-                    finishClosingProcess(startsCount, reason);
+                    finishClosingProcess(startsCount, reason, deploymentId);
                 }
             };
         }
@@ -588,8 +607,15 @@ public class DeploymentManager {
             return startTime.isBefore(Instant.now().minus(Duration.ofDays(1)));
         }
 
-        private void finishClosingProcess(AtomicInteger startsCount, String reason) {
-            logger.warn("[{}] inference process failed after [{}] starts, not restarting again", task.getDeploymentId(), startsCount.get());
+        private void finishClosingProcess(AtomicInteger startsCount, String reason, String deploymentId) {
+            String logAndAuditMessage = "["
+                + task.getDeploymentId()
+                + "] inference process failed after ["
+                + startsCount.get()
+                + "] starts in 24 hours, not restarting again.";
+            logger.warn(logAndAuditMessage);
+            threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
+                .execute(() -> inferenceAuditor.error(deploymentId, logAndAuditMessage));
             priorityProcessWorker.shutdownNowWithError(new IllegalStateException(reason));
             if (nlpTaskProcessor.get() != null) {
                 nlpTaskProcessor.get().close();

+ 5 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java

@@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
+import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 import org.junit.After;
 import org.junit.Before;
 
@@ -36,6 +37,7 @@ import static org.mockito.Mockito.when;
 public class DeploymentManagerTests extends ESTestCase {
 
     private ThreadPool tp;
+    private InferenceAuditor inferenceAuditor;
 
     @Before
     public void managerSetup() {
@@ -58,6 +60,7 @@ public class DeploymentManagerTests extends ESTestCase {
                 "xpack.ml.native_inference_comms_thread_pool"
             )
         );
+        inferenceAuditor = mock(InferenceAuditor.class);
     }
 
     @After
@@ -78,7 +81,8 @@ public class DeploymentManagerTests extends ESTestCase {
             mock(NamedXContentRegistry.class),
             tp,
             mock(PyTorchProcessFactory.class),
-            10
+            10,
+            inferenceAuditor
         );
 
         PriorityProcessWorkerExecutorService priorityExecutorService = new PriorityProcessWorkerExecutorService(