|
@@ -27,6 +27,7 @@ import org.elasticsearch.core.TimeValue;
|
|
import org.elasticsearch.index.query.IdsQueryBuilder;
|
|
import org.elasticsearch.index.query.IdsQueryBuilder;
|
|
import org.elasticsearch.rest.RestStatus;
|
|
import org.elasticsearch.rest.RestStatus;
|
|
import org.elasticsearch.search.SearchHit;
|
|
import org.elasticsearch.search.SearchHit;
|
|
|
|
+import org.elasticsearch.threadpool.Scheduler;
|
|
import org.elasticsearch.threadpool.ThreadPool;
|
|
import org.elasticsearch.threadpool.ThreadPool;
|
|
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
|
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
|
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
|
@@ -58,6 +59,7 @@ import java.util.Optional;
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
import java.util.concurrent.ConcurrentMap;
|
|
import java.util.concurrent.ConcurrentMap;
|
|
import java.util.concurrent.ExecutorService;
|
|
import java.util.concurrent.ExecutorService;
|
|
|
|
+import java.util.concurrent.atomic.AtomicBoolean;
|
|
import java.util.concurrent.atomic.AtomicLong;
|
|
import java.util.concurrent.atomic.AtomicLong;
|
|
import java.util.function.Consumer;
|
|
import java.util.function.Consumer;
|
|
|
|
|
|
@@ -74,6 +76,7 @@ public class DeploymentManager {
|
|
private final PyTorchProcessFactory pyTorchProcessFactory;
|
|
private final PyTorchProcessFactory pyTorchProcessFactory;
|
|
private final ExecutorService executorServiceForDeployment;
|
|
private final ExecutorService executorServiceForDeployment;
|
|
private final ExecutorService executorServiceForProcess;
|
|
private final ExecutorService executorServiceForProcess;
|
|
|
|
+ private final ThreadPool threadPool;
|
|
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
|
|
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
|
|
|
|
|
|
public DeploymentManager(Client client, NamedXContentRegistry xContentRegistry,
|
|
public DeploymentManager(Client client, NamedXContentRegistry xContentRegistry,
|
|
@@ -81,6 +84,7 @@ public class DeploymentManager {
|
|
this.client = Objects.requireNonNull(client);
|
|
this.client = Objects.requireNonNull(client);
|
|
this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
|
|
this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
|
|
this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory);
|
|
this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory);
|
|
|
|
+ this.threadPool = Objects.requireNonNull(threadPool);
|
|
this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
|
|
this.executorServiceForDeployment = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
|
|
this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
|
|
this.executorServiceForProcess = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
|
|
}
|
|
}
|
|
@@ -92,8 +96,8 @@ public class DeploymentManager {
|
|
public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
|
|
public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
|
|
return Optional.ofNullable(processContextByAllocation.get(task.getId()))
|
|
return Optional.ofNullable(processContextByAllocation.get(task.getId()))
|
|
.map(processContext ->
|
|
.map(processContext ->
|
|
- new ModelStats(processContext.resultProcessor.getTimingStats(),
|
|
|
|
- processContext.resultProcessor.getLastUsed())
|
|
|
|
|
|
+ new ModelStats(processContext.getResultProcessor().getTimingStats(),
|
|
|
|
+ processContext.getResultProcessor().getLastUsed())
|
|
);
|
|
);
|
|
}
|
|
}
|
|
|
|
|
|
@@ -117,7 +121,7 @@ public class DeploymentManager {
|
|
|
|
|
|
ActionListener<Boolean> modelLoadedListener = ActionListener.wrap(
|
|
ActionListener<Boolean> modelLoadedListener = ActionListener.wrap(
|
|
success -> {
|
|
success -> {
|
|
- executorServiceForProcess.execute(() -> processContext.resultProcessor.process(processContext.process.get()));
|
|
|
|
|
|
+ executorServiceForProcess.execute(() -> processContext.getResultProcessor().process(processContext.process.get()));
|
|
listener.onResponse(task);
|
|
listener.onResponse(task);
|
|
},
|
|
},
|
|
listener::onFailure
|
|
listener::onFailure
|
|
@@ -226,83 +230,145 @@ public class DeploymentManager {
|
|
return;
|
|
return;
|
|
}
|
|
}
|
|
|
|
|
|
- final String requestId = String.valueOf(requestIdCounter.getAndIncrement());
|
|
|
|
|
|
+ final long requestId = requestIdCounter.getAndIncrement();
|
|
|
|
+ executorServiceForProcess.execute(new InferenceAction(requestId, timeout, processContext, config, doc, threadPool, listener));
|
|
|
|
+ }
|
|
|
|
|
|
- executorServiceForProcess.execute(new AbstractRunnable() {
|
|
|
|
- @Override
|
|
|
|
- public void onFailure(Exception e) {
|
|
|
|
- listener.onFailure(e);
|
|
|
|
- }
|
|
|
|
|
|
+ static class InferenceAction extends AbstractRunnable {
|
|
|
|
+ private final long requestId;
|
|
|
|
+ private final TimeValue timeout;
|
|
|
|
+ private final Scheduler.Cancellable timeoutHandler;
|
|
|
|
+ private final ProcessContext processContext;
|
|
|
|
+ private final InferenceConfig config;
|
|
|
|
+ private final Map<String, Object> doc;
|
|
|
|
+ private final ActionListener<InferenceResults> listener;
|
|
|
|
+ private final AtomicBoolean notified = new AtomicBoolean();
|
|
|
|
+
|
|
|
|
+ InferenceAction(
|
|
|
|
+ long requestId,
|
|
|
|
+ TimeValue timeout,
|
|
|
|
+ ProcessContext processContext,
|
|
|
|
+ InferenceConfig config,
|
|
|
|
+ Map<String, Object> doc,
|
|
|
|
+ ThreadPool threadPool,
|
|
|
|
+ ActionListener<InferenceResults> listener
|
|
|
|
+ ) {
|
|
|
|
+ this.requestId = requestId;
|
|
|
|
+ this.timeout = timeout;
|
|
|
|
+ this.processContext = processContext;
|
|
|
|
+ this.config = config;
|
|
|
|
+ this.doc = doc;
|
|
|
|
+ this.listener = listener;
|
|
|
|
+ this.timeoutHandler = threadPool.schedule(
|
|
|
|
+ this::onTimeout,
|
|
|
|
+ ExceptionsHelper.requireNonNull(timeout, "timeout"),
|
|
|
|
+ MachineLearning.UTILITY_THREAD_POOL_NAME
|
|
|
|
+ );
|
|
|
|
+ }
|
|
|
|
|
|
- @Override
|
|
|
|
- protected void doRun() {
|
|
|
|
- try {
|
|
|
|
- // The request builder expect a list of inputs which are then batched.
|
|
|
|
- // TODO batching was implemented for expected use-cases such as zero-shot
|
|
|
|
- // classification but is not used here.
|
|
|
|
- List<String> text = Collections.singletonList(NlpTask.extractInput(processContext.modelInput.get(), doc));
|
|
|
|
- NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
|
|
|
|
- processor.validateInputs(text);
|
|
|
|
- assert config instanceof NlpConfig;
|
|
|
|
- NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestId);
|
|
|
|
- logger.trace(() -> "Inference Request "+ request.processInput.utf8ToString());
|
|
|
|
- PyTorchResultProcessor.PendingResult pendingResult = processContext.resultProcessor.registerRequest(requestId);
|
|
|
|
- processContext.process.get().writeInferenceRequest(request.processInput);
|
|
|
|
- waitForResult(
|
|
|
|
- processContext,
|
|
|
|
- pendingResult,
|
|
|
|
- request.tokenization,
|
|
|
|
- requestId,
|
|
|
|
- timeout,
|
|
|
|
- processor.getResultProcessor((NlpConfig) config),
|
|
|
|
- listener
|
|
|
|
- );
|
|
|
|
- } catch (IOException e) {
|
|
|
|
- logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
|
|
|
|
- onFailure(ExceptionsHelper.serverError("error writing to process", e));
|
|
|
|
- } catch (Exception e) {
|
|
|
|
- onFailure(e);
|
|
|
|
- } finally {
|
|
|
|
- processContext.resultProcessor.requestAccepted(requestId);
|
|
|
|
- }
|
|
|
|
|
|
+ void onTimeout() {
|
|
|
|
+ if (notified.compareAndSet(false, true)) {
|
|
|
|
+ processContext.getResultProcessor().requestIgnored(String.valueOf(requestId));
|
|
|
|
+ listener.onFailure(
|
|
|
|
+ new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.TOO_MANY_REQUESTS, timeout)
|
|
|
|
+ );
|
|
|
|
+ return;
|
|
}
|
|
}
|
|
- });
|
|
|
|
- }
|
|
|
|
|
|
+ logger.debug("request [{}] received timeout after [{}] but listener already alerted", requestId, timeout);
|
|
|
|
+ }
|
|
|
|
|
|
- private void waitForResult(ProcessContext processContext,
|
|
|
|
- PyTorchResultProcessor.PendingResult pendingResult,
|
|
|
|
- TokenizationResult tokenization,
|
|
|
|
- String requestId,
|
|
|
|
- TimeValue timeout,
|
|
|
|
- NlpTask.ResultProcessor inferenceResultsProcessor,
|
|
|
|
- ActionListener<InferenceResults> listener) {
|
|
|
|
- try {
|
|
|
|
- PyTorchResult pyTorchResult = processContext.resultProcessor.waitForResult(
|
|
|
|
- processContext.process.get(),
|
|
|
|
- requestId,
|
|
|
|
- pendingResult,
|
|
|
|
- timeout
|
|
|
|
- );
|
|
|
|
- if (pyTorchResult == null) {
|
|
|
|
- listener.onFailure(new ElasticsearchStatusException("timeout [{}] waiting for inference result",
|
|
|
|
- RestStatus.TOO_MANY_REQUESTS, timeout));
|
|
|
|
|
|
+ void onSuccess(InferenceResults inferenceResults) {
|
|
|
|
+ timeoutHandler.cancel();
|
|
|
|
+ if (notified.compareAndSet(false, true)) {
|
|
|
|
+ listener.onResponse(inferenceResults);
|
|
return;
|
|
return;
|
|
}
|
|
}
|
|
|
|
+ logger.debug("request [{}] received inference response but listener already notified", requestId);
|
|
|
|
+ }
|
|
|
|
|
|
- if (pyTorchResult.isError()) {
|
|
|
|
- listener.onFailure(new ElasticsearchStatusException(pyTorchResult.getError(),
|
|
|
|
- RestStatus.INTERNAL_SERVER_ERROR));
|
|
|
|
|
|
+ @Override
|
|
|
|
+ public void onFailure(Exception e) {
|
|
|
|
+ timeoutHandler.cancel();
|
|
|
|
+ if (notified.compareAndSet(false, true)) {
|
|
|
|
+ listener.onFailure(e);
|
|
return;
|
|
return;
|
|
}
|
|
}
|
|
|
|
+ logger.debug(
|
|
|
|
+ () -> new ParameterizedMessage("request [{}] received failure but listener already notified", requestId),
|
|
|
|
+ e
|
|
|
|
+ );
|
|
|
|
+ }
|
|
|
|
|
|
- logger.debug(() -> new ParameterizedMessage(
|
|
|
|
- "[{}] retrieved result for request [{}]", processContext.task.getModelId(), requestId));
|
|
|
|
- InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult);
|
|
|
|
- logger.debug(() -> new ParameterizedMessage(
|
|
|
|
- "[{}] processed result for request [{}]", processContext.task.getModelId(), requestId));
|
|
|
|
- listener.onResponse(results);
|
|
|
|
- } catch (InterruptedException e) {
|
|
|
|
- listener.onFailure(e);
|
|
|
|
|
|
+ @Override
|
|
|
|
+ protected void doRun() throws Exception {
|
|
|
|
+ final String requestIdStr = String.valueOf(requestId);
|
|
|
|
+ try {
|
|
|
|
+ // The request builder expect a list of inputs which are then batched.
|
|
|
|
+ // TODO batching was implemented for expected use-cases such as zero-shot
|
|
|
|
+ // classification but is not used here.
|
|
|
|
+ List<String> text = Collections.singletonList(NlpTask.extractInput(processContext.modelInput.get(), doc));
|
|
|
|
+ NlpTask.Processor processor = processContext.nlpTaskProcessor.get();
|
|
|
|
+ processor.validateInputs(text);
|
|
|
|
+ assert config instanceof NlpConfig;
|
|
|
|
+ NlpTask.Request request = processor.getRequestBuilder((NlpConfig) config).buildRequest(text, requestIdStr);
|
|
|
|
+ logger.trace(() -> "Inference Request "+ request.processInput.utf8ToString());
|
|
|
|
+ PyTorchResultProcessor.PendingResult pendingResult = processContext.getResultProcessor().registerRequest(requestIdStr);
|
|
|
|
+ processContext.process.get().writeInferenceRequest(request.processInput);
|
|
|
|
+ waitForResult(
|
|
|
|
+ processContext,
|
|
|
|
+ pendingResult,
|
|
|
|
+ request.tokenization,
|
|
|
|
+ requestIdStr,
|
|
|
|
+ timeout,
|
|
|
|
+ processor.getResultProcessor((NlpConfig) config),
|
|
|
|
+ ActionListener.wrap(this::onSuccess,this::onFailure)
|
|
|
|
+ );
|
|
|
|
+ } catch (IOException e) {
|
|
|
|
+ logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
|
|
|
|
+ onFailure(ExceptionsHelper.serverError("error writing to process", e));
|
|
|
|
+ } catch (Exception e) {
|
|
|
|
+ onFailure(e);
|
|
|
|
+ } finally {
|
|
|
|
+ processContext.getResultProcessor().requestIgnored(String.valueOf(requestId));
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private void waitForResult(ProcessContext processContext,
|
|
|
|
+ PyTorchResultProcessor.PendingResult pendingResult,
|
|
|
|
+ TokenizationResult tokenization,
|
|
|
|
+ String requestId,
|
|
|
|
+ TimeValue timeout,
|
|
|
|
+ NlpTask.ResultProcessor inferenceResultsProcessor,
|
|
|
|
+ ActionListener<InferenceResults> listener) {
|
|
|
|
+ try {
|
|
|
|
+ PyTorchResult pyTorchResult = processContext.getResultProcessor().waitForResult(
|
|
|
|
+ processContext.process.get(),
|
|
|
|
+ requestId,
|
|
|
|
+ pendingResult,
|
|
|
|
+ timeout
|
|
|
|
+ );
|
|
|
|
+ if (pyTorchResult == null) {
|
|
|
|
+ listener.onFailure(
|
|
|
|
+ new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.TOO_MANY_REQUESTS, timeout)
|
|
|
|
+ );
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if (pyTorchResult.isError()) {
|
|
|
|
+ listener.onFailure(new ElasticsearchStatusException(pyTorchResult.getError(),
|
|
|
|
+ RestStatus.INTERNAL_SERVER_ERROR));
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ logger.debug(() -> new ParameterizedMessage(
|
|
|
|
+ "[{}] retrieved result for request [{}]", processContext.task.getModelId(), requestId));
|
|
|
|
+ InferenceResults results = inferenceResultsProcessor.processResult(tokenization, pyTorchResult);
|
|
|
|
+ logger.debug(() -> new ParameterizedMessage(
|
|
|
|
+ "[{}] processed result for request [{}]", processContext.task.getModelId(), requestId));
|
|
|
|
+ listener.onResponse(results);
|
|
|
|
+ } catch (InterruptedException e) {
|
|
|
|
+ listener.onFailure(e);
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -321,6 +387,10 @@ public class DeploymentManager {
|
|
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
|
|
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ PyTorchResultProcessor getResultProcessor() {
|
|
|
|
+ return resultProcessor;
|
|
|
|
+ }
|
|
|
|
+
|
|
synchronized void startProcess() {
|
|
synchronized void startProcess() {
|
|
process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, onProcessCrash()));
|
|
process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, onProcessCrash()));
|
|
}
|
|
}
|