Prechádzať zdrojové kódy

[ML] Set parent task when calling infer action from internal infer (#80731)

This commit sets the parent task id to the trained model infer action
when it is called from the internal infer action (ingest use case).
Dimitris Athanasiou 3 rokov pred
rodič
commit
d4d4211e39

+ 0 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

@@ -98,7 +98,6 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
         int nodeIndex = Randomness.get().nextInt(randomRunningNode.length);
         request.setNodes(randomRunningNode[nodeIndex]);
         super.doExecute(task, request, listener);
-
     }
 
     @Override

+ 24 - 12
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.client.Client;
+import org.elasticsearch.client.ParentTaskAssigningClient;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.core.TimeValue;
@@ -20,6 +21,7 @@ import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.XPackField;
@@ -77,7 +79,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
 
         if (MachineLearningField.ML_API_FEATURE.check(licenseState)) {
             responseBuilder.setLicensed(true);
-            doInfer(request, responseBuilder, listener);
+            doInfer(task, request, responseBuilder, listener);
         } else {
             trainedModelProvider.getTrainedModel(
                 request.getModelId(),
@@ -88,7 +90,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
                     boolean allowed = trainedModelConfig.getLicenseLevel() == License.OperationMode.BASIC;
                     responseBuilder.setLicensed(allowed);
                     if (allowed || request.isPreviouslyLicensed()) {
-                        doInfer(request, responseBuilder, listener);
+                        doInfer(task, request, responseBuilder, listener);
                     } else {
                         listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
                     }
@@ -97,9 +99,9 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
         }
     }
 
-    private void doInfer(Request request, Response.Builder responseBuilder, ActionListener<Response> listener) {
+    private void doInfer(Task task, Request request, Response.Builder responseBuilder, ActionListener<Response> listener) {
         if (isAllocatedModel(request.getModelId())) {
-            inferAgainstAllocatedModel(request, responseBuilder, listener);
+            inferAgainstAllocatedModel(task, request, responseBuilder, listener);
         } else {
             getModelAndInfer(request, responseBuilder, listener);
         }
@@ -138,7 +140,12 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
         modelLoadingService.getModelForPipeline(request.getModelId(), getModelListener);
     }
 
-    private void inferAgainstAllocatedModel(Request request, Response.Builder responseBuilder, ActionListener<Response> listener) {
+    private void inferAgainstAllocatedModel(
+        Task task,
+        Request request,
+        Response.Builder responseBuilder,
+        ActionListener<Response> listener
+    ) {
         TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor = new TypedChainTaskExecutor<>(
             client.threadPool().executor(ThreadPool.Names.SAME),
             // run through all tasks
@@ -150,6 +157,7 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
             .forEach(
                 stringObjectMap -> typedChainTaskExecutor.add(
                     chainedTask -> inferSingleDocAgainstAllocatedModel(
+                        task,
                         request.getModelId(),
                         request.getUpdate(),
                         stringObjectMap,
@@ -169,21 +177,25 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
     }
 
     private void inferSingleDocAgainstAllocatedModel(
+        Task task,
         String modelId,
         InferenceConfigUpdate inferenceConfigUpdate,
         Map<String, Object> doc,
         ActionListener<InferenceResults> listener
     ) {
+        TaskId taskId = new TaskId(clusterService.localNode().getId(), task.getId());
+        InferTrainedModelDeploymentAction.Request request = new InferTrainedModelDeploymentAction.Request(
+            modelId,
+            inferenceConfigUpdate,
+            Collections.singletonList(doc),
+            TimeValue.MAX_VALUE
+        );
+        request.setParentTaskId(taskId);
         executeAsyncWithOrigin(
-            client,
+            new ParentTaskAssigningClient(client, taskId),
             ML_ORIGIN,
             InferTrainedModelDeploymentAction.INSTANCE,
-            new InferTrainedModelDeploymentAction.Request(
-                modelId,
-                inferenceConfigUpdate,
-                Collections.singletonList(doc),
-                TimeValue.MAX_VALUE
-            ),
+            request,
             ActionListener.wrap(r -> listener.onResponse(r.getResults()), e -> {
                 Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
                 if (unwrapped instanceof ElasticsearchStatusException) {