|
@@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.support.ActionFilters;
|
|
|
import org.elasticsearch.action.support.HandledTransportAction;
|
|
|
import org.elasticsearch.client.Client;
|
|
|
+import org.elasticsearch.client.ParentTaskAssigningClient;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
|
import org.elasticsearch.common.inject.Inject;
|
|
|
import org.elasticsearch.core.TimeValue;
|
|
@@ -20,6 +21,7 @@ import org.elasticsearch.license.LicenseUtils;
|
|
|
import org.elasticsearch.license.XPackLicenseState;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.tasks.Task;
|
|
|
+import org.elasticsearch.tasks.TaskId;
|
|
|
import org.elasticsearch.threadpool.ThreadPool;
|
|
|
import org.elasticsearch.transport.TransportService;
|
|
|
import org.elasticsearch.xpack.core.XPackField;
|
|
@@ -77,7 +79,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|
|
|
|
|
if (MachineLearningField.ML_API_FEATURE.check(licenseState)) {
|
|
|
responseBuilder.setLicensed(true);
|
|
|
- doInfer(request, responseBuilder, listener);
|
|
|
+ doInfer(task, request, responseBuilder, listener);
|
|
|
} else {
|
|
|
trainedModelProvider.getTrainedModel(
|
|
|
request.getModelId(),
|
|
@@ -88,7 +90,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|
|
boolean allowed = trainedModelConfig.getLicenseLevel() == License.OperationMode.BASIC;
|
|
|
responseBuilder.setLicensed(allowed);
|
|
|
if (allowed || request.isPreviouslyLicensed()) {
|
|
|
- doInfer(request, responseBuilder, listener);
|
|
|
+ doInfer(task, request, responseBuilder, listener);
|
|
|
} else {
|
|
|
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
|
|
|
}
|
|
@@ -97,9 +99,9 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private void doInfer(Request request, Response.Builder responseBuilder, ActionListener<Response> listener) {
|
|
|
+ private void doInfer(Task task, Request request, Response.Builder responseBuilder, ActionListener<Response> listener) {
|
|
|
if (isAllocatedModel(request.getModelId())) {
|
|
|
- inferAgainstAllocatedModel(request, responseBuilder, listener);
|
|
|
+ inferAgainstAllocatedModel(task, request, responseBuilder, listener);
|
|
|
} else {
|
|
|
getModelAndInfer(request, responseBuilder, listener);
|
|
|
}
|
|
@@ -138,7 +140,12 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|
|
modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
|
|
|
}
|
|
|
|
|
|
- private void inferAgainstAllocatedModel(Request request, Response.Builder responseBuilder, ActionListener<Response> listener) {
|
|
|
+ private void inferAgainstAllocatedModel(
|
|
|
+ Task task,
|
|
|
+ Request request,
|
|
|
+ Response.Builder responseBuilder,
|
|
|
+ ActionListener<Response> listener
|
|
|
+ ) {
|
|
|
TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor = new TypedChainTaskExecutor<>(
|
|
|
client.threadPool().executor(ThreadPool.Names.SAME),
|
|
|
// run through all tasks
|
|
@@ -150,6 +157,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|
|
.forEach(
|
|
|
stringObjectMap -> typedChainTaskExecutor.add(
|
|
|
chainedTask -> inferSingleDocAgainstAllocatedModel(
|
|
|
+ task,
|
|
|
request.getModelId(),
|
|
|
request.getUpdate(),
|
|
|
stringObjectMap,
|
|
@@ -169,21 +177,25 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
|
|
|
}
|
|
|
|
|
|
private void inferSingleDocAgainstAllocatedModel(
|
|
|
+ Task task,
|
|
|
String modelId,
|
|
|
InferenceConfigUpdate inferenceConfigUpdate,
|
|
|
Map<String, Object> doc,
|
|
|
ActionListener<InferenceResults> listener
|
|
|
) {
|
|
|
+ TaskId taskId = new TaskId(clusterService.localNode().getId(), task.getId());
|
|
|
+ InferTrainedModelDeploymentAction.Request request = new InferTrainedModelDeploymentAction.Request(
|
|
|
+ modelId,
|
|
|
+ inferenceConfigUpdate,
|
|
|
+ Collections.singletonList(doc),
|
|
|
+ TimeValue.MAX_VALUE
|
|
|
+ );
|
|
|
+ request.setParentTaskId(taskId);
|
|
|
executeAsyncWithOrigin(
|
|
|
- client,
|
|
|
+ new ParentTaskAssigningClient(client, taskId),
|
|
|
ML_ORIGIN,
|
|
|
InferTrainedModelDeploymentAction.INSTANCE,
|
|
|
- new InferTrainedModelDeploymentAction.Request(
|
|
|
- modelId,
|
|
|
- inferenceConfigUpdate,
|
|
|
- Collections.singletonList(doc),
|
|
|
- TimeValue.MAX_VALUE
|
|
|
- ),
|
|
|
+ request,
|
|
|
ActionListener.wrap(r -> listener.onResponse(r.getResults()), e -> {
|
|
|
Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
|
|
|
if (unwrapped instanceof ElasticsearchStatusException) {
|