ソースを参照

[ML][Transforms] protecting doSaveState with optimistic concurrency (#46156)

* [ML][Transforms] protecting doSaveState with optimistic concurrency

* task code cleanup
Benjamin Trent 6 年 前
コミット
91f7a0e3cd

+ 4 - 3
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/action/TransportUpdateDataFrameTransformAction.java

@@ -49,6 +49,7 @@ import org.elasticsearch.xpack.core.security.support.Exceptions;
 import org.elasticsearch.xpack.dataframe.notifications.DataFrameAuditor;
 import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
 import org.elasticsearch.xpack.dataframe.persistence.DataframeIndex;
+import org.elasticsearch.xpack.dataframe.persistence.SeqNoPrimaryTermAndIndex;
 import org.elasticsearch.xpack.dataframe.transforms.SourceDestValidator;
 import org.elasticsearch.xpack.dataframe.transforms.pivot.Pivot;
 
@@ -138,7 +139,7 @@ public class TransportUpdateDataFrameTransformAction extends TransportMasterNode
     private void handlePrivsResponse(String username,
                                      Request request,
                                      DataFrameTransformConfig config,
-                                     DataFrameTransformsConfigManager.SeqNoPrimaryTermAndIndex seqNoPrimaryTermAndIndex,
+                                     SeqNoPrimaryTermAndIndex seqNoPrimaryTermAndIndex,
                                      ClusterState clusterState,
                                      HasPrivilegesResponse privilegesResponse,
                                      ActionListener<Response> listener) {
@@ -161,7 +162,7 @@ public class TransportUpdateDataFrameTransformAction extends TransportMasterNode
     private void validateAndUpdateDataFrame(Request request,
                                             ClusterState clusterState,
                                             DataFrameTransformConfig config,
-                                            DataFrameTransformsConfigManager.SeqNoPrimaryTermAndIndex seqNoPrimaryTermAndIndex,
+                                            SeqNoPrimaryTermAndIndex seqNoPrimaryTermAndIndex,
                                             ActionListener<Response> listener) {
         try {
             SourceDestValidator.validate(config, clusterState, indexNameExpressionResolver, request.isDeferValidation());
@@ -186,7 +187,7 @@ public class TransportUpdateDataFrameTransformAction extends TransportMasterNode
     }
     private void updateDataFrame(Request request,
                                  DataFrameTransformConfig config,
-                                 DataFrameTransformsConfigManager.SeqNoPrimaryTermAndIndex seqNoPrimaryTermAndIndex,
+                                 SeqNoPrimaryTermAndIndex seqNoPrimaryTermAndIndex,
                                  ClusterState clusterState,
                                  ActionListener<Response> listener) {
 

+ 26 - 34
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/persistence/DataFrameTransformsConfigManager.java

@@ -227,7 +227,8 @@ public class DataFrameTransformsConfigManager {
                 .id(DataFrameTransformConfig.documentId(transformConfig.getId()))
                 .source(source);
             if (seqNoPrimaryTermAndIndex != null) {
-                indexRequest.setIfSeqNo(seqNoPrimaryTermAndIndex.seqNo).setIfPrimaryTerm(seqNoPrimaryTermAndIndex.primaryTerm);
+                indexRequest.setIfSeqNo(seqNoPrimaryTermAndIndex.getSeqNo())
+                    .setIfPrimaryTerm(seqNoPrimaryTermAndIndex.getPrimaryTerm());
             }
             executeAsyncWithOrigin(client, DATA_FRAME_ORIGIN, IndexAction.INSTANCE, indexRequest, ActionListener.wrap(r -> {
                 listener.onResponse(true);
@@ -433,21 +434,31 @@ public class DataFrameTransformsConfigManager {
         }));
     }
 
-    public void putOrUpdateTransformStoredDoc(DataFrameTransformStoredDoc stats, ActionListener<Boolean> listener) {
+    public void putOrUpdateTransformStoredDoc(DataFrameTransformStoredDoc stats,
+                                              SeqNoPrimaryTermAndIndex seqNoPrimaryTermAndIndex,
+                                              ActionListener<SeqNoPrimaryTermAndIndex> listener) {
         try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
             XContentBuilder source = stats.toXContent(builder, new ToXContent.MapParams(TO_XCONTENT_PARAMS));
 
             IndexRequest indexRequest = new IndexRequest(DataFrameInternalIndex.LATEST_INDEX_NAME)
                 .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
                 .id(DataFrameTransformStoredDoc.documentId(stats.getId()))
-                .opType(DocWriteRequest.OpType.INDEX)
                 .source(source);
-
+            if (seqNoPrimaryTermAndIndex != null &&
+                seqNoPrimaryTermAndIndex.getIndex().equals(DataFrameInternalIndex.LATEST_INDEX_NAME)) {
+                indexRequest.opType(DocWriteRequest.OpType.INDEX)
+                    .setIfSeqNo(seqNoPrimaryTermAndIndex.getSeqNo())
+                    .setIfPrimaryTerm(seqNoPrimaryTermAndIndex.getPrimaryTerm());
+            } else {
+                // If the index is NOT the latest or we are null, that means we have not created this doc before
+                // so, it should be a create option without the seqNo and primaryTerm set
+                indexRequest.opType(DocWriteRequest.OpType.CREATE);
+            }
             executeAsyncWithOrigin(client, DATA_FRAME_ORIGIN, IndexAction.INSTANCE, indexRequest, ActionListener.wrap(
-                r -> listener.onResponse(true),
+                r -> listener.onResponse(SeqNoPrimaryTermAndIndex.fromIndexResponse(r)),
                 e -> listener.onFailure(new RuntimeException(
-                    DataFrameMessages.getMessage(DataFrameMessages.DATA_FRAME_FAILED_TO_PERSIST_STATS, stats.getId()),
-                    e))
+                        DataFrameMessages.getMessage(DataFrameMessages.DATA_FRAME_FAILED_TO_PERSIST_STATS, stats.getId()),
+                        e))
             ));
         } catch (IOException e) {
             // not expected to happen but for the sake of completeness
@@ -457,13 +468,15 @@ public class DataFrameTransformsConfigManager {
         }
     }
 
-    public void getTransformStoredDoc(String transformId, ActionListener<DataFrameTransformStoredDoc> resultListener) {
+    public void getTransformStoredDoc(String transformId,
+                                      ActionListener<Tuple<DataFrameTransformStoredDoc, SeqNoPrimaryTermAndIndex>> resultListener) {
         QueryBuilder queryBuilder = QueryBuilders.termQuery("_id", DataFrameTransformStoredDoc.documentId(transformId));
         SearchRequest searchRequest = client.prepareSearch(DataFrameInternalIndex.INDEX_NAME_PATTERN)
             .setQuery(queryBuilder)
             // use sort to get the last
             .addSort("_index", SortOrder.DESC)
             .setSize(1)
+            .seqNoAndPrimaryTerm(true)
             .request();
 
         executeAsyncWithOrigin(client, DATA_FRAME_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.<SearchResponse>wrap(
@@ -473,11 +486,14 @@ public class DataFrameTransformsConfigManager {
                         DataFrameMessages.getMessage(DataFrameMessages.DATA_FRAME_UNKNOWN_TRANSFORM_STATS, transformId)));
                     return;
                 }
-                BytesReference source = searchResponse.getHits().getHits()[0].getSourceRef();
+                SearchHit searchHit = searchResponse.getHits().getHits()[0];
+                BytesReference source = searchHit.getSourceRef();
                 try (InputStream stream = source.streamInput();
                     XContentParser parser = XContentFactory.xContent(XContentType.JSON)
                         .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
-                    resultListener.onResponse(DataFrameTransformStoredDoc.fromXContent(parser));
+                    resultListener.onResponse(
+                        Tuple.tuple(DataFrameTransformStoredDoc.fromXContent(parser),
+                        SeqNoPrimaryTermAndIndex.fromSearchHit(searchHit)));
                 } catch (Exception e) {
                     logger.error(DataFrameMessages.getMessage(DataFrameMessages.FAILED_TO_PARSE_TRANSFORM_STATISTICS_CONFIGURATION,
                             transformId), e);
@@ -595,28 +611,4 @@ public class DataFrameTransformsConfigManager {
         }
         return new Tuple<>(status, reason);
     }
-
-    public static class SeqNoPrimaryTermAndIndex {
-        private final long seqNo;
-        private final long primaryTerm;
-        private final String index;
-
-        public SeqNoPrimaryTermAndIndex(long seqNo, long primaryTerm, String index) {
-            this.seqNo = seqNo;
-            this.primaryTerm = primaryTerm;
-            this.index = index;
-        }
-
-        public long getSeqNo() {
-            return seqNo;
-        }
-
-        public long getPrimaryTerm() {
-            return primaryTerm;
-        }
-
-        public String getIndex() {
-            return index;
-        }
-    }
 }

+ 73 - 0
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/persistence/SeqNoPrimaryTermAndIndex.java

@@ -0,0 +1,73 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.dataframe.persistence;
+
+import org.elasticsearch.action.index.IndexResponse;
+import org.elasticsearch.search.SearchHit;
+
+import java.util.Objects;
+
+/**
+ * Simple class to keep track of information needed for optimistic concurrency
+ */
+public class SeqNoPrimaryTermAndIndex {
+    private final long seqNo;
+    private final long primaryTerm;
+    private final String index;
+
+    public static SeqNoPrimaryTermAndIndex fromSearchHit(SearchHit hit) {
+        return new SeqNoPrimaryTermAndIndex(hit.getSeqNo(), hit.getPrimaryTerm(), hit.getIndex());
+    }
+
+    public static SeqNoPrimaryTermAndIndex fromIndexResponse(IndexResponse response) {
+        return new SeqNoPrimaryTermAndIndex(response.getSeqNo(), response.getPrimaryTerm(), response.getIndex());
+    }
+
+    SeqNoPrimaryTermAndIndex(long seqNo, long primaryTerm, String index) {
+        this.seqNo = seqNo;
+        this.primaryTerm = primaryTerm;
+        this.index = index;
+    }
+
+    public long getSeqNo() {
+        return seqNo;
+    }
+
+    public long getPrimaryTerm() {
+        return primaryTerm;
+    }
+
+    public String getIndex() {
+        return index;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(seqNo, primaryTerm, index);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) {
+            return true;
+        }
+
+        if (obj == null || getClass() != obj.getClass()) {
+            return false;
+        }
+
+        SeqNoPrimaryTermAndIndex other = (SeqNoPrimaryTermAndIndex) obj;
+        return Objects.equals(seqNo, other.seqNo)
+            && Objects.equals(primaryTerm, other.primaryTerm)
+            && Objects.equals(index, other.index);
+    }
+
+    @Override
+    public String toString() {
+        return "{seqNo=" + seqNo + ",primaryTerm=" + primaryTerm + ",index=" + index + "}";
+    }
+}

+ 11 - 5
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformPersistentTasksExecutor.java

@@ -20,6 +20,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.persistent.AllocatedPersistentTask;
 import org.elasticsearch.persistent.PersistentTaskState;
@@ -42,6 +43,7 @@ import org.elasticsearch.xpack.dataframe.checkpoint.DataFrameTransformsCheckpoin
 import org.elasticsearch.xpack.dataframe.notifications.DataFrameAuditor;
 import org.elasticsearch.xpack.dataframe.persistence.DataFrameInternalIndex;
 import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
+import org.elasticsearch.xpack.dataframe.persistence.SeqNoPrimaryTermAndIndex;
 import org.elasticsearch.xpack.dataframe.transforms.pivot.SchemaUtil;
 
 import java.util.ArrayList;
@@ -189,8 +191,12 @@ public class DataFrameTransformPersistentTasksExecutor extends PersistentTasksEx
         // <3> Set the previous stats (if they exist), initialize the indexer, start the task (If it is STOPPED)
         // Since we don't create the task until `_start` is called, if we see that the task state is stopped, attempt to start
         // Schedule execution regardless
-        ActionListener<DataFrameTransformStoredDoc> transformStatsActionListener = ActionListener.wrap(
-            stateAndStats -> {
+        ActionListener<Tuple<DataFrameTransformStoredDoc, SeqNoPrimaryTermAndIndex>> transformStatsActionListener = ActionListener.wrap(
+            stateAndStatsAndSeqNoPrimaryTermAndIndex -> {
+                DataFrameTransformStoredDoc stateAndStats = stateAndStatsAndSeqNoPrimaryTermAndIndex.v1();
+                SeqNoPrimaryTermAndIndex seqNoPrimaryTermAndIndex = stateAndStatsAndSeqNoPrimaryTermAndIndex.v2();
+                // Since we have not set the value for this yet, it SHOULD be null
+                buildTask.updateSeqNoPrimaryTermAndIndex(null, seqNoPrimaryTermAndIndex);
                 logger.trace("[{}] initializing state and stats: [{}]", transformId, stateAndStats.toString());
                 indexerBuilder.setInitialStats(stateAndStats.getTransformStats())
                     .setInitialPosition(stateAndStats.getTransformState().getPosition())
@@ -217,10 +223,10 @@ public class DataFrameTransformPersistentTasksExecutor extends PersistentTasksEx
                     String msg = DataFrameMessages.getMessage(DataFrameMessages.FAILED_TO_LOAD_TRANSFORM_STATE, transformId);
                     logger.error(msg, error);
                     markAsFailed(buildTask, msg);
+                } else {
+                    logger.trace("[{}] No stats found (new transform), starting the task", transformId);
+                    startTask(buildTask, indexerBuilder, null, startTaskListener);
                 }
-
-                logger.trace("[{}] No stats found(new transform), starting the task", transformId);
-                startTask(buildTask, indexerBuilder, null, startTaskListener);
             }
         );
 

+ 52 - 51
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformTask.java

@@ -22,6 +22,7 @@ import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.client.Client;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.logging.LoggerMessageFormat;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.unit.TimeValue;
@@ -54,6 +55,7 @@ import org.elasticsearch.xpack.dataframe.checkpoint.CheckpointProvider;
 import org.elasticsearch.xpack.dataframe.checkpoint.DataFrameTransformsCheckpointService;
 import org.elasticsearch.xpack.dataframe.notifications.DataFrameAuditor;
 import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
+import org.elasticsearch.xpack.dataframe.persistence.SeqNoPrimaryTermAndIndex;
 import org.elasticsearch.xpack.dataframe.transforms.pivot.AggregationResultUtils;
 
 import java.time.Instant;
@@ -98,6 +100,7 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
 
     private final AtomicReference<DataFrameTransformTaskState> taskState;
     private final AtomicReference<String> stateReason;
+    private final AtomicReference<SeqNoPrimaryTermAndIndex> seqNoPrimaryTermAndIndex = new AtomicReference<>(null);
     // the checkpoint of this data frame, storing the checkpoint until data indexing from source to dest is _complete_
     // Note: Each indexer run creates a new future checkpoint which becomes the current checkpoint only after the indexer run finished
     private final AtomicLong currentCheckpoint;
@@ -216,31 +219,6 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
                 ));
     }
 
-    public DataFrameTransformCheckpoint getLastCheckpoint() {
-        return getIndexer().getLastCheckpoint();
-    }
-
-    public DataFrameTransformCheckpoint getNextCheckpoint() {
-        return getIndexer().getNextCheckpoint();
-    }
-
-    /**
-     * Get the in-progress checkpoint
-     *
-     * @return checkpoint in progress or 0 if task/indexer is not active
-     */
-    public long getInProgressCheckpoint() {
-        if (getIndexer() == null) {
-            return 0;
-        } else {
-            return indexer.get().getState().equals(IndexerState.INDEXING) ? currentCheckpoint.get() + 1L : 0;
-        }
-    }
-
-    public synchronized void setTaskStateStopped() {
-        taskState.set(DataFrameTransformTaskState.STOPPED);
-    }
-
     /**
      * Start the background indexer and set the task's state to started
      * @param startingCheckpoint Set the current checkpoint to this value. If null the
@@ -270,6 +248,15 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
                 msg));
             return;
         }
+        // If we are already in a `STARTED` state, we should not attempt to call `.start` on the indexer again.
+        if (taskState.get() == DataFrameTransformTaskState.STARTED) {
+            listener.onFailure(new ElasticsearchStatusException(
+                "Cannot start transform [{}] as it is already STARTED.",
+                RestStatus.CONFLICT,
+                getTransformId()
+            ));
+            return;
+        }
         final IndexerState newState = getIndexer().start();
         if (Arrays.stream(RUNNING_STATES).noneMatch(newState::equals)) {
             listener.onFailure(new ElasticsearchException("Cannot start task for data frame transform [{}], because state was [{}]",
@@ -325,7 +312,7 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
             return;
         }
 
-        if (getIndexer().getState() == IndexerState.STOPPED) {
+        if (getIndexer().getState() == IndexerState.STOPPED || getIndexer().getState() == IndexerState.STOPPING) {
             return;
         }
 
@@ -339,10 +326,11 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
 
         IndexerState state = getIndexer().stop();
         stateReason.set(null);
-        // We just don't want it to be failed if it is failed
-        // Either we are running, and the STATE is already started or failed
-        // doSaveState should transfer the state to STOPPED when it needs to.
-        taskState.set(DataFrameTransformTaskState.STARTED);
+        // No reason to keep it in the potentially failed state.
+        // Since we have called `stop` against the indexer, we have no more fear of triggering again.
+        // But, since `doSaveState` is asynchronous, it is best to set the state as STARTED so that another `start` call cannot be
+        // executed while we are wrapping up.
+        taskState.compareAndSet(DataFrameTransformTaskState.FAILED, DataFrameTransformTaskState.STARTED);
         if (state == IndexerState.STOPPED) {
             getIndexer().onStop();
             getIndexer().doSaveState(state, getIndexer().getPosition(), () -> {});
@@ -361,8 +349,10 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
             return;
         }
 
-        if (taskState.get() == DataFrameTransformTaskState.FAILED) {
-            logger.debug("[{}] schedule was triggered for transform but task is failed. Ignoring trigger.", getTransformId());
+        if (taskState.get() == DataFrameTransformTaskState.FAILED || taskState.get() == DataFrameTransformTaskState.STOPPED) {
+            logger.debug("[{}] schedule was triggered for transform but task is [{}]. Ignoring trigger.",
+                getTransformId(),
+                taskState.get());
             return;
         }
 
@@ -379,7 +369,7 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
 
         // if it runs for the 1st time we just do it, if not we check for changes
         if (currentCheckpoint.get() == 0) {
-            logger.debug("Trigger initial run.");
+            logger.debug("[{}] trigger initial run.", getTransformId());
             getIndexer().maybeTriggerAsyncJob(System.currentTimeMillis());
         } else if (getIndexer().isContinuous()) {
             getIndexer().maybeTriggerAsyncJob(System.currentTimeMillis());
@@ -395,17 +385,6 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
         markAsCompleted();
     }
 
-    public DataFrameTransformProgress getProgress() {
-        if (indexer.get() == null) {
-            return null;
-        }
-        DataFrameTransformProgress indexerProgress = indexer.get().getProgress();
-        if (indexerProgress == null) {
-            return null;
-        }
-        return new DataFrameTransformProgress(indexerProgress);
-    }
-
     void persistStateToClusterState(DataFrameTransformState state,
                                     ActionListener<PersistentTasksCustomMetaData.PersistentTask<?>> listener) {
         updatePersistentTaskState(state, ActionListener.wrap(
@@ -520,6 +499,19 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
         indexer.set(indexerBuilder.build(this));
     }
 
+    void updateSeqNoPrimaryTermAndIndex(SeqNoPrimaryTermAndIndex expectedValue, SeqNoPrimaryTermAndIndex newValue) {
+        boolean updated = seqNoPrimaryTermAndIndex.compareAndSet(expectedValue, newValue);
+        // This should never happen. We ONLY ever update this value if at initialization or we just finished updating the document
+        // famous last words...
+        assert updated :
+            "[" + getTransformId() + "] unexpected change to seqNoPrimaryTermAndIndex.";
+    }
+
+    @Nullable
+    SeqNoPrimaryTermAndIndex getSeqNoPrimaryTermAndIndex() {
+        return seqNoPrimaryTermAndIndex.get();
+    }
+
     static class ClientDataFrameIndexerBuilder {
         private Client client;
         private DataFrameTransformsConfigManager transformsConfigManager;
@@ -879,6 +871,7 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
                 next.run();
                 return;
             }
+
             // This means that the indexer was triggered to discover changes, found none, and exited early.
             // If the state is `STOPPED` this means that DataFrameTransformTask#stop was called while we were checking for changes.
             // Allow the stop call path to continue
@@ -886,12 +879,6 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
                 next.run();
                 return;
             }
-            // If we are `STOPPED` on a `doSaveState` call, that indicates we transitioned to `STOPPED` from `STOPPING`
-            // OR we called `doSaveState` manually as the indexer was not actively running.
-            // Since we save the state to an index, we should make sure that our task state is in parity with the indexer state
-            if (indexerState.equals(IndexerState.STOPPED)) {
-                transformTask.setTaskStateStopped();
-            }
 
             DataFrameTransformTaskState taskState = transformTask.taskState.get();
 
@@ -899,13 +886,21 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
                 && transformTask.currentCheckpoint.get() == 1
                 && this.isContinuous() == false) {
                 // set both to stopped so they are persisted as such
-                taskState = DataFrameTransformTaskState.STOPPED;
                 indexerState = IndexerState.STOPPED;
 
                 auditor.info(transformConfig.getId(), "Data frame finished indexing all data, initiating stop");
                 logger.info("[{}] data frame transform finished indexing all data, initiating stop.", transformConfig.getId());
             }
 
+            // If we are `STOPPED` on a `doSaveState` call, that indicates we transitioned to `STOPPED` from `STOPPING`
+            // OR we called `doSaveState` manually as the indexer was not actively running.
+            // Since we save the state to an index, we should make sure that our task state is in parity with the indexer state
+            if (indexerState.equals(IndexerState.STOPPED)) {
+                // We don't want adjust the stored taskState because as soon as it is `STOPPED` a user could call
+                // .start again.
+                taskState = DataFrameTransformTaskState.STOPPED;
+            }
+
             final DataFrameTransformState state = new DataFrameTransformState(
                 taskState,
                 indexerState,
@@ -915,13 +910,18 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
                 getProgress());
             logger.debug("[{}] updating persistent state of transform to [{}].", transformConfig.getId(), state.toString());
 
+            // This could be `null` but the putOrUpdateTransformStoredDoc handles that case just fine
+            SeqNoPrimaryTermAndIndex seqNoPrimaryTermAndIndex = transformTask.getSeqNoPrimaryTermAndIndex();
+
             // Persist the current state and stats in the internal index. The interval of this method being
             // called is controlled by AsyncTwoPhaseIndexer#onBulkResponse which calls doSaveState every so
             // often when doing bulk indexing calls or at the end of one indexing run.
             transformsConfigManager.putOrUpdateTransformStoredDoc(
                     new DataFrameTransformStoredDoc(transformId, state, getStats()),
+                    seqNoPrimaryTermAndIndex,
                     ActionListener.wrap(
                             r -> {
+                                transformTask.updateSeqNoPrimaryTermAndIndex(seqNoPrimaryTermAndIndex, r);
                                 // for auto stop shutdown the task
                                 if (state.getTaskState().equals(DataFrameTransformTaskState.STOPPED)) {
                                     transformTask.shutdown();
@@ -989,6 +989,7 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
                 nextCheckpoint = null;
                 // Reset our failure count as we have finished and may start again with a new checkpoint
                 failureCount.set(0);
+                transformTask.stateReason.set(null);
 
                 // With bucket_selector we could have read all the buckets and completed the transform
                 // but not "see" all the buckets since they were filtered out. Consequently, progress would

+ 3 - 1
x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/DataFrameSingleNodeTestCase.java

@@ -23,6 +23,8 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Consumer;
 
+import static org.hamcrest.Matchers.equalTo;
+
 public abstract class DataFrameSingleNodeTestCase extends ESSingleNodeTestCase {
 
     @Before
@@ -56,7 +58,7 @@ public abstract class DataFrameSingleNodeTestCase extends ESSingleNodeTestCase {
             if (expected == null) {
                 fail("expected an exception but got a response");
             } else {
-                assertEquals(r, expected);
+                assertThat(r, equalTo(expected));
             }
             if (onAnswer != null) {
                 onAnswer.accept(r);

+ 38 - 8
x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/persistence/DataFrameTransformsConfigManagerTests.java

@@ -261,23 +261,50 @@ public class DataFrameTransformsConfigManagerTests extends DataFrameSingleNodeTe
         String transformId = "transform_test_stored_doc_create_read_update";
 
         DataFrameTransformStoredDoc storedDocs = DataFrameTransformStoredDocTests.randomDataFrameTransformStoredDoc(transformId);
+        SeqNoPrimaryTermAndIndex firstIndex = new SeqNoPrimaryTermAndIndex(0, 1, DataFrameInternalIndex.LATEST_INDEX_NAME);
 
-        assertAsync(listener -> transformsConfigManager.putOrUpdateTransformStoredDoc(storedDocs, listener), Boolean.TRUE, null, null);
-        assertAsync(listener -> transformsConfigManager.getTransformStoredDoc(transformId, listener), storedDocs, null, null);
+        assertAsync(listener -> transformsConfigManager.putOrUpdateTransformStoredDoc(storedDocs, null, listener),
+            firstIndex,
+            null,
+            null);
+        assertAsync(listener -> transformsConfigManager.getTransformStoredDoc(transformId, listener),
+            Tuple.tuple(storedDocs, firstIndex),
+            null,
+            null);
 
+        SeqNoPrimaryTermAndIndex secondIndex = new SeqNoPrimaryTermAndIndex(1, 1, DataFrameInternalIndex.LATEST_INDEX_NAME);
         DataFrameTransformStoredDoc updated = DataFrameTransformStoredDocTests.randomDataFrameTransformStoredDoc(transformId);
-        assertAsync(listener -> transformsConfigManager.putOrUpdateTransformStoredDoc(updated, listener), Boolean.TRUE, null, null);
-        assertAsync(listener -> transformsConfigManager.getTransformStoredDoc(transformId, listener), updated, null, null);
+        assertAsync(listener -> transformsConfigManager.putOrUpdateTransformStoredDoc(updated, firstIndex, listener),
+            secondIndex,
+            null,
+            null);
+        assertAsync(listener -> transformsConfigManager.getTransformStoredDoc(transformId, listener),
+            Tuple.tuple(updated, secondIndex),
+            null,
+            null);
+
+        assertAsync(listener -> transformsConfigManager.putOrUpdateTransformStoredDoc(updated, firstIndex, listener),
+            (SeqNoPrimaryTermAndIndex)null,
+            r -> fail("did not fail with version conflict."),
+            e -> assertThat(
+                e.getMessage(),
+                equalTo("Failed to persist data frame statistics for transform [transform_test_stored_doc_create_read_update]"))
+            );
     }
 
     public void testGetStoredDocMultiple() throws InterruptedException {
         int numStats = randomIntBetween(10, 15);
         List<DataFrameTransformStoredDoc> expectedDocs = new ArrayList<>();
         for (int i=0; i<numStats; i++) {
+            SeqNoPrimaryTermAndIndex initialSeqNo =
+                new SeqNoPrimaryTermAndIndex(i, 1, DataFrameInternalIndex.LATEST_INDEX_NAME);
             DataFrameTransformStoredDoc stat =
-                    DataFrameTransformStoredDocTests.randomDataFrameTransformStoredDoc(randomAlphaOfLength(6));
+                    DataFrameTransformStoredDocTests.randomDataFrameTransformStoredDoc(randomAlphaOfLength(6) + i);
             expectedDocs.add(stat);
-            assertAsync(listener -> transformsConfigManager.putOrUpdateTransformStoredDoc(stat, listener), Boolean.TRUE, null, null);
+            assertAsync(listener -> transformsConfigManager.putOrUpdateTransformStoredDoc(stat, null, listener),
+                initialSeqNo,
+                null,
+                null);
         }
 
         // remove one of the put docs so we don't retrieve all
@@ -338,8 +365,11 @@ public class DataFrameTransformsConfigManagerTests extends DataFrameSingleNodeTe
             client().index(request).actionGet();
         }
 
-        assertAsync(listener -> transformsConfigManager.putOrUpdateTransformStoredDoc(dataFrameTransformStoredDoc, listener),
-            true,
+        // Put when referencing the old index should create the doc in the new index, even if we have seqNo|primaryTerm info
+        assertAsync(listener -> transformsConfigManager.putOrUpdateTransformStoredDoc(dataFrameTransformStoredDoc,
+            new SeqNoPrimaryTermAndIndex(3, 1, oldIndex),
+            listener),
+            new SeqNoPrimaryTermAndIndex(0, 1, DataFrameInternalIndex.LATEST_INDEX_NAME),
             null,
             null);
 

+ 56 - 0
x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/persistence/SeqNoPrimaryTermAndIndexTests.java

@@ -0,0 +1,56 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.dataframe.persistence;
+
+import org.elasticsearch.action.index.IndexResponse;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.SearchShardTarget;
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class SeqNoPrimaryTermAndIndexTests extends ESTestCase {
+
+    public void testEquals() {
+        for (int i = 0; i < 30; i++) {
+            long seqNo = randomLongBetween(-2, 10_000);
+            long primaryTerm = randomLongBetween(-2, 10_000);
+            String index = randomAlphaOfLength(10);
+            SeqNoPrimaryTermAndIndex first = new SeqNoPrimaryTermAndIndex(seqNo, primaryTerm, index);
+            SeqNoPrimaryTermAndIndex second = new SeqNoPrimaryTermAndIndex(seqNo, primaryTerm, index);
+            assertThat(first, equalTo(second));
+        }
+    }
+
+    public void testFromSearchHit() {
+        SearchHit searchHit = new SearchHit(1);
+        long seqNo = randomLongBetween(-2, 10_000);
+        long primaryTerm = randomLongBetween(-2, 10_000);
+        String index = randomAlphaOfLength(10);
+        searchHit.setSeqNo(seqNo);
+        searchHit.setPrimaryTerm(primaryTerm);
+        searchHit.shard(new SearchShardTarget("anynode", new ShardId(index, randomAlphaOfLength(10), 1), null, null));
+        assertThat(SeqNoPrimaryTermAndIndex.fromSearchHit(searchHit), equalTo(new SeqNoPrimaryTermAndIndex(seqNo, primaryTerm, index)));
+    }
+
+    public void testFromIndexResponse() {
+        long seqNo = randomLongBetween(-2, 10_000);
+        long primaryTerm = randomLongBetween(-2, 10_000);
+        String index = randomAlphaOfLength(10);
+        IndexResponse indexResponse = new IndexResponse(new ShardId(index, randomAlphaOfLength(10), 1),
+            "_doc",
+            "asdf",
+            seqNo,
+            primaryTerm,
+            1,
+        randomBoolean());
+
+        assertThat(SeqNoPrimaryTermAndIndex.fromIndexResponse(indexResponse),
+            equalTo(new SeqNoPrimaryTermAndIndex(seqNo, primaryTerm, index)));
+    }
+}