|
|
@@ -53,6 +53,7 @@ import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactor
|
|
|
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
|
|
|
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
|
|
|
import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;
|
|
|
+import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.time.Duration;
|
|
|
@@ -85,6 +86,7 @@ public class DeploymentManager {
|
|
|
private final ExecutorService executorServiceForDeployment;
|
|
|
private final ExecutorService executorServiceForProcess;
|
|
|
private final ThreadPool threadPool;
|
|
|
+ private final InferenceAuditor inferenceAuditor;
|
|
|
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
|
|
|
private final int maxProcesses;
|
|
|
|
|
|
@@ -93,12 +95,14 @@ public class DeploymentManager {
|
|
|
NamedXContentRegistry xContentRegistry,
|
|
|
ThreadPool threadPool,
|
|
|
PyTorchProcessFactory pyTorchProcessFactory,
|
|
|
- int maxProcesses
|
|
|
+ int maxProcesses,
|
|
|
+ InferenceAuditor inferenceAuditor
|
|
|
) {
|
|
|
this.client = Objects.requireNonNull(client);
|
|
|
this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
|
|
|
this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory);
|
|
|
this.threadPool = Objects.requireNonNull(threadPool);
|
|
|
+ this.inferenceAuditor = Objects.requireNonNull(inferenceAuditor);
|
|
|
this.executorServiceForDeployment = threadPool.executor(UTILITY_THREAD_POOL_NAME);
|
|
|
this.executorServiceForProcess = threadPool.executor(MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME);
|
|
|
this.maxProcesses = maxProcesses;
|
|
|
@@ -523,7 +527,7 @@ public class DeploymentManager {
|
|
|
task,
|
|
|
executorServiceForProcess,
|
|
|
() -> resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES),
|
|
|
- onProcessCrashHandleRestarts(startsCount)
|
|
|
+ onProcessCrashHandleRestarts(startsCount, task.getDeploymentId())
|
|
|
)
|
|
|
);
|
|
|
startTime = Instant.now();
|
|
|
@@ -546,16 +550,19 @@ public class DeploymentManager {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private Consumer<String> onProcessCrashHandleRestarts(AtomicInteger startsCount) {
|
|
|
+ private Consumer<String> onProcessCrashHandleRestarts(AtomicInteger startsCount, String deploymentId) {
|
|
|
return (reason) -> {
|
|
|
if (isThisProcessOlderThan1Day()) {
|
|
|
startsCount.set(1);
|
|
|
- logger.error(
|
|
|
- "[{}] inference process crashed due to reason [{}]. This process was started more than 24 hours ago; "
|
|
|
- + "the starts count is reset to 1.",
|
|
|
- task.getDeploymentId(),
|
|
|
- reason
|
|
|
- );
|
|
|
+ {
|
|
|
+ String logMessage = "["
|
|
|
+ + task.getDeploymentId()
|
|
|
+ + "] inference process crashed due to reason ["
|
|
|
+ + reason
|
|
|
+ + "]. This process was started more than 24 hours ago; "
|
|
|
+ + "the starts count is reset to 1.";
|
|
|
+ logger.error(logMessage);
|
|
|
+ }
|
|
|
} else {
|
|
|
logger.error("[{}] inference process crashed due to reason [{}]", task.getDeploymentId(), reason);
|
|
|
}
|
|
|
@@ -566,20 +573,32 @@ public class DeploymentManager {
|
|
|
stateStreamer.cancel();
|
|
|
|
|
|
if (startsCount.get() <= NUM_RESTART_ATTEMPTS) {
|
|
|
- logger.info("[{}] restarting inference process after [{}] starts", task.getDeploymentId(), startsCount.get());
|
|
|
+ {
|
|
|
+ String logAndAuditMessage = "Inference process ["
|
|
|
+ + task.getDeploymentId()
|
|
|
+ + "] failed due to ["
|
|
|
+ + reason
|
|
|
+ + "]. This is the ["
|
|
|
+ + startsCount.get()
|
|
|
+ + "] failure in 24 hours, and the process will be restarted.";
|
|
|
+ logger.info(logAndAuditMessage);
|
|
|
+ threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
|
|
|
+ .execute(() -> inferenceAuditor.warning(deploymentId, logAndAuditMessage));
|
|
|
+ }
|
|
|
priorityProcessWorker.shutdownNow(); // TODO what to do with these tasks?
|
|
|
ActionListener<TrainedModelDeploymentTask> errorListener = ActionListener.wrap((trainedModelDeploymentTask -> {
|
|
|
logger.debug("Completed restart of inference process, the [{}] start", startsCount);
|
|
|
}),
|
|
|
(e) -> finishClosingProcess(
|
|
|
startsCount,
|
|
|
- "Failed to restart inference process because of error [" + e.getMessage() + "]"
|
|
|
+ "Failed to restart inference process because of error [" + e.getMessage() + "]",
|
|
|
+ deploymentId
|
|
|
)
|
|
|
);
|
|
|
|
|
|
startDeployment(task, startsCount.incrementAndGet(), errorListener);
|
|
|
} else {
|
|
|
- finishClosingProcess(startsCount, reason);
|
|
|
+ finishClosingProcess(startsCount, reason, deploymentId);
|
|
|
}
|
|
|
};
|
|
|
}
|
|
|
@@ -588,8 +607,15 @@ public class DeploymentManager {
|
|
|
return startTime.isBefore(Instant.now().minus(Duration.ofDays(1)));
|
|
|
}
|
|
|
|
|
|
- private void finishClosingProcess(AtomicInteger startsCount, String reason) {
|
|
|
- logger.warn("[{}] inference process failed after [{}] starts, not restarting again", task.getDeploymentId(), startsCount.get());
|
|
|
+ private void finishClosingProcess(AtomicInteger startsCount, String reason, String deploymentId) {
|
|
|
+ String logAndAuditMessage = "["
|
|
|
+ + task.getDeploymentId()
|
|
|
+ + "] inference process failed after ["
|
|
|
+ + startsCount.get()
|
|
|
+ + "] starts in 24 hours, not restarting again.";
|
|
|
+ logger.warn(logAndAuditMessage);
|
|
|
+ threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
|
|
|
+ .execute(() -> inferenceAuditor.error(deploymentId, logAndAuditMessage));
|
|
|
priorityProcessWorker.shutdownNowWithError(new IllegalStateException(reason));
|
|
|
if (nlpTaskProcessor.get() != null) {
|
|
|
nlpTaskProcessor.get().close();
|