Browse Source

[ML] Load pytorch models from a thread of the inference pool (#91882)

In #91661 we fixed a bug where we could leak a thread of the inference
thread pool after loading the model through the callbacks. This could
cause exhaustion of that pool and thus a deployment could hang whilst
waiting for a thread in vain. The fix was to do the loading on a thread
of the utility thread pool.

However, for debugging purposes it is nicer to do the loading on a
thread of the inference thread pool as it would make it easier to
understand that such a thread belongs to a pytorch model being
loaded.

This commit changes model loading to occur back on a thread of the
inference thread pool but this time we take care that in the callback
we return on a thread of the utility thread pool.
Dimitris Athanasiou 2 years ago
parent
commit
b456653fb9

+ 21 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -64,6 +64,7 @@ import java.util.function.Consumer;
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
+import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
 
 public class DeploymentManager {
 
@@ -242,6 +243,8 @@ public class DeploymentManager {
         try {
             processContext.startProcess();
             processContext.loadModel(modelLocation, ActionListener.wrap(success -> {
+                assert Thread.currentThread().getName().contains(UTILITY_THREAD_POOL_NAME)
+                    : format("Must execute from [%s] but thread is [%s]", UTILITY_THREAD_POOL_NAME, Thread.currentThread().getName());
                 processContext.startPriorityProcessWorker();
                 loadedListener.onResponse(success);
             }, loadedListener::onFailure));
@@ -401,9 +404,12 @@ public class DeploymentManager {
                 this.numThreadsPerAllocation = threadSettings.numThreadsPerAllocation();
                 this.numAllocations = threadSettings.numAllocations();
             });
-            // We want to use the utility thread pool to load the model and not one of the process
-            // threads that are dedicated to processing done throughout the lifetime of the process.
-            this.stateStreamer = new PyTorchStateStreamer(client, executorServiceForDeployment, xContentRegistry);
+            // We want to use the inference thread pool to load the model as it is a possibly long operation
+            // and knowing it is an inference thread would enable better understanding during debugging.
+            // Even though we account for 3 threads per process in the thread pool, loading the model
+            // happens before we start input/output so it should be ok to use a thread from that pool for loading
+            // the model.
+            this.stateStreamer = new PyTorchStateStreamer(client, executorServiceForProcess, xContentRegistry);
             this.priorityProcessWorker = new PriorityProcessWorkerExecutorService(
                 threadPool.getThreadContext(),
                 "inference process",
@@ -458,7 +464,18 @@ public class DeploymentManager {
 
         void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
             if (modelLocation instanceof IndexLocation indexLocation) {
-                process.get().loadModel(task.getModelId(), indexLocation.getIndexName(), stateStreamer, listener);
+                // Loading the model happens on the inference thread pool but when we get the callback
+                // we need to return to the utility thread pool to avoid leaking the thread we used.
+                process.get()
+                    .loadModel(
+                        task.getModelId(),
+                        indexLocation.getIndexName(),
+                        stateStreamer,
+                        ActionListener.wrap(
+                            r -> executorServiceForDeployment.submit(() -> listener.onResponse(r)),
+                            e -> executorServiceForDeployment.submit(() -> listener.onFailure(e))
+                        )
+                    );
             } else {
                 listener.onFailure(
                     new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]")