|
@@ -17,6 +17,7 @@ import org.elasticsearch.action.search.SearchRequest;
|
|
|
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
|
|
import org.elasticsearch.client.internal.Client;
|
|
|
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
|
|
|
+import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
|
|
|
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
|
|
import org.elasticsearch.core.TimeValue;
|
|
|
import org.elasticsearch.index.query.IdsQueryBuilder;
|
|
@@ -91,15 +92,11 @@ public class DeploymentManager {
|
|
|
this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
|
|
|
this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory);
|
|
|
this.threadPool = Objects.requireNonNull(threadPool);
|
|
|
- this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
|
|
|
+ this.executorServiceForDeployment = threadPool.executor(UTILITY_THREAD_POOL_NAME);
|
|
|
this.executorServiceForProcess = threadPool.executor(MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME);
|
|
|
this.maxProcesses = maxProcesses;
|
|
|
}
|
|
|
|
|
|
- public void startDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> listener) {
|
|
|
- doStartDeployment(task, listener);
|
|
|
- }
|
|
|
-
|
|
|
public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
|
|
|
return Optional.ofNullable(processContextByAllocation.get(task.getId())).map(processContext -> {
|
|
|
var stats = processContext.getResultProcessor().getResultStats();
|
|
@@ -127,19 +124,23 @@ public class DeploymentManager {
|
|
|
|
|
|
// function exposed for testing
|
|
|
ProcessContext addProcessContext(Long id, ProcessContext processContext) {
|
|
|
- if (processContextByAllocation.size() >= maxProcesses) {
|
|
|
- throw ExceptionsHelper.serverError(
|
|
|
- "[{}] Could not start inference process as the node reached the max number [{}] of processes",
|
|
|
- processContext.task.getModelId(),
|
|
|
- maxProcesses
|
|
|
- );
|
|
|
- }
|
|
|
return processContextByAllocation.putIfAbsent(id, processContext);
|
|
|
}
|
|
|
|
|
|
- private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
|
|
|
+ public void startDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
|
|
|
logger.info("[{}] Starting model deployment", task.getModelId());
|
|
|
|
|
|
+ if (processContextByAllocation.size() >= maxProcesses) {
|
|
|
+ finalListener.onFailure(
|
|
|
+ ExceptionsHelper.serverError(
|
|
|
+ "[{}] Could not start inference process as the node reached the max number [{}] of processes",
|
|
|
+ task.getModelId(),
|
|
|
+ maxProcesses
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
ProcessContext processContext = new ProcessContext(task);
|
|
|
if (addProcessContext(task.getId(), processContext) != null) {
|
|
|
finalListener.onFailure(
|
|
@@ -148,15 +149,18 @@ public class DeploymentManager {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- ActionListener<TrainedModelDeploymentTask> listener = ActionListener.wrap(finalListener::onResponse, failure -> {
|
|
|
- processContextByAllocation.remove(task.getId());
|
|
|
+ ActionListener<TrainedModelDeploymentTask> failedDeploymentListener = ActionListener.wrap(finalListener::onResponse, failure -> {
|
|
|
+ ProcessContext failedContext = processContextByAllocation.remove(task.getId());
|
|
|
+ if (failedContext != null) {
|
|
|
+ failedContext.stopProcess();
|
|
|
+ }
|
|
|
finalListener.onFailure(failure);
|
|
|
});
|
|
|
|
|
|
ActionListener<Boolean> modelLoadedListener = ActionListener.wrap(success -> {
|
|
|
executorServiceForProcess.execute(() -> processContext.getResultProcessor().process(processContext.process.get()));
|
|
|
- listener.onResponse(task);
|
|
|
- }, listener::onFailure);
|
|
|
+ finalListener.onResponse(task);
|
|
|
+ }, failedDeploymentListener::onFailure);
|
|
|
|
|
|
ActionListener<GetTrainedModelsAction.Response> getModelListener = ActionListener.wrap(getModelResponse -> {
|
|
|
assert getModelResponse.getResources().results().size() == 1;
|
|
@@ -169,7 +173,7 @@ public class DeploymentManager {
|
|
|
SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
|
|
|
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> {
|
|
|
if (searchVocabResponse.getHits().getHits().length == 0) {
|
|
|
- listener.onFailure(
|
|
|
+ failedDeploymentListener.onFailure(
|
|
|
new ResourceNotFoundException(
|
|
|
Messages.getMessage(
|
|
|
Messages.VOCABULARY_NOT_FOUND,
|
|
@@ -188,12 +192,10 @@ public class DeploymentManager {
|
|
|
// here, we are being called back on the searching thread, which MAY be a network thread
|
|
|
// `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility
|
|
|
// executor.
|
|
|
- executorServiceForDeployment.execute(
|
|
|
- () -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener)
|
|
|
- );
|
|
|
- }, listener::onFailure));
|
|
|
+ executorServiceForDeployment.execute(() -> processContext.startAndLoad(modelConfig.getLocation(), modelLoadedListener));
|
|
|
+ }, failedDeploymentListener::onFailure));
|
|
|
} else {
|
|
|
- listener.onFailure(
|
|
|
+ failedDeploymentListener.onFailure(
|
|
|
new IllegalArgumentException(
|
|
|
format(
|
|
|
"[%s] must be a pytorch model; found inference config of kind [%s]",
|
|
@@ -203,7 +205,7 @@ public class DeploymentManager {
|
|
|
)
|
|
|
);
|
|
|
}
|
|
|
- }, listener::onFailure);
|
|
|
+ }, failedDeploymentListener::onFailure);
|
|
|
|
|
|
executeAsyncWithOrigin(
|
|
|
client,
|
|
@@ -239,25 +241,8 @@ public class DeploymentManager {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private void startAndLoad(ProcessContext processContext, TrainedModelLocation modelLocation, ActionListener<Boolean> loadedListener) {
|
|
|
- try {
|
|
|
- processContext.startProcess();
|
|
|
- processContext.loadModel(modelLocation, ActionListener.wrap(success -> {
|
|
|
- assert Thread.currentThread().getName().contains(UTILITY_THREAD_POOL_NAME)
|
|
|
- : format("Must execute from [%s] but thread is [%s]", UTILITY_THREAD_POOL_NAME, Thread.currentThread().getName());
|
|
|
- processContext.startPriorityProcessWorker();
|
|
|
- loadedListener.onResponse(success);
|
|
|
- }, loadedListener::onFailure));
|
|
|
- } catch (Exception e) {
|
|
|
- loadedListener.onFailure(e);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
public void stopDeployment(TrainedModelDeploymentTask task) {
|
|
|
- ProcessContext processContext;
|
|
|
- synchronized (processContextByAllocation) {
|
|
|
- processContext = processContextByAllocation.get(task.getId());
|
|
|
- }
|
|
|
+ ProcessContext processContext = processContextByAllocation.remove(task.getId());
|
|
|
if (processContext != null) {
|
|
|
logger.info("[{}] Stopping deployment, reason [{}]", task.getModelId(), task.stoppedReason().orElse("unknown"));
|
|
|
processContext.stopProcess();
|
|
@@ -275,6 +260,13 @@ public class DeploymentManager {
|
|
|
Task parentActionTask,
|
|
|
ActionListener<InferenceResults> listener
|
|
|
) {
|
|
|
+ assert ((EsThreadPoolExecutor) executorServiceForProcess).getPoolSize() % 3 == 0
|
|
|
+ : "Thread pool size ["
|
|
|
+ + ((EsThreadPoolExecutor) executorServiceForProcess).getPoolSize()
|
|
|
+ + "] should be a multiple of 3. Num contexts = ["
|
|
|
+ + processContextByAllocation.size()
|
|
|
+ + "]";
|
|
|
+
|
|
|
var processContext = getProcessContext(task, listener::onFailure);
|
|
|
if (processContext == null) {
|
|
|
// error reporting handled in the call to getProcessContext
|
|
@@ -397,6 +389,7 @@ public class DeploymentManager {
|
|
|
private volatile Integer numAllocations;
|
|
|
private final AtomicInteger rejectedExecutionCount = new AtomicInteger();
|
|
|
private final AtomicInteger timeoutCount = new AtomicInteger();
|
|
|
+ private volatile boolean isStopped;
|
|
|
|
|
|
ProcessContext(TrainedModelDeploymentTask task) {
|
|
|
this.task = Objects.requireNonNull(task);
|
|
@@ -421,9 +414,36 @@ public class DeploymentManager {
|
|
|
return resultProcessor;
|
|
|
}
|
|
|
|
|
|
- synchronized void startProcess() {
|
|
|
- process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, onProcessCrash()));
|
|
|
+ synchronized void startAndLoad(TrainedModelLocation modelLocation, ActionListener<Boolean> loadedListener) {
|
|
|
+ assert Thread.currentThread().getName().contains(UTILITY_THREAD_POOL_NAME)
|
|
|
+ : format("Must execute from [%s] but thread is [%s]", UTILITY_THREAD_POOL_NAME, Thread.currentThread().getName());
|
|
|
+
|
|
|
+ if (isStopped) {
|
|
|
+ logger.debug("[{}] model stopped before it is started", task.getModelId());
|
|
|
+ loadedListener.onFailure(new IllegalArgumentException("model stopped before it is started"));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.debug("[{}] start and load", task.getModelId());
|
|
|
+ process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, this::onProcessCrash));
|
|
|
startTime = Instant.now();
|
|
|
+ logger.debug("[{}] process started", task.getModelId());
|
|
|
+ try {
|
|
|
+ loadModel(modelLocation, ActionListener.wrap(success -> {
|
|
|
+ if (isStopped) {
|
|
|
+ logger.debug("[{}] model loaded but process is stopped", task.getModelId());
|
|
|
+ killProcessIfPresent();
|
|
|
+ loadedListener.onFailure(new IllegalStateException("model loaded but process is stopped"));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.debug("[{}] model loaded, starting priority process worker thread", task.getModelId());
|
|
|
+ startPriorityProcessWorker();
|
|
|
+ loadedListener.onResponse(success);
|
|
|
+ }, loadedListener::onFailure));
|
|
|
+ } catch (Exception e) {
|
|
|
+ loadedListener.onFailure(e);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
void startPriorityProcessWorker() {
|
|
@@ -431,38 +451,45 @@ public class DeploymentManager {
|
|
|
}
|
|
|
|
|
|
synchronized void stopProcess() {
|
|
|
+ isStopped = true;
|
|
|
resultProcessor.stop();
|
|
|
+ stateStreamer.cancel();
|
|
|
priorityProcessWorker.shutdown();
|
|
|
+ killProcessIfPresent();
|
|
|
+ if (nlpTaskProcessor.get() != null) {
|
|
|
+ nlpTaskProcessor.get().close();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void killProcessIfPresent() {
|
|
|
try {
|
|
|
if (process.get() == null) {
|
|
|
return;
|
|
|
}
|
|
|
- stateStreamer.cancel();
|
|
|
process.get().kill(true);
|
|
|
- processContextByAllocation.remove(task.getId());
|
|
|
} catch (IOException e) {
|
|
|
logger.error(() -> "[" + task.getModelId() + "] Failed to kill process", e);
|
|
|
- } finally {
|
|
|
- if (nlpTaskProcessor.get() != null) {
|
|
|
- nlpTaskProcessor.get().close();
|
|
|
- }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private Consumer<String> onProcessCrash() {
|
|
|
- return reason -> {
|
|
|
- logger.error("[{}] inference process crashed due to reason [{}]", task.getModelId(), reason);
|
|
|
- resultProcessor.stop();
|
|
|
- priorityProcessWorker.shutdownWithError(new IllegalStateException(reason));
|
|
|
- processContextByAllocation.remove(task.getId());
|
|
|
- if (nlpTaskProcessor.get() != null) {
|
|
|
- nlpTaskProcessor.get().close();
|
|
|
- }
|
|
|
- task.setFailed("inference process crashed due to reason [" + reason + "]");
|
|
|
- };
|
|
|
+ private void onProcessCrash(String reason) {
|
|
|
+ logger.error("[{}] inference process crashed due to reason [{}]", task.getModelId(), reason);
|
|
|
+ processContextByAllocation.remove(task.getId());
|
|
|
+ isStopped = true;
|
|
|
+ resultProcessor.stop();
|
|
|
+ stateStreamer.cancel();
|
|
|
+ priorityProcessWorker.shutdownWithError(new IllegalStateException(reason));
|
|
|
+ if (nlpTaskProcessor.get() != null) {
|
|
|
+ nlpTaskProcessor.get().close();
|
|
|
+ }
|
|
|
+ task.setFailed("inference process crashed due to reason [" + reason + "]");
|
|
|
}
|
|
|
|
|
|
void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
|
|
|
+ if (isStopped) {
|
|
|
+ listener.onFailure(new IllegalArgumentException("Process has stopped, model loading canceled"));
|
|
|
+ return;
|
|
|
+ }
|
|
|
if (modelLocation instanceof IndexLocation indexLocation) {
|
|
|
// Loading the model happens on the inference thread pool but when we get the callback
|
|
|
// we need to return to the utility thread pool to avoid leaking the thread we used.
|