Browse Source

Preserve parent task id for data frame analytics (#55046)

This change makes sure that all internal client requests spawned by the
data frame analytics persistent task executor and that use the end user
security credentials, have the parent task id assigned. The objective here
is to permit auditing (as well as tracking for debugging purposes) of all
the end-user requests executed on its behalf by persistent tasks.
Because data frame analytics taks already implements graceful shutdown
of child tasks, this change does not interfere with it by opting out of
the persistent task cancellation of child tasks.

Relates #54943 #52314
Albert Zaharovits 5 năm trước cách đây
mục cha
commit
f7809dd096

+ 2 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/AnomalyDetectorsIndex.java

@@ -78,7 +78,8 @@ public final class AnomalyDetectorsIndex {
      * Creates the .ml-state-000001 index (if necessary)
      * Creates the .ml-state-write alias for the .ml-state-000001 index (if necessary)
      */
-    public static void createStateIndexAndAliasIfNecessary(Client client, ClusterState state, IndexNameExpressionResolver resolver,
+    public static void createStateIndexAndAliasIfNecessary(Client client, ClusterState state,
+                                                           IndexNameExpressionResolver resolver,
                                                            final ActionListener<Boolean> finalListener) {
         MlIndexAndAlias.createIndexAndAliasIfNecessary(
             client,

+ 4 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.java

@@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListenerResponseHandler;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
+import org.elasticsearch.client.ParentTaskAssigningClient;
 import org.elasticsearch.client.node.NodeClient;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -83,7 +84,8 @@ public class TransportExplainDataFrameAnalyticsAction
 
     private void explain(Task task, PutDataFrameAnalyticsAction.Request request,
                          ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
-        ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory(client);
+        ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory =
+                new ExtractedFieldsDetectorFactory(new ParentTaskAssigningClient(client, task.getParentTaskId()));
         extractedFieldsDetectorFactory.createFromSource(
             request.getConfig(),
             ActionListener.wrap(
@@ -115,7 +117,7 @@ public class TransportExplainDataFrameAnalyticsAction
                                      ActionListener<MemoryEstimation> listener) {
         final String estimateMemoryTaskId = "memory_usage_estimation_" + task.getId();
         DataFrameDataExtractorFactory extractorFactory = DataFrameDataExtractorFactory.createForSourceIndices(
-            client, estimateMemoryTaskId, request.getConfig(), extractedFields);
+            new ParentTaskAssigningClient(client, task.getParentTaskId()), estimateMemoryTaskId, request.getConfig(), extractedFields);
         processManager.runJobAsync(
             estimateMemoryTaskId,
             request.getConfig(),

+ 11 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java

@@ -19,6 +19,7 @@ import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.action.support.master.TransportMasterNodeAction;
 import org.elasticsearch.client.Client;
+import org.elasticsearch.client.ParentTaskAssigningClient;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.block.ClusterBlockException;
 import org.elasticsearch.cluster.block.ClusterBlockLevel;
@@ -200,7 +201,7 @@ public class TransportStartDataFrameAnalyticsAction
         );
 
         // Get start context
-        getStartContext(request.getId(), startContextListener);
+        getStartContext(request.getId(), task, startContextListener);
     }
 
     private void estimateMemoryUsageAndUpdateMemoryTracker(StartContext startContext, ActionListener<StartContext> listener) {
@@ -240,8 +241,9 @@ public class TransportStartDataFrameAnalyticsAction
 
     }
 
-    private void getStartContext(String id, ActionListener<StartContext> finalListener) {
+    private void getStartContext(String id, Task task, ActionListener<StartContext> finalListener) {
 
+        ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
         // Step 7. Validate that there are analyzable data in the source index
         ActionListener<StartContext> validateMappingsMergeListener = ActionListener.wrap(
             startContext -> validateSourceIndexHasRows(startContext, finalListener),
@@ -250,7 +252,7 @@ public class TransportStartDataFrameAnalyticsAction
 
         // Step 6. Validate mappings can be merged
         ActionListener<StartContext> toValidateMappingsListener = ActionListener.wrap(
-            startContext -> MappingsMerger.mergeMappings(client, startContext.config.getHeaders(),
+            startContext -> MappingsMerger.mergeMappings(parentTaskClient, startContext.config.getHeaders(),
                 startContext.config.getSource(), ActionListener.wrap(
                 mappings -> validateMappingsMergeListener.onResponse(startContext), finalListener::onFailure)),
             finalListener::onFailure
@@ -261,7 +263,7 @@ public class TransportStartDataFrameAnalyticsAction
             startContext -> {
                 switch (startContext.startingState) {
                     case FIRST_TIME:
-                        checkDestIndexIsEmptyIfExists(startContext, toValidateMappingsListener);
+                        checkDestIndexIsEmptyIfExists(parentTaskClient, startContext, toValidateMappingsListener);
                         break;
                     case RESUMING_REINDEXING:
                     case RESUMING_ANALYZING:
@@ -283,7 +285,7 @@ public class TransportStartDataFrameAnalyticsAction
         // Step 4. Check data extraction is possible
         ActionListener<StartContext> toValidateExtractionPossibleListener = ActionListener.wrap(
             startContext -> {
-                new ExtractedFieldsDetectorFactory(client).createFromSource(startContext.config, ActionListener.wrap(
+                new ExtractedFieldsDetectorFactory(parentTaskClient).createFromSource(startContext.config, ActionListener.wrap(
                     extractedFieldsDetector -> {
                         startContext.extractedFields = extractedFieldsDetector.detect().v1();
                         toValidateDestEmptyListener.onResponse(startContext);
@@ -361,13 +363,14 @@ public class TransportStartDataFrameAnalyticsAction
         ));
     }
 
-    private void checkDestIndexIsEmptyIfExists(StartContext startContext, ActionListener<StartContext> listener) {
+    private void checkDestIndexIsEmptyIfExists(ParentTaskAssigningClient parentTaskClient, StartContext startContext,
+                                               ActionListener<StartContext> listener) {
         String destIndex = startContext.config.getDest().getIndex();
         SearchRequest destEmptySearch = new SearchRequest(destIndex);
         destEmptySearch.source().size(0);
         destEmptySearch.allowPartialSearchResults(false);
-        ClientHelper.executeWithHeadersAsync(startContext.config.getHeaders(), ClientHelper.ML_ORIGIN, client, SearchAction.INSTANCE,
-            destEmptySearch, ActionListener.wrap(
+        ClientHelper.executeWithHeadersAsync(startContext.config.getHeaders(), ClientHelper.ML_ORIGIN, parentTaskClient,
+                SearchAction.INSTANCE, destEmptySearch, ActionListener.wrap(
                 searchResponse -> {
                     if (searchResponse.getHits().getTotalHits().value > 0) {
                         listener.onFailure(ExceptionsHelper.badRequestException("dest index [{}] must be empty", destIndex));

+ 18 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java

@@ -20,6 +20,8 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
 import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
 import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
 import org.elasticsearch.action.support.ContextPreservingActionListener;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.client.ParentTaskAssigningClient;
 import org.elasticsearch.client.node.NodeClient;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
@@ -113,15 +115,16 @@ public class DataFrameAnalyticsManager {
 
         // Make sure the stats index and alias exist
         ActionListener<Boolean> stateAliasListener = ActionListener.wrap(
-            aBoolean -> createStatsIndexAndUpdateMappingsIfNecessary(clusterState, statsIndexListener),
-            configListener::onFailure
+            aBoolean -> createStatsIndexAndUpdateMappingsIfNecessary(new ParentTaskAssigningClient(client, task.getParentTaskId()),
+                    clusterState, statsIndexListener), configListener::onFailure
         );
 
         // Make sure the state index and alias exist
-        AnomalyDetectorsIndex.createStateIndexAndAliasIfNecessary(client, clusterState, expressionResolver, stateAliasListener);
+        AnomalyDetectorsIndex.createStateIndexAndAliasIfNecessary(new ParentTaskAssigningClient(client, task.getParentTaskId()),
+                clusterState, expressionResolver, stateAliasListener);
     }
 
-    private void createStatsIndexAndUpdateMappingsIfNecessary(ClusterState clusterState, ActionListener<Boolean> listener) {
+    private void createStatsIndexAndUpdateMappingsIfNecessary(Client client, ClusterState clusterState, ActionListener<Boolean> listener) {
         ActionListener<Boolean> createIndexListener = ActionListener.wrap(
             aBoolean -> ElasticsearchMappings.addDocMappingIfMissing(
                     MlStatsIndex.writeAlias(),
@@ -175,7 +178,7 @@ public class DataFrameAnalyticsManager {
             task.markAsCompleted();
             return;
         }
-        ClientHelper.executeAsyncWithOrigin(client,
+        ClientHelper.executeAsyncWithOrigin(new ParentTaskAssigningClient(client, task.getParentTaskId()),
             ML_ORIGIN,
             DeleteIndexAction.INSTANCE,
             new DeleteIndexRequest(config.getDest().getIndex()),
@@ -200,6 +203,8 @@ public class DataFrameAnalyticsManager {
             return;
         }
 
+        final ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
+
         // Reindexing is complete; start analytics
         ActionListener<BulkByScrollResponse> reindexCompletedListener = ActionListener.wrap(
             reindexResponse -> {
@@ -239,8 +244,9 @@ public class DataFrameAnalyticsManager {
                 reindexRequest.getSearchRequest().source().fetchSource(config.getSource().getSourceFiltering());
                 reindexRequest.setDestIndex(config.getDest().getIndex());
                 reindexRequest.setScript(new Script("ctx._source." + DestinationIndex.ID_COPY + " = ctx._id"));
+                reindexRequest.setParentTask(task.getParentTaskId());
 
-                final ThreadContext threadContext = client.threadPool().getThreadContext();
+                final ThreadContext threadContext = parentTaskClient.threadPool().getThreadContext();
                 final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
                 try (ThreadContext.StoredContext ignore = threadContext.stashWithOrigin(ML_ORIGIN)) {
                     LOGGER.info("[{}] Started reindexing", config.getId());
@@ -261,7 +267,7 @@ public class DataFrameAnalyticsManager {
                     config.getId(),
                     Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_REUSING_DEST_INDEX, indexResponse.indices()[0]));
                 LOGGER.info("[{}] Using existing destination index [{}]", config.getId(), indexResponse.indices()[0]);
-                DestinationIndex.updateMappingsToDestIndex(client, config, indexResponse, ActionListener.wrap(
+                DestinationIndex.updateMappingsToDestIndex(parentTaskClient, config, indexResponse, ActionListener.wrap(
                     acknowledgedResponse -> copyIndexCreatedListener.onResponse(null),
                     copyIndexCreatedListener::onFailure
                 ));
@@ -272,14 +278,14 @@ public class DataFrameAnalyticsManager {
                         config.getId(),
                         Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_CREATING_DEST_INDEX, config.getDest().getIndex()));
                     LOGGER.info("[{}] Creating destination index [{}]", config.getId(), config.getDest().getIndex());
-                    DestinationIndex.createDestinationIndex(client, Clock.systemUTC(), config, copyIndexCreatedListener);
+                    DestinationIndex.createDestinationIndex(parentTaskClient, Clock.systemUTC(), config, copyIndexCreatedListener);
                 } else {
                     copyIndexCreatedListener.onFailure(e);
                 }
             }
         );
 
-        ClientHelper.executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, client, GetIndexAction.INSTANCE,
+        ClientHelper.executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, parentTaskClient, GetIndexAction.INSTANCE,
                 new GetIndexRequest().indices(config.getDest().getIndex()), destIndexListener);
     }
 
@@ -289,6 +295,7 @@ public class DataFrameAnalyticsManager {
             task.markAsCompleted();
             return;
         }
+        final ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
         // Update state to ANALYZING and start process
         ActionListener<DataFrameDataExtractorFactory> dataExtractorFactoryListener = ActionListener.wrap(
             dataExtractorFactory -> {
@@ -325,14 +332,14 @@ public class DataFrameAnalyticsManager {
                 // TODO This could fail with errors. In that case we get stuck with the copied index.
                 // We could delete the index in case of failure or we could try building the factory before reindexing
                 // to catch the error early on.
-                DataFrameDataExtractorFactory.createForDestinationIndex(client, config, dataExtractorFactoryListener);
+                DataFrameDataExtractorFactory.createForDestinationIndex(parentTaskClient, config, dataExtractorFactoryListener);
             },
             dataExtractorFactoryListener::onFailure
         );
 
         // First we need to refresh the dest index to ensure data is searchable in case the job
         // was stopped after reindexing was complete but before the index was refreshed.
-        executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, client, RefreshAction.INSTANCE,
+        executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, parentTaskClient, RefreshAction.INSTANCE,
             new RefreshRequest(config.getDest().getIndex()), refreshListener);
     }
 

+ 8 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsTask.java

@@ -22,6 +22,7 @@ import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.client.Client;
+import org.elasticsearch.client.ParentTaskAssigningClient;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.unit.TimeValue;
@@ -75,7 +76,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
                                   Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager,
                                   DataFrameAnalyticsAuditor auditor, StartDataFrameAnalyticsAction.TaskParams taskParams) {
         super(id, type, action, MlTasks.DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + taskParams.getId(), parentTask, headers);
-        this.client = Objects.requireNonNull(client);
+        this.client = new ParentTaskAssigningClient(Objects.requireNonNull(client), parentTask);
         this.clusterService = Objects.requireNonNull(clusterService);
         this.analyticsManager = Objects.requireNonNull(analyticsManager);
         this.auditor = Objects.requireNonNull(auditor);
@@ -109,6 +110,12 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
         markAsCompleted();
     }
 
+    @Override
+    public boolean shouldCancelChildrenOnCancellation() {
+        // onCancelled implements graceful shutdown of children
+        return false;
+    }
+
     @Override
     public void markAsCompleted() {
         // It is possible that the stop API has been called in the meantime and that

+ 19 - 16
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

@@ -15,6 +15,7 @@ import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.client.Client;
+import org.elasticsearch.client.ParentTaskAssigningClient;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
@@ -155,14 +156,14 @@ public class AnalyticsProcessManager {
         LOGGER.info("[{}] Started loading data", processContext.config.getId());
         auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_LOADING_DATA));
 
+        ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
         DataFrameAnalyticsConfig config = processContext.config;
         DataFrameDataExtractor dataExtractor = processContext.dataExtractor.get();
         AnalyticsProcess<AnalyticsResult> process = processContext.process.get();
         AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
         try {
             writeHeaderRecord(dataExtractor, process);
-            writeDataRows(dataExtractor, process, config, task.getStatsHolder().getProgressTracker(),
-                task.getStatsHolder().getDataCountsTracker());
+            writeDataRows(dataExtractor, process, config, task);
             processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()),
                 DataCounts::documentId);
             process.writeEndOfDataMessage();
@@ -177,8 +178,8 @@ public class AnalyticsProcessManager {
             resultProcessor.awaitForCompletion();
             processContext.setFailureReason(resultProcessor.getFailure());
 
-            refreshDest(config);
-            refreshIndices(config.getId());
+            refreshDest(parentTaskClient, config);
+            refreshIndices(parentTaskClient, config.getId());
             LOGGER.info("[{}] Result processor has completed", config.getId());
         } catch (Exception e) {
             if (task.isStopping()) {
@@ -214,11 +215,12 @@ public class AnalyticsProcessManager {
     }
 
     private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
-                               DataFrameAnalyticsConfig config, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker)
-        throws IOException {
-
-        CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(client, config, dataExtractor.getFieldNames())
-            .create();
+                               DataFrameAnalyticsConfig config, DataFrameAnalyticsTask task) throws IOException {
+        ProgressTracker progressTracker = task.getStatsHolder().getProgressTracker();
+        DataCountsTracker dataCountsTracker = task.getStatsHolder().getDataCountsTracker();
+        CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(
+                new ParentTaskAssigningClient(client, task.getParentTaskId()), config, dataExtractor.getFieldNames())
+                .create();
 
         // The extra fields are for the doc hash and the control field (should be an empty string)
         String[] record = new String[dataExtractor.getFieldNames().size() + 2];
@@ -312,12 +314,12 @@ public class AnalyticsProcessManager {
         };
     }
 
-    private void refreshDest(DataFrameAnalyticsConfig config) {
-        ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
-            () -> client.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex())).actionGet());
+    private void refreshDest(ParentTaskAssigningClient parentTaskClient, DataFrameAnalyticsConfig config) {
+        ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, parentTaskClient,
+            () -> parentTaskClient.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex())).actionGet());
     }
 
-    private void refreshIndices(String jobId) {
+    private void refreshIndices(ParentTaskAssigningClient parentTaskClient, String jobId) {
         RefreshRequest refreshRequest = new RefreshRequest(
             AnomalyDetectorsIndex.jobStateIndexPattern(),
             MlStatsIndex.indexPattern()
@@ -327,8 +329,8 @@ public class AnalyticsProcessManager {
         LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}",
             jobId, Arrays.toString(refreshRequest.indices())));
 
-        try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
-            client.admin().indices().refresh(refreshRequest).actionGet();
+        try (ThreadContext.StoredContext ignore = parentTaskClient.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
+            parentTaskClient.admin().indices().refresh(refreshRequest).actionGet();
         }
     }
 
@@ -455,7 +457,8 @@ public class AnalyticsProcessManager {
         private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task,
                                                                DataFrameDataExtractorFactory dataExtractorFactory) {
             DataFrameRowsJoiner dataFrameRowsJoiner =
-                new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
+                new DataFrameRowsJoiner(config.getId(), task.getParentTaskId(),
+                        dataExtractorFactory.newExtractor(true), resultsPersisterService);
             return new AnalyticsResultProcessor(
                 config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister,
                 dataExtractor.get().getAllExtractedFields());

+ 7 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java

@@ -13,6 +13,7 @@ import org.elasticsearch.action.bulk.BulkRequest;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
 import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
@@ -35,6 +36,7 @@ class DataFrameRowsJoiner implements AutoCloseable {
     private static final int RESULTS_BATCH_SIZE = 1000;
 
     private final String analyticsId;
+    private final TaskId parentTaskId;
     private final DataFrameDataExtractor dataExtractor;
     private final ResultsPersisterService resultsPersisterService;
     private final Iterator<DataFrameDataExtractor.Row> dataFrameRowsIterator;
@@ -42,8 +44,10 @@ class DataFrameRowsJoiner implements AutoCloseable {
     private volatile String failure;
     private volatile boolean isCancelled;
 
-    DataFrameRowsJoiner(String analyticsId, DataFrameDataExtractor dataExtractor, ResultsPersisterService resultsPersisterService) {
+    DataFrameRowsJoiner(String analyticsId, TaskId parentTaskId, DataFrameDataExtractor dataExtractor,
+                        ResultsPersisterService resultsPersisterService) {
         this.analyticsId = Objects.requireNonNull(analyticsId);
+        this.parentTaskId = Objects.requireNonNull(parentTaskId);
         this.dataExtractor = Objects.requireNonNull(dataExtractor);
         this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
         this.dataFrameRowsIterator = new ResultMatchingDataFrameRows();
@@ -90,6 +94,7 @@ class DataFrameRowsJoiner implements AutoCloseable {
             bulkRequest.add(createIndexRequest(result, row.getHit()));
         }
         if (bulkRequest.numberOfActions() > 0) {
+            bulkRequest.setParentTask(parentTaskId);
             resultsPersisterService.bulkIndexWithHeadersWithRetry(
                 dataExtractor.getHeaders(),
                 bulkRequest,
@@ -117,6 +122,7 @@ class DataFrameRowsJoiner implements AutoCloseable {
         indexRequest.id(hit.getId());
         indexRequest.source(source);
         indexRequest.opType(DocWriteRequest.OpType.INDEX);
+        indexRequest.setParentTask(parentTaskId);
         return indexRequest;
     }
 

+ 5 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java

@@ -11,6 +11,7 @@ import org.elasticsearch.client.Client;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
@@ -93,6 +94,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
         task = mock(DataFrameAnalyticsTask.class);
         when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID);
         when(task.getStatsHolder()).thenReturn(new StatsHolder());
+        when(task.getParentTaskId()).thenReturn(new TaskId(""));
         dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandomBuilder(CONFIG_ID,
             false,
             OutlierDetectionTests.createRandom()).build();
@@ -133,6 +135,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
         inOrder.verify(task).isStopping();
         inOrder.verify(task).getAllocationId();
         inOrder.verify(task).isStopping();
+        inOrder.verify(task).getParentTaskId();
         inOrder.verify(task).getStatsHolder();
         inOrder.verify(task).isStopping();
         inOrder.verify(task).getAllocationId();
@@ -168,6 +171,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
         inOrder.verify(dataExtractor).collectDataSummary();
         inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
         inOrder.verify(process).isProcessAlive();
+        inOrder.verify(task).getParentTaskId();
         inOrder.verify(task).getStatsHolder();
         inOrder.verify(dataExtractor).getAllExtractedFields();
         inOrder.verify(executorServiceForProcess, times(2)).execute(any());  // 'processData' and 'processResults' threads
@@ -226,6 +230,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
         inOrder.verify(dataExtractor).collectDataSummary();
         inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
         inOrder.verify(process).isProcessAlive();
+        inOrder.verify(task).getParentTaskId();
         inOrder.verify(task).getStatsHolder();
         inOrder.verify(dataExtractor).getAllExtractedFields();
         // stop

+ 2 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
 import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
@@ -222,7 +223,7 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
     }
 
     private void givenProcessResults(List<RowResults> results) {
-        try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, dataExtractor, resultsPersisterService)) {
+        try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, new TaskId(""), dataExtractor, resultsPersisterService)) {
             results.forEach(joiner::processRowResults);
         }
     }