Ver Fonte

[ML] Free resources correctly when model loading is cancelled (#92204)

When a model deployment is stopped during the loading process
or the loading fails all resources must be freed properly. In 
particular this fixes a bug where a worker thread is be started but 
may not be stopped.
David Kyle há 2 anos atrás
pai
commit
3bc652ba45

+ 5 - 0
docs/changelog/92204.yaml

@@ -0,0 +1,5 @@
+pr: 92204
+summary: Free resources correctly when model loading is cancelled
+area: Machine Learning
+type: bug
+issues: []

+ 1 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java

@@ -57,7 +57,7 @@ public abstract class PyTorchModelRestTestCase extends ESRestTestCase {
         Request loggingSettings = new Request("PUT", "_cluster/settings");
         loggingSettings.setJsonEntity("""
             {"persistent" : {
-                    "logger.org.elasticsearch.xpack.ml.inference.assignment" : "DEBUG",
+                    "logger.org.elasticsearch.xpack.ml.inference.assignment" : "TRACE",
                     "logger.org.elasticsearch.xpack.ml.inference.deployment" : "DEBUG",
                     "logger.org.elasticsearch.xpack.ml.inference.pytorch" : "DEBUG",
                     "logger.org.elasticsearch.xpack.ml.process.logging" : "DEBUG"

+ 88 - 61
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -17,6 +17,7 @@ import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
+import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.query.IdsQueryBuilder;
@@ -91,15 +92,11 @@ public class DeploymentManager {
         this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
         this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory);
         this.threadPool = Objects.requireNonNull(threadPool);
-        this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
+        this.executorServiceForDeployment = threadPool.executor(UTILITY_THREAD_POOL_NAME);
         this.executorServiceForProcess = threadPool.executor(MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME);
         this.maxProcesses = maxProcesses;
     }
 
-    public void startDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> listener) {
-        doStartDeployment(task, listener);
-    }
-
     public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
         return Optional.ofNullable(processContextByAllocation.get(task.getId())).map(processContext -> {
             var stats = processContext.getResultProcessor().getResultStats();
@@ -127,19 +124,23 @@ public class DeploymentManager {
 
     // function exposed for testing
     ProcessContext addProcessContext(Long id, ProcessContext processContext) {
-        if (processContextByAllocation.size() >= maxProcesses) {
-            throw ExceptionsHelper.serverError(
-                "[{}] Could not start inference process as the node reached the max number [{}] of processes",
-                processContext.task.getModelId(),
-                maxProcesses
-            );
-        }
         return processContextByAllocation.putIfAbsent(id, processContext);
     }
 
-    private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
+    public void startDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
         logger.info("[{}] Starting model deployment", task.getModelId());
 
+        if (processContextByAllocation.size() >= maxProcesses) {
+            finalListener.onFailure(
+                ExceptionsHelper.serverError(
+                    "[{}] Could not start inference process as the node reached the max number [{}] of processes",
+                    task.getModelId(),
+                    maxProcesses
+                )
+            );
+            return;
+        }
+
         ProcessContext processContext = new ProcessContext(task);
         if (addProcessContext(task.getId(), processContext) != null) {
             finalListener.onFailure(
@@ -148,15 +149,18 @@ public class DeploymentManager {
             return;
         }
 
-        ActionListener<TrainedModelDeploymentTask> listener = ActionListener.wrap(finalListener::onResponse, failure -> {
-            processContextByAllocation.remove(task.getId());
+        ActionListener<TrainedModelDeploymentTask> failedDeploymentListener = ActionListener.wrap(finalListener::onResponse, failure -> {
+            ProcessContext failedContext = processContextByAllocation.remove(task.getId());
+            if (failedContext != null) {
+                failedContext.stopProcess();
+            }
             finalListener.onFailure(failure);
         });
 
         ActionListener<Boolean> modelLoadedListener = ActionListener.wrap(success -> {
             executorServiceForProcess.execute(() -> processContext.getResultProcessor().process(processContext.process.get()));
-            listener.onResponse(task);
-        }, listener::onFailure);
+            finalListener.onResponse(task);
+        }, failedDeploymentListener::onFailure);
 
         ActionListener<GetTrainedModelsAction.Response> getModelListener = ActionListener.wrap(getModelResponse -> {
             assert getModelResponse.getResources().results().size() == 1;
@@ -169,7 +173,7 @@ public class DeploymentManager {
                 SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
                 executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> {
                     if (searchVocabResponse.getHits().getHits().length == 0) {
-                        listener.onFailure(
+                        failedDeploymentListener.onFailure(
                             new ResourceNotFoundException(
                                 Messages.getMessage(
                                     Messages.VOCABULARY_NOT_FOUND,
@@ -188,12 +192,10 @@ public class DeploymentManager {
                     // here, we are being called back on the searching thread, which MAY be a network thread
                     // `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility
                     // executor.
-                    executorServiceForDeployment.execute(
-                        () -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener)
-                    );
-                }, listener::onFailure));
+                    executorServiceForDeployment.execute(() -> processContext.startAndLoad(modelConfig.getLocation(), modelLoadedListener));
+                }, failedDeploymentListener::onFailure));
             } else {
-                listener.onFailure(
+                failedDeploymentListener.onFailure(
                     new IllegalArgumentException(
                         format(
                             "[%s] must be a pytorch model; found inference config of kind [%s]",
@@ -203,7 +205,7 @@ public class DeploymentManager {
                     )
                 );
             }
-        }, listener::onFailure);
+        }, failedDeploymentListener::onFailure);
 
         executeAsyncWithOrigin(
             client,
@@ -239,25 +241,8 @@ public class DeploymentManager {
         }
     }
 
-    private void startAndLoad(ProcessContext processContext, TrainedModelLocation modelLocation, ActionListener<Boolean> loadedListener) {
-        try {
-            processContext.startProcess();
-            processContext.loadModel(modelLocation, ActionListener.wrap(success -> {
-                assert Thread.currentThread().getName().contains(UTILITY_THREAD_POOL_NAME)
-                    : format("Must execute from [%s] but thread is [%s]", UTILITY_THREAD_POOL_NAME, Thread.currentThread().getName());
-                processContext.startPriorityProcessWorker();
-                loadedListener.onResponse(success);
-            }, loadedListener::onFailure));
-        } catch (Exception e) {
-            loadedListener.onFailure(e);
-        }
-    }
-
     public void stopDeployment(TrainedModelDeploymentTask task) {
-        ProcessContext processContext;
-        synchronized (processContextByAllocation) {
-            processContext = processContextByAllocation.get(task.getId());
-        }
+        ProcessContext processContext = processContextByAllocation.remove(task.getId());
         if (processContext != null) {
             logger.info("[{}] Stopping deployment, reason [{}]", task.getModelId(), task.stoppedReason().orElse("unknown"));
             processContext.stopProcess();
@@ -275,6 +260,13 @@ public class DeploymentManager {
         Task parentActionTask,
         ActionListener<InferenceResults> listener
     ) {
+        assert ((EsThreadPoolExecutor) executorServiceForProcess).getPoolSize() % 3 == 0
+            : "Thread pool size ["
+                + ((EsThreadPoolExecutor) executorServiceForProcess).getPoolSize()
+                + "] should be a multiple of 3. Num contexts = ["
+                + processContextByAllocation.size()
+                + "]";
+
         var processContext = getProcessContext(task, listener::onFailure);
         if (processContext == null) {
             // error reporting handled in the call to getProcessContext
@@ -397,6 +389,7 @@ public class DeploymentManager {
         private volatile Integer numAllocations;
         private final AtomicInteger rejectedExecutionCount = new AtomicInteger();
         private final AtomicInteger timeoutCount = new AtomicInteger();
+        private volatile boolean isStopped;
 
         ProcessContext(TrainedModelDeploymentTask task) {
             this.task = Objects.requireNonNull(task);
@@ -421,9 +414,36 @@ public class DeploymentManager {
             return resultProcessor;
         }
 
-        synchronized void startProcess() {
-            process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, onProcessCrash()));
+        synchronized void startAndLoad(TrainedModelLocation modelLocation, ActionListener<Boolean> loadedListener) {
+            assert Thread.currentThread().getName().contains(UTILITY_THREAD_POOL_NAME)
+                : format("Must execute from [%s] but thread is [%s]", UTILITY_THREAD_POOL_NAME, Thread.currentThread().getName());
+
+            if (isStopped) {
+                logger.debug("[{}] model stopped before it is started", task.getModelId());
+                loadedListener.onFailure(new IllegalArgumentException("model stopped before it is started"));
+                return;
+            }
+
+            logger.debug("[{}] start and load", task.getModelId());
+            process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, this::onProcessCrash));
             startTime = Instant.now();
+            logger.debug("[{}] process started", task.getModelId());
+            try {
+                loadModel(modelLocation, ActionListener.wrap(success -> {
+                    if (isStopped) {
+                        logger.debug("[{}] model loaded but process is stopped", task.getModelId());
+                        killProcessIfPresent();
+                        loadedListener.onFailure(new IllegalStateException("model loaded but process is stopped"));
+                        return;
+                    }
+
+                    logger.debug("[{}] model loaded, starting priority process worker thread", task.getModelId());
+                    startPriorityProcessWorker();
+                    loadedListener.onResponse(success);
+                }, loadedListener::onFailure));
+            } catch (Exception e) {
+                loadedListener.onFailure(e);
+            }
         }
 
         void startPriorityProcessWorker() {
@@ -431,38 +451,45 @@ public class DeploymentManager {
         }
 
         synchronized void stopProcess() {
+            isStopped = true;
             resultProcessor.stop();
+            stateStreamer.cancel();
             priorityProcessWorker.shutdown();
+            killProcessIfPresent();
+            if (nlpTaskProcessor.get() != null) {
+                nlpTaskProcessor.get().close();
+            }
+        }
+
+        private void killProcessIfPresent() {
             try {
                 if (process.get() == null) {
                     return;
                 }
-                stateStreamer.cancel();
                 process.get().kill(true);
-                processContextByAllocation.remove(task.getId());
             } catch (IOException e) {
                 logger.error(() -> "[" + task.getModelId() + "] Failed to kill process", e);
-            } finally {
-                if (nlpTaskProcessor.get() != null) {
-                    nlpTaskProcessor.get().close();
-                }
             }
         }
 
-        private Consumer<String> onProcessCrash() {
-            return reason -> {
-                logger.error("[{}] inference process crashed due to reason [{}]", task.getModelId(), reason);
-                resultProcessor.stop();
-                priorityProcessWorker.shutdownWithError(new IllegalStateException(reason));
-                processContextByAllocation.remove(task.getId());
-                if (nlpTaskProcessor.get() != null) {
-                    nlpTaskProcessor.get().close();
-                }
-                task.setFailed("inference process crashed due to reason [" + reason + "]");
-            };
+        private void onProcessCrash(String reason) {
+            logger.error("[{}] inference process crashed due to reason [{}]", task.getModelId(), reason);
+            processContextByAllocation.remove(task.getId());
+            isStopped = true;
+            resultProcessor.stop();
+            stateStreamer.cancel();
+            priorityProcessWorker.shutdownWithError(new IllegalStateException(reason));
+            if (nlpTaskProcessor.get() != null) {
+                nlpTaskProcessor.get().close();
+            }
+            task.setFailed("inference process crashed due to reason [" + reason + "]");
         }
 
         void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
+            if (isStopped) {
+                listener.onFailure(new IllegalArgumentException("Process has stopped, model loading canceled"));
+                return;
+            }
             if (modelLocation instanceof IndexLocation indexLocation) {
                 // Loading the model happens on the inference thread pool but when we get the callback
                 // we need to return to the utility thread pool to avoid leaking the thread we used.