|
@@ -64,6 +64,7 @@ import java.util.function.Consumer;
|
|
|
import static org.elasticsearch.core.Strings.format;
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
|
|
+import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
|
|
|
|
|
|
public class DeploymentManager {
|
|
|
|
|
@@ -242,6 +243,8 @@ public class DeploymentManager {
|
|
|
try {
|
|
|
processContext.startProcess();
|
|
|
processContext.loadModel(modelLocation, ActionListener.wrap(success -> {
|
|
|
+ assert Thread.currentThread().getName().contains(UTILITY_THREAD_POOL_NAME)
|
|
|
+ : format("Must execute from [%s] but thread is [%s]", UTILITY_THREAD_POOL_NAME, Thread.currentThread().getName());
|
|
|
processContext.startPriorityProcessWorker();
|
|
|
loadedListener.onResponse(success);
|
|
|
}, loadedListener::onFailure));
|
|
@@ -401,9 +404,12 @@ public class DeploymentManager {
|
|
|
this.numThreadsPerAllocation = threadSettings.numThreadsPerAllocation();
|
|
|
this.numAllocations = threadSettings.numAllocations();
|
|
|
});
|
|
|
- // We want to use the utility thread pool to load the model and not one of the process
|
|
|
- // threads that are dedicated to processing done throughout the lifetime of the process.
|
|
|
- this.stateStreamer = new PyTorchStateStreamer(client, executorServiceForDeployment, xContentRegistry);
|
|
|
+ // We want to use the inference thread pool to load the model as it is a possibly long operation
|
|
|
+ // and knowing it is an inference thread would enable better understanding during debugging.
|
|
|
+ // Even though we account for 3 threads per process in the thread pool, loading the model
|
|
|
+ // happens before we start input/output so it should be ok to use a thread from that pool for loading
|
|
|
+ // the model.
|
|
|
+ this.stateStreamer = new PyTorchStateStreamer(client, executorServiceForProcess, xContentRegistry);
|
|
|
this.priorityProcessWorker = new PriorityProcessWorkerExecutorService(
|
|
|
threadPool.getThreadContext(),
|
|
|
"inference process",
|
|
@@ -458,7 +464,18 @@ public class DeploymentManager {
|
|
|
|
|
|
void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
|
|
|
if (modelLocation instanceof IndexLocation indexLocation) {
|
|
|
- process.get().loadModel(task.getModelId(), indexLocation.getIndexName(), stateStreamer, listener);
|
|
|
+ // Loading the model happens on the inference thread pool but when we get the callback
|
|
|
+ // we need to return to the utility thread pool to avoid leaking the thread we used.
|
|
|
+ process.get()
|
|
|
+ .loadModel(
|
|
|
+ task.getModelId(),
|
|
|
+ indexLocation.getIndexName(),
|
|
|
+ stateStreamer,
|
|
|
+ ActionListener.wrap(
|
|
|
+ r -> executorServiceForDeployment.submit(() -> listener.onResponse(r)),
|
|
|
+ e -> executorServiceForDeployment.submit(() -> listener.onFailure(e))
|
|
|
+ )
|
|
|
+ );
|
|
|
} else {
|
|
|
listener.onFailure(
|
|
|
new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]")
|