Browse Source

[ML-DataFrame] introduce an abstraction for checkpointing (#44900)

introduces an abstraction for how checkpointing and synchronization works, covering

 - retrieval of checkpoints
 - check for updates
 - retrieving stats information
Hendrik Muhs 6 years ago
parent
commit
185c583bc3
12 changed files with 591 additions and 295 deletions
  1. 0 5
      client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameTransformIT.java
  2. 0 4
      x-pack/plugin/data-frame/qa/multi-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformIT.java
  3. 30 23
      x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/action/TransportGetDataFrameTransformsStatsAction.java
  4. 72 0
      x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/CheckpointProvider.java
  5. 19 233
      x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/DataFrameTransformsCheckpointService.java
  6. 314 0
      x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/DefaultCheckpointProvider.java
  7. 90 0
      x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/TimeBasedCheckpointProvider.java
  8. 8 0
      x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexer.java
  9. 48 20
      x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformTask.java
  10. 4 4
      x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/checkpoint/DataFrameTransformCheckpointServiceNodeTests.java
  11. 4 4
      x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/checkpoint/DataFrameTransformsCheckpointServiceTests.java
  12. 2 2
      x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/ClientDataFrameIndexerTests.java

+ 0 - 5
client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameTransformIT.java

@@ -69,7 +69,6 @@ import static org.hamcrest.CoreMatchers.nullValue;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.hasKey;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
@@ -387,10 +386,6 @@ public class DataFrameTransformIT extends ESRestHighLevelClientTestCase {
             assertNotEquals(zeroIndexerStats, stateAndStats.getIndexerStats());
             assertThat(stateAndStats.getTaskState(),
                 is(oneOf(DataFrameTransformTaskState.STARTED, DataFrameTransformTaskState.STOPPED)));
-            assertNotNull(stateAndStats.getCheckpointingInfo().getNext().getCheckpointProgress());
-            assertThat(stateAndStats.getCheckpointingInfo().getNext().getCheckpointProgress().getPercentComplete(), equalTo(100.0));
-            assertThat(stateAndStats.getCheckpointingInfo().getNext().getCheckpointProgress().getTotalDocs(), greaterThan(0L));
-            assertThat(stateAndStats.getCheckpointingInfo().getNext().getCheckpointProgress().getRemainingDocs(), equalTo(0L));
             assertThat(stateAndStats.getReason(), is(nullValue()));
         });
     }

+ 0 - 4
x-pack/plugin/data-frame/qa/multi-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformIT.java

@@ -11,7 +11,6 @@ import org.elasticsearch.action.bulk.BulkRequest;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.client.RequestOptions;
-import org.elasticsearch.client.core.IndexerState;
 import org.elasticsearch.client.dataframe.transforms.DataFrameTransformConfig;
 import org.elasticsearch.client.dataframe.transforms.DataFrameTransformTaskState;
 import org.elasticsearch.client.dataframe.transforms.TimeSyncConfig;
@@ -64,9 +63,6 @@ public class DataFrameTransformIT extends DataFrameIntegTestCase {
 
         waitUntilCheckpoint(config.getId(), 1L);
 
-        // It will eventually be stopped
-        assertBusy(() -> assertThat(getDataFrameTransformStats(config.getId())
-                .getTransformsStats().get(0).getCheckpointingInfo().getNext().getIndexerState(), equalTo(IndexerState.STOPPED)));
         stopDataFrameTransform(config.getId());
 
         DataFrameTransformConfig storedConfig = getDataFrameTransform(config.getId()).getTransformConfigurations().get(0);

+ 30 - 23
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/action/TransportGetDataFrameTransformsStatsAction.java

@@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.dataframe.action.GetDataFrameTransformsStats
 import org.elasticsearch.xpack.core.dataframe.action.GetDataFrameTransformsStatsAction.Request;
 import org.elasticsearch.xpack.core.dataframe.action.GetDataFrameTransformsStatsAction.Response;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpointingInfo;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformState;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformStats;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformStoredDoc;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformTaskState;
@@ -85,26 +86,29 @@ public class TransportGetDataFrameTransformsStatsAction extends
         ClusterState state = clusterService.state();
         String nodeId = state.nodes().getLocalNode().getId();
         if (task.isCancelled() == false) {
-            transformsCheckpointService.getCheckpointStats(task.getTransformId(), task.getCheckpoint(), task.getInProgressCheckpoint(),
-                task.getState().getIndexerState(), task.getState().getPosition(), task.getProgress(),
-                ActionListener.wrap(checkpointStats -> listener.onResponse(new Response(
-                        Collections.singletonList(new DataFrameTransformStats(task.getTransformId(),
-                            task.getState().getTaskState(),
-                            task.getState().getReason(),
-                            null,
-                            task.getStats(),
-                            checkpointStats)),
-                        1L)),
-                    e -> listener.onResponse(new Response(
-                        Collections.singletonList(new DataFrameTransformStats(task.getTransformId(),
-                            task.getState().getTaskState(),
-                            task.getState().getReason(),
-                            null,
-                            task.getStats(),
-                            DataFrameTransformCheckpointingInfo.EMPTY)),
-                        1L,
-                        Collections.emptyList(),
-                        Collections.singletonList(new FailedNodeException(nodeId, "Failed to retrieve checkpointing info", e))))
+            DataFrameTransformState transformState = task.getState();
+            task.getCheckpointingInfo(transformsCheckpointService, ActionListener.wrap(
+                checkpointingInfo -> listener.onResponse(new Response(
+                    Collections.singletonList(new DataFrameTransformStats(task.getTransformId(),
+                        transformState.getTaskState(),
+                        transformState.getReason(),
+                        null,
+                        task.getStats(),
+                        checkpointingInfo)),
+                    1L)),
+                e -> {
+                    logger.warn("Failed to retrieve checkpointing info for transform [" + task.getTransformId() + "]", e);
+                    listener.onResponse(new Response(
+                    Collections.singletonList(new DataFrameTransformStats(task.getTransformId(),
+                        transformState.getTaskState(),
+                        transformState.getReason(),
+                        null,
+                        task.getStats(),
+                        DataFrameTransformCheckpointingInfo.EMPTY)),
+                    1L,
+                    Collections.emptyList(),
+                    Collections.singletonList(new FailedNodeException(nodeId, "Failed to retrieve checkpointing info", e))));
+                }
                 ));
         } else {
             listener.onResponse(new Response(Collections.emptyList(), 0L));
@@ -214,9 +218,12 @@ public class TransportGetDataFrameTransformsStatsAction extends
 
     private void populateSingleStoppedTransformStat(DataFrameTransformStoredDoc transform,
                                                     ActionListener<DataFrameTransformCheckpointingInfo> listener) {
-        transformsCheckpointService.getCheckpointStats(transform.getId(), transform.getTransformState().getCheckpoint(),
-            transform.getTransformState().getCheckpoint() + 1, transform.getTransformState().getIndexerState(),
-            transform.getTransformState().getPosition(), transform.getTransformState().getProgress(),
+        transformsCheckpointService.getCheckpointingInfo(
+            transform.getId(),
+            transform.getTransformState().getCheckpoint(),
+            transform.getTransformState().getIndexerState(),
+            transform.getTransformState().getPosition(),
+            transform.getTransformState().getProgress(),
             ActionListener.wrap(
                 listener::onResponse,
                 e -> {

+ 72 - 0
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/CheckpointProvider.java

@@ -0,0 +1,72 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.dataframe.checkpoint;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerPosition;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpoint;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpointingInfo;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformProgress;
+import org.elasticsearch.xpack.core.indexing.IndexerState;
+
+/**
+ * Interface for checkpoint creation, checking for changes and getting statistics about checkpoints
+ */
+public interface CheckpointProvider {
+
+    /**
+     * Create a new checkpoint
+     *
+     * @param lastCheckpoint the last checkpoint
+     * @param listener listener to call after inner request returned
+     */
+    void createNextCheckpoint(DataFrameTransformCheckpoint lastCheckpoint, ActionListener<DataFrameTransformCheckpoint> listener);
+
+    /**
+     * Determines whether the data frame needs updating
+     *
+     * @param lastCheckpoint the last checkpoint
+     * @param listener listener to send the result to
+     */
+    void sourceHasChanged(DataFrameTransformCheckpoint lastCheckpoint, ActionListener<Boolean> listener);
+
+    /**
+     * Get checkpoint statistics for a running data frame
+     *
+     * For running data frames most information is available in-memory.
+     *
+     * @param lastCheckpoint the last checkpoint
+     * @param nextCheckpoint the next checkpoint
+     * @param nextCheckpointIndexerState indexer state for the next checkpoint
+     * @param nextCheckpointPosition position for the next checkpoint
+     * @param nextCheckpointProgress progress for the next checkpoint
+     * @param listener listener to retrieve the result
+     */
+    void getCheckpointingInfo(DataFrameTransformCheckpoint lastCheckpoint,
+                              DataFrameTransformCheckpoint nextCheckpoint,
+                              IndexerState nextCheckpointIndexerState,
+                              DataFrameIndexerPosition nextCheckpointPosition,
+                              DataFrameTransformProgress nextCheckpointProgress,
+                              ActionListener<DataFrameTransformCheckpointingInfo> listener);
+
+    /**
+     * Get checkpoint statistics for a stopped data frame
+     *
+     * For stopped data frames we need to do lookups in the internal index.
+     *
+     * @param lastCheckpointNumber the last checkpoint number
+     * @param nextCheckpointIndexerState indexer state for the next checkpoint
+     * @param nextCheckpointPosition position for the next checkpoint
+     * @param nextCheckpointProgress progress for the next checkpoint
+     * @param listener listener to retrieve the result
+     */
+    void getCheckpointingInfo(long lastCheckpointNumber,
+                              IndexerState nextCheckpointIndexerState,
+                              DataFrameIndexerPosition nextCheckpointPosition,
+                              DataFrameTransformProgress nextCheckpointProgress,
+                              ActionListener<DataFrameTransformCheckpointingInfo> listener);
+}

+ 19 - 233
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/DataFrameTransformsCheckpointService.java

@@ -9,31 +9,15 @@ package org.elasticsearch.xpack.dataframe.checkpoint;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.admin.indices.get.GetIndexAction;
-import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
-import org.elasticsearch.action.admin.indices.stats.IndicesStatsAction;
-import org.elasticsearch.action.admin.indices.stats.IndicesStatsRequest;
-import org.elasticsearch.action.admin.indices.stats.ShardStats;
-import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.client.Client;
-import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerPosition;
-import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpoint;
-import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpointStats;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpointingInfo;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformProgress;
-import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig;
 import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig;
 import org.elasticsearch.xpack.core.indexing.IndexerState;
 import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
 
-import java.util.Arrays;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Set;
-import java.util.TreeMap;
-
 /**
  * DataFrameTransform Checkpoint Service
  *
@@ -44,35 +28,6 @@ import java.util.TreeMap;
  */
 public class DataFrameTransformsCheckpointService {
 
-    private static class Checkpoints {
-        long lastCheckpointNumber;
-        long nextCheckpointNumber;
-        IndexerState nextCheckpointIndexerState;
-        DataFrameIndexerPosition nextCheckpointPosition;
-        DataFrameTransformProgress nextCheckpointProgress;
-        DataFrameTransformCheckpoint lastCheckpoint = DataFrameTransformCheckpoint.EMPTY;
-        DataFrameTransformCheckpoint nextCheckpoint = DataFrameTransformCheckpoint.EMPTY;
-        DataFrameTransformCheckpoint sourceCheckpoint = DataFrameTransformCheckpoint.EMPTY;
-
-        Checkpoints(long lastCheckpointNumber, long nextCheckpointNumber, IndexerState nextCheckpointIndexerState,
-                    DataFrameIndexerPosition nextCheckpointPosition, DataFrameTransformProgress nextCheckpointProgress) {
-            this.lastCheckpointNumber = lastCheckpointNumber;
-            this.nextCheckpointNumber = nextCheckpointNumber;
-            this.nextCheckpointIndexerState = nextCheckpointIndexerState;
-            this.nextCheckpointPosition = nextCheckpointPosition;
-            this.nextCheckpointProgress = nextCheckpointProgress;
-        }
-
-        DataFrameTransformCheckpointingInfo buildInfo() {
-            return new DataFrameTransformCheckpointingInfo(
-                new DataFrameTransformCheckpointStats(lastCheckpointNumber, null, null, null,
-                    lastCheckpoint.getTimestamp(), lastCheckpoint.getTimeUpperBound()),
-                new DataFrameTransformCheckpointStats(nextCheckpointNumber, nextCheckpointIndexerState, nextCheckpointPosition,
-                    nextCheckpointProgress, nextCheckpoint.getTimestamp(), nextCheckpoint.getTimeUpperBound()),
-                DataFrameTransformCheckpoint.getBehind(lastCheckpoint, sourceCheckpoint));
-        }
-    }
-
     private static final Logger logger = LogManager.getLogger(DataFrameTransformsCheckpointService.class);
 
     private final Client client;
@@ -84,144 +39,37 @@ public class DataFrameTransformsCheckpointService {
         this.dataFrameTransformsConfigManager = dataFrameTransformsConfigManager;
     }
 
-    /**
-     * Get an unnumbered checkpoint. These checkpoints are for persistence but comparing state.
-     *
-     * @param transformConfig the @link{DataFrameTransformConfig}
-     * @param listener listener to call after inner request returned
-     */
-    public void getCheckpoint(DataFrameTransformConfig transformConfig, ActionListener<DataFrameTransformCheckpoint> listener) {
-        getCheckpoint(transformConfig, -1L, listener);
-    }
-
-    /**
-     * Get a checkpoint, used to store a checkpoint.
-     *
-     * @param transformConfig the @link{DataFrameTransformConfig}
-     * @param checkpoint the number of the checkpoint
-     * @param listener listener to call after inner request returned
-     */
-    public void getCheckpoint(DataFrameTransformConfig transformConfig, long checkpoint,
-                              ActionListener<DataFrameTransformCheckpoint> listener) {
-        long timestamp = System.currentTimeMillis();
-
-        // for time based synchronization
-        long timeUpperBound = getTimeStampForTimeBasedSynchronization(transformConfig.getSyncConfig(), timestamp);
-
-        // 1st get index to see the indexes the user has access to
-        GetIndexRequest getIndexRequest = new GetIndexRequest()
-            .indices(transformConfig.getSource().getIndex())
-            .features(new GetIndexRequest.Feature[0])
-            .indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN);
-
-        ClientHelper.executeWithHeadersAsync(transformConfig.getHeaders(), ClientHelper.DATA_FRAME_ORIGIN, client, GetIndexAction.INSTANCE,
-                getIndexRequest, ActionListener.wrap(getIndexResponse -> {
-                    Set<String> userIndices = new HashSet<>(Arrays.asList(getIndexResponse.getIndices()));
-                    // 2nd get stats request
-                    ClientHelper.executeAsyncWithOrigin(client,
-                        ClientHelper.DATA_FRAME_ORIGIN,
-                        IndicesStatsAction.INSTANCE,
-                        new IndicesStatsRequest()
-                            .indices(transformConfig.getSource().getIndex())
-                            .clear()
-                            .indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN),
-                        ActionListener.wrap(
-                            response -> {
-                                if (response.getFailedShards() != 0) {
-                                    listener.onFailure(
-                                        new CheckpointException("Source has [" + response.getFailedShards() + "] failed shards"));
-                                    return;
-                                }
+    public CheckpointProvider getCheckpointProvider(final DataFrameTransformConfig transformConfig) {
+        if (transformConfig.getSyncConfig() instanceof TimeSyncConfig) {
+            return new TimeBasedCheckpointProvider(client, dataFrameTransformsConfigManager, transformConfig);
+        }
 
-                                Map<String, long[]> checkpointsByIndex = extractIndexCheckPoints(response.getShards(), userIndices);
-                                listener.onResponse(new DataFrameTransformCheckpoint(transformConfig.getId(),
-                                    timestamp,
-                                    checkpoint,
-                                    checkpointsByIndex,
-                                    timeUpperBound));
-                            },
-                            e-> listener.onFailure(new CheckpointException("Failed to create checkpoint", e))
-                        ));
-                },
-                e -> listener.onFailure(new CheckpointException("Failed to create checkpoint", e))
-            ));
+        return new DefaultCheckpointProvider(client, dataFrameTransformsConfigManager, transformConfig);
     }
 
     /**
-     * Get checkpointing stats for a data frame
+     * Get checkpointing stats for a stopped data frame
      *
      * @param transformId The data frame task
-     * @param lastCheckpoint the last checkpoint
-     * @param nextCheckpoint the next checkpoint
+     * @param lastCheckpointNumber the last checkpoint
      * @param nextCheckpointIndexerState indexer state for the next checkpoint
      * @param nextCheckpointPosition position for the next checkpoint
      * @param nextCheckpointProgress progress for the next checkpoint
      * @param listener listener to retrieve the result
      */
-    public void getCheckpointStats(String transformId,
-                                   long lastCheckpoint,
-                                   long nextCheckpoint,
-                                   IndexerState nextCheckpointIndexerState,
-                                   DataFrameIndexerPosition nextCheckpointPosition,
-                                   DataFrameTransformProgress nextCheckpointProgress,
-                                   ActionListener<DataFrameTransformCheckpointingInfo> listener) {
-
-        Checkpoints checkpoints =
-            new Checkpoints(lastCheckpoint, nextCheckpoint, nextCheckpointIndexerState, nextCheckpointPosition, nextCheckpointProgress);
-
-        // <3> notify the user once we have the last checkpoint
-        ActionListener<DataFrameTransformCheckpoint> lastCheckpointListener = ActionListener.wrap(
-            lastCheckpointObj -> {
-                checkpoints.lastCheckpoint = lastCheckpointObj;
-                listener.onResponse(checkpoints.buildInfo());
-            },
-            e -> {
-                logger.debug("Failed to retrieve last checkpoint [" +
-                    lastCheckpoint + "] for data frame [" + transformId + "]", e);
-                listener.onFailure(new CheckpointException("Failure during last checkpoint info retrieval", e));
-            }
-        );
-
-        // <2> after the next checkpoint, get the last checkpoint
-        ActionListener<DataFrameTransformCheckpoint> nextCheckpointListener = ActionListener.wrap(
-            nextCheckpointObj -> {
-                checkpoints.nextCheckpoint = nextCheckpointObj;
-                if (lastCheckpoint != 0) {
-                    dataFrameTransformsConfigManager.getTransformCheckpoint(transformId,
-                        lastCheckpoint,
-                        lastCheckpointListener);
-                } else {
-                    lastCheckpointListener.onResponse(DataFrameTransformCheckpoint.EMPTY);
-                }
-            },
-            e -> {
-                logger.debug("Failed to retrieve next checkpoint [" +
-                    nextCheckpoint + "] for data frame [" + transformId + "]", e);
-                listener.onFailure(new CheckpointException("Failure during next checkpoint info retrieval", e));
-            }
-        );
-
-        // <1> after the source checkpoint, get the in progress checkpoint
-        ActionListener<DataFrameTransformCheckpoint> sourceCheckpointListener = ActionListener.wrap(
-            sourceCheckpoint -> {
-                checkpoints.sourceCheckpoint = sourceCheckpoint;
-                if (nextCheckpoint != 0) {
-                    dataFrameTransformsConfigManager.getTransformCheckpoint(transformId,
-                        nextCheckpoint,
-                        nextCheckpointListener);
-                } else {
-                    nextCheckpointListener.onResponse(DataFrameTransformCheckpoint.EMPTY);
-                }
-            },
-            e -> {
-                logger.debug("Failed to retrieve source checkpoint for data frame [" + transformId + "]", e);
-                listener.onFailure(new CheckpointException("Failure during source checkpoint info retrieval", e));
-            }
-        );
-
-        // <0> get the transform and the source, transient checkpoint
+    public void getCheckpointingInfo(final String transformId,
+                                     final long lastCheckpointNumber,
+                                     final IndexerState nextCheckpointIndexerState,
+                                     final DataFrameIndexerPosition nextCheckpointPosition,
+                                     final DataFrameTransformProgress nextCheckpointProgress,
+                                     final ActionListener<DataFrameTransformCheckpointingInfo> listener) {
+
+        // we need to retrieve the config first before we can defer the rest to the corresponding provider
         dataFrameTransformsConfigManager.getTransformConfiguration(transformId, ActionListener.wrap(
-            transformConfig -> getCheckpoint(transformConfig, sourceCheckpointListener),
+            transformConfig -> {
+                getCheckpointProvider(transformConfig).getCheckpointingInfo(lastCheckpointNumber, nextCheckpointIndexerState,
+                            nextCheckpointPosition, nextCheckpointProgress, listener);
+                },
             transformError -> {
                 logger.warn("Failed to retrieve configuration for data frame [" + transformId + "]", transformError);
                 listener.onFailure(new CheckpointException("Failed to retrieve configuration", transformError));
@@ -229,66 +77,4 @@ public class DataFrameTransformsCheckpointService {
         );
     }
 
-    private long getTimeStampForTimeBasedSynchronization(SyncConfig syncConfig, long timestamp) {
-        if (syncConfig instanceof TimeSyncConfig) {
-            TimeSyncConfig timeSyncConfig = (TimeSyncConfig) syncConfig;
-            return timestamp - timeSyncConfig.getDelay().millis();
-        }
-
-        return 0L;
-    }
-
-    static Map<String, long[]> extractIndexCheckPoints(ShardStats[] shards, Set<String> userIndices) {
-        Map<String, TreeMap<Integer, Long>> checkpointsByIndex = new TreeMap<>();
-
-        for (ShardStats shard : shards) {
-            String indexName = shard.getShardRouting().getIndexName();
-
-            if (userIndices.contains(indexName)) {
-                // SeqNoStats could be `null`, assume the global checkpoint to be 0 in this case
-                long globalCheckpoint = shard.getSeqNoStats() == null ? 0 : shard.getSeqNoStats().getGlobalCheckpoint();
-                if (checkpointsByIndex.containsKey(indexName)) {
-                    // we have already seen this index, just check/add shards
-                    TreeMap<Integer, Long> checkpoints = checkpointsByIndex.get(indexName);
-                    if (checkpoints.containsKey(shard.getShardRouting().getId())) {
-                        // there is already a checkpoint entry for this index/shard combination, check if they match
-                        if (checkpoints.get(shard.getShardRouting().getId()) != globalCheckpoint) {
-                            throw new CheckpointException("Global checkpoints mismatch for index [" + indexName + "] between shards of id ["
-                                    + shard.getShardRouting().getId() + "]");
-                        }
-                    } else {
-                        // 1st time we see this shard for this index, add the entry for the shard
-                        checkpoints.put(shard.getShardRouting().getId(), globalCheckpoint);
-                    }
-                } else {
-                    // 1st time we see this index, create an entry for the index and add the shard checkpoint
-                    checkpointsByIndex.put(indexName, new TreeMap<>());
-                    checkpointsByIndex.get(indexName).put(shard.getShardRouting().getId(), globalCheckpoint);
-                }
-            }
-        }
-
-        // checkpoint extraction is done in 2 steps:
-        // 1. GetIndexRequest to retrieve the indices the user has access to
-        // 2. IndicesStatsRequest to retrieve stats about indices
-        // between 1 and 2 indices could get deleted or created
-        if (logger.isDebugEnabled()) {
-            Set<String> userIndicesClone = new HashSet<>(userIndices);
-
-            userIndicesClone.removeAll(checkpointsByIndex.keySet());
-            if (userIndicesClone.isEmpty() == false) {
-                logger.debug("Original set of user indices contained more indexes [{}]", userIndicesClone);
-            }
-        }
-
-        // create the final structure
-        Map<String, long[]> checkpointsByIndexReduced = new TreeMap<>();
-
-        checkpointsByIndex.forEach((indexName, checkpoints) -> {
-            checkpointsByIndexReduced.put(indexName, checkpoints.values().stream().mapToLong(l -> l).toArray());
-        });
-
-        return checkpointsByIndexReduced;
-    }
-
 }

+ 314 - 0
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/DefaultCheckpointProvider.java

@@ -0,0 +1,314 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.dataframe.checkpoint;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.apache.logging.log4j.util.Supplier;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.admin.indices.get.GetIndexAction;
+import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
+import org.elasticsearch.action.admin.indices.stats.IndicesStatsAction;
+import org.elasticsearch.action.admin.indices.stats.IndicesStatsRequest;
+import org.elasticsearch.action.admin.indices.stats.ShardStats;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.xpack.core.ClientHelper;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerPosition;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpoint;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpointStats;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpointingInfo;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformProgress;
+import org.elasticsearch.xpack.core.indexing.IndexerState;
+import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+
+public class DefaultCheckpointProvider implements CheckpointProvider {
+
+    /**
+     * Builder for collecting checkpointing information for the purpose of _stats
+     */
+    private static class DataFrameTransformCheckpointingInfoBuilder {
+        private IndexerState nextCheckpointIndexerState;
+        private DataFrameIndexerPosition nextCheckpointPosition;
+        private DataFrameTransformProgress nextCheckpointProgress;
+        private DataFrameTransformCheckpoint lastCheckpoint;
+        private DataFrameTransformCheckpoint nextCheckpoint;
+        private DataFrameTransformCheckpoint sourceCheckpoint;
+
+        DataFrameTransformCheckpointingInfoBuilder() {
+        }
+
+        DataFrameTransformCheckpointingInfo build() {
+            if (lastCheckpoint == null) {
+                lastCheckpoint = DataFrameTransformCheckpoint.EMPTY;
+            }
+            if (nextCheckpoint == null) {
+                nextCheckpoint = DataFrameTransformCheckpoint.EMPTY;
+            }
+            if (sourceCheckpoint == null) {
+                sourceCheckpoint = DataFrameTransformCheckpoint.EMPTY;
+            }
+
+            // checkpointstats requires a non-negative checkpoint number
+            long lastCheckpointNumber = lastCheckpoint.getCheckpoint() > 0 ? lastCheckpoint.getCheckpoint() : 0;
+            long nextCheckpointNumber = nextCheckpoint.getCheckpoint() > 0 ? nextCheckpoint.getCheckpoint() : 0;
+
+            return new DataFrameTransformCheckpointingInfo(
+                new DataFrameTransformCheckpointStats(lastCheckpointNumber, null, null, null,
+                    lastCheckpoint.getTimestamp(), lastCheckpoint.getTimeUpperBound()),
+                new DataFrameTransformCheckpointStats(nextCheckpointNumber, nextCheckpointIndexerState, nextCheckpointPosition,
+                    nextCheckpointProgress, nextCheckpoint.getTimestamp(), nextCheckpoint.getTimeUpperBound()),
+                DataFrameTransformCheckpoint.getBehind(lastCheckpoint, sourceCheckpoint));
+        }
+
+        public DataFrameTransformCheckpointingInfoBuilder setLastCheckpoint(DataFrameTransformCheckpoint lastCheckpoint) {
+            this.lastCheckpoint = lastCheckpoint;
+            return this;
+        }
+
+        public DataFrameTransformCheckpointingInfoBuilder setNextCheckpoint(DataFrameTransformCheckpoint nextCheckpoint) {
+            this.nextCheckpoint = nextCheckpoint;
+            return this;
+        }
+
+        public DataFrameTransformCheckpointingInfoBuilder setSourceCheckpoint(DataFrameTransformCheckpoint sourceCheckpoint) {
+            this.sourceCheckpoint = sourceCheckpoint;
+            return this;
+        }
+
+        public DataFrameTransformCheckpointingInfoBuilder setNextCheckpointProgress(DataFrameTransformProgress nextCheckpointProgress) {
+            this.nextCheckpointProgress = nextCheckpointProgress;
+            return this;
+        }
+
+        public DataFrameTransformCheckpointingInfoBuilder setNextCheckpointPosition(DataFrameIndexerPosition nextCheckpointPosition) {
+            this.nextCheckpointPosition = nextCheckpointPosition;
+            return this;
+        }
+
+        public DataFrameTransformCheckpointingInfoBuilder setNextCheckpointIndexerState(IndexerState nextCheckpointIndexerState) {
+            this.nextCheckpointIndexerState = nextCheckpointIndexerState;
+            return this;
+        }
+
+    }
+
+    private static final Logger logger = LogManager.getLogger(DefaultCheckpointProvider.class);
+
+    protected final Client client;
+    protected final DataFrameTransformsConfigManager dataFrameTransformsConfigManager;
+    protected final DataFrameTransformConfig transformConfig;
+
+    public DefaultCheckpointProvider(final Client client,
+                                     final DataFrameTransformsConfigManager dataFrameTransformsConfigManager,
+                                     final DataFrameTransformConfig transformConfig) {
+        this.client = client;
+        this.dataFrameTransformsConfigManager = dataFrameTransformsConfigManager;
+        this.transformConfig = transformConfig;
+    }
+
+    @Override
+    public void sourceHasChanged(final DataFrameTransformCheckpoint lastCheckpoint, final ActionListener<Boolean> listener) {
+        listener.onResponse(false);
+    }
+
+    @Override
+    public void createNextCheckpoint(final DataFrameTransformCheckpoint lastCheckpoint,
+                              final ActionListener<DataFrameTransformCheckpoint> listener) {
+        final long timestamp = System.currentTimeMillis();
+        final long checkpoint = lastCheckpoint != null ? lastCheckpoint.getCheckpoint() + 1 : 1;
+
+        getIndexCheckpoints(ActionListener.wrap(checkpointsByIndex -> {
+            listener.onResponse(new DataFrameTransformCheckpoint(transformConfig.getId(), timestamp, checkpoint, checkpointsByIndex, 0L));
+        }, listener::onFailure));
+    }
+
+    protected void getIndexCheckpoints (ActionListener<Map<String, long[]>> listener) {
+     // 1st get index to see the indexes the user has access to
+        GetIndexRequest getIndexRequest = new GetIndexRequest()
+            .indices(transformConfig.getSource().getIndex())
+            .features(new GetIndexRequest.Feature[0])
+            .indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN);
+
+        ClientHelper.executeWithHeadersAsync(transformConfig.getHeaders(), ClientHelper.DATA_FRAME_ORIGIN, client, GetIndexAction.INSTANCE,
+                getIndexRequest, ActionListener.wrap(getIndexResponse -> {
+                    Set<String> userIndices = new HashSet<>(Arrays.asList(getIndexResponse.getIndices()));
+                    // 2nd get stats request
+                    ClientHelper.executeAsyncWithOrigin(client,
+                        ClientHelper.DATA_FRAME_ORIGIN,
+                        IndicesStatsAction.INSTANCE,
+                        new IndicesStatsRequest()
+                            .indices(transformConfig.getSource().getIndex())
+                            .clear()
+                            .indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN),
+                        ActionListener.wrap(
+                            response -> {
+                                if (response.getFailedShards() != 0) {
+                                    listener.onFailure(
+                                        new CheckpointException("Source has [" + response.getFailedShards() + "] failed shards"));
+                                    return;
+                                }
+
+                                listener.onResponse(extractIndexCheckPoints(response.getShards(), userIndices));
+                            },
+                            e-> listener.onFailure(new CheckpointException("Failed to create checkpoint", e))
+                        ));
+                },
+                e -> listener.onFailure(new CheckpointException("Failed to create checkpoint", e))
+            ));
+    }
+
+    static Map<String, long[]> extractIndexCheckPoints(ShardStats[] shards, Set<String> userIndices) {
+        Map<String, TreeMap<Integer, Long>> checkpointsByIndex = new TreeMap<>();
+
+        for (ShardStats shard : shards) {
+            String indexName = shard.getShardRouting().getIndexName();
+
+            if (userIndices.contains(indexName)) {
+                // SeqNoStats could be `null`, assume the global checkpoint to be 0 in this case
+                long globalCheckpoint = shard.getSeqNoStats() == null ? 0 : shard.getSeqNoStats().getGlobalCheckpoint();
+                if (checkpointsByIndex.containsKey(indexName)) {
+                    // we have already seen this index, just check/add shards
+                    TreeMap<Integer, Long> checkpoints = checkpointsByIndex.get(indexName);
+                    if (checkpoints.containsKey(shard.getShardRouting().getId())) {
+                        // there is already a checkpoint entry for this index/shard combination, check if they match
+                        if (checkpoints.get(shard.getShardRouting().getId()) != globalCheckpoint) {
+                            throw new CheckpointException("Global checkpoints mismatch for index [" + indexName + "] between shards of id ["
+                                    + shard.getShardRouting().getId() + "]");
+                        }
+                    } else {
+                        // 1st time we see this shard for this index, add the entry for the shard
+                        checkpoints.put(shard.getShardRouting().getId(), globalCheckpoint);
+                    }
+                } else {
+                    // 1st time we see this index, create an entry for the index and add the shard checkpoint
+                    checkpointsByIndex.put(indexName, new TreeMap<>());
+                    checkpointsByIndex.get(indexName).put(shard.getShardRouting().getId(), globalCheckpoint);
+                }
+            }
+        }
+
+        // checkpoint extraction is done in 2 steps:
+        // 1. GetIndexRequest to retrieve the indices the user has access to
+        // 2. IndicesStatsRequest to retrieve stats about indices
+        // between 1 and 2 indices could get deleted or created
+        if (logger.isDebugEnabled()) {
+            Set<String> userIndicesClone = new HashSet<>(userIndices);
+
+            userIndicesClone.removeAll(checkpointsByIndex.keySet());
+            if (userIndicesClone.isEmpty() == false) {
+                logger.debug("Original set of user indices contained more indexes [{}]",
+                        userIndicesClone);
+            }
+        }
+
+        // create the final structure
+        Map<String, long[]> checkpointsByIndexReduced = new TreeMap<>();
+
+        checkpointsByIndex.forEach((indexName, checkpoints) -> {
+            checkpointsByIndexReduced.put(indexName, checkpoints.values().stream().mapToLong(l -> l).toArray());
+        });
+
+        return checkpointsByIndexReduced;
+    }
+
+    @Override
+    public void getCheckpointingInfo(DataFrameTransformCheckpoint lastCheckpoint,
+                                   DataFrameTransformCheckpoint nextCheckpoint,
+                                   IndexerState nextCheckpointIndexerState,
+                                   DataFrameIndexerPosition nextCheckpointPosition,
+                                   DataFrameTransformProgress nextCheckpointProgress,
+                                   ActionListener<DataFrameTransformCheckpointingInfo> listener) {
+
+        DataFrameTransformCheckpointingInfoBuilder checkpointingInfoBuilder = new DataFrameTransformCheckpointingInfoBuilder();
+
+        checkpointingInfoBuilder.setLastCheckpoint(lastCheckpoint)
+            .setNextCheckpoint(nextCheckpoint)
+            .setNextCheckpointIndexerState(nextCheckpointIndexerState)
+            .setNextCheckpointPosition(nextCheckpointPosition)
+            .setNextCheckpointProgress(nextCheckpointProgress);
+
+        long timestamp = System.currentTimeMillis();
+
+        getIndexCheckpoints(ActionListener.wrap(checkpointsByIndex -> {
+            checkpointingInfoBuilder.setSourceCheckpoint(
+                    new DataFrameTransformCheckpoint(transformConfig.getId(), timestamp, -1L, checkpointsByIndex, 0L));
+            listener.onResponse(checkpointingInfoBuilder.build());
+        }, listener::onFailure));
+    }
+
+    @Override
+    public void getCheckpointingInfo(long lastCheckpointNumber, IndexerState nextCheckpointIndexerState,
+            DataFrameIndexerPosition nextCheckpointPosition, DataFrameTransformProgress nextCheckpointProgress,
+            ActionListener<DataFrameTransformCheckpointingInfo> listener) {
+
+        DataFrameTransformCheckpointingInfoBuilder checkpointingInfoBuilder = new DataFrameTransformCheckpointingInfoBuilder();
+
+        checkpointingInfoBuilder.setNextCheckpointIndexerState(nextCheckpointIndexerState)
+            .setNextCheckpointPosition(nextCheckpointPosition)
+            .setNextCheckpointProgress(nextCheckpointProgress);
+
+        long timestamp = System.currentTimeMillis();
+
+        // <3> got the source checkpoint, notify the user
+        ActionListener<Map<String, long[]>> checkpointsByIndexListener = ActionListener.wrap(
+                checkpointsByIndex -> {
+                    checkpointingInfoBuilder.setSourceCheckpoint(
+                        new DataFrameTransformCheckpoint(transformConfig.getId(), timestamp, -1L, checkpointsByIndex, 0L));
+                    listener.onResponse(checkpointingInfoBuilder.build());
+                },
+                e -> {
+                    logger.debug((Supplier<?>) () -> new ParameterizedMessage(
+                            "Failed to retrieve source checkpoint for data frame [{}]", transformConfig.getId()), e);
+                    listener.onFailure(new CheckpointException("Failure during source checkpoint info retrieval", e));
+                }
+            );
+
+        // <2> got the next checkpoint, get the source checkpoint
+        ActionListener<DataFrameTransformCheckpoint> nextCheckpointListener = ActionListener.wrap(
+                nextCheckpointObj -> {
+                    checkpointingInfoBuilder.setNextCheckpoint(nextCheckpointObj);
+                    getIndexCheckpoints(checkpointsByIndexListener);
+                },
+                e -> {
+                    logger.debug((Supplier<?>) () -> new ParameterizedMessage(
+                            "Failed to retrieve next checkpoint [{}] for data frame [{}]", lastCheckpointNumber + 1,
+                            transformConfig.getId()), e);
+                    listener.onFailure(new CheckpointException("Failure during next checkpoint info retrieval", e));
+                }
+            );
+
+        // <1> got last checkpoint, get the next checkpoint
+        ActionListener<DataFrameTransformCheckpoint> lastCheckpointListener = ActionListener.wrap(
+            lastCheckpointObj -> {
+                checkpointingInfoBuilder.lastCheckpoint = lastCheckpointObj;
+                dataFrameTransformsConfigManager.getTransformCheckpoint(transformConfig.getId(), lastCheckpointNumber + 1,
+                        nextCheckpointListener);
+            },
+            e -> {
+                logger.debug((Supplier<?>) () -> new ParameterizedMessage(
+                        "Failed to retrieve last checkpoint [{}] for data frame [{}]", lastCheckpointNumber,
+                        transformConfig.getId()), e);
+                listener.onFailure(new CheckpointException("Failure during last checkpoint info retrieval", e));
+            }
+        );
+
+        if (lastCheckpointNumber != 0) {
+            dataFrameTransformsConfigManager.getTransformCheckpoint(transformConfig.getId(), lastCheckpointNumber, lastCheckpointListener);
+        } else {
+            getIndexCheckpoints(checkpointsByIndexListener);
+        }
+    }
+}

+ 90 - 0
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/TimeBasedCheckpointProvider.java

@@ -0,0 +1,90 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.dataframe.checkpoint;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.search.SearchAction;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.index.query.BoolQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.RangeQueryBuilder;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.xpack.core.ClientHelper;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpoint;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig;
+import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig;
+import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
+
+public class TimeBasedCheckpointProvider extends DefaultCheckpointProvider {
+
+    private static final Logger logger = LogManager.getLogger(TimeBasedCheckpointProvider.class);
+
+    private final TimeSyncConfig timeSyncConfig;
+
+    TimeBasedCheckpointProvider(final Client client,
+                                final DataFrameTransformsConfigManager dataFrameTransformsConfigManager,
+                                final DataFrameTransformConfig transformConfig) {
+        super(client, dataFrameTransformsConfigManager, transformConfig);
+        timeSyncConfig = (TimeSyncConfig) transformConfig.getSyncConfig();
+    }
+
+    @Override
+    public void sourceHasChanged(DataFrameTransformCheckpoint lastCheckpoint,
+            ActionListener<Boolean> listener) {
+
+        final long timestamp = getTime();
+
+        SearchRequest searchRequest = new SearchRequest(transformConfig.getSource().getIndex())
+                .allowPartialSearchResults(false)
+                .indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN);
+        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
+                .size(0)
+                // we only want to know if there is at least 1 new document
+                .trackTotalHitsUpTo(1);
+
+        QueryBuilder queryBuilder = transformConfig.getSource().getQueryConfig().getQuery();
+        BoolQueryBuilder filteredQuery = new BoolQueryBuilder().
+                filter(queryBuilder).
+                filter(new RangeQueryBuilder(timeSyncConfig.getField()).
+                        gte(lastCheckpoint.getTimeUpperBound()).
+                        lt(timestamp - timeSyncConfig.getDelay().millis()).format("epoch_millis"));
+
+        sourceBuilder.query(filteredQuery);
+        searchRequest.source(sourceBuilder);
+
+        logger.trace("query for changes based on time: {}", sourceBuilder);
+
+        ClientHelper.executeWithHeadersAsync(transformConfig.getHeaders(), ClientHelper.DATA_FRAME_ORIGIN, client, SearchAction.INSTANCE,
+                searchRequest, ActionListener.wrap(r -> {
+                    listener.onResponse(r.getHits().getTotalHits().value > 0L);
+                }, listener::onFailure));
+    }
+
+    @Override
+    public void createNextCheckpoint(final DataFrameTransformCheckpoint lastCheckpoint,
+            final ActionListener<DataFrameTransformCheckpoint> listener) {
+        final long timestamp = getTime();
+        final long checkpoint = lastCheckpoint != null ? lastCheckpoint.getCheckpoint() + 1 : 1;
+
+        // for time based synchronization
+        long timeUpperBound = timestamp - timeSyncConfig.getDelay().millis();
+
+        getIndexCheckpoints(ActionListener.wrap(checkpointsByIndex -> {
+            listener.onResponse(
+                    new DataFrameTransformCheckpoint(transformConfig.getId(), timestamp, checkpoint, checkpointsByIndex, timeUpperBound));
+        }, listener::onFailure));
+    }
+
+    // for the purpose of testing
+    long getTime() {
+        return System.currentTimeMillis();
+    }
+}

+ 8 - 0
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexer.java

@@ -131,6 +131,14 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<DataFrameInd
         return progress;
     }
 
+    public DataFrameTransformCheckpoint getLastCheckpoint() {
+        return lastCheckpoint;
+    }
+
+    public DataFrameTransformCheckpoint getNextCheckpoint() {
+        return nextCheckpoint;
+    }
+
     /**
      * Request a checkpoint
      */

+ 48 - 20
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformTask.java

@@ -35,6 +35,7 @@ import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerPositio
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransform;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpoint;
+import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpointingInfo;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformProgress;
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformState;
@@ -44,6 +45,7 @@ import org.elasticsearch.xpack.core.dataframe.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.indexing.IndexerState;
 import org.elasticsearch.xpack.core.scheduler.SchedulerEngine;
 import org.elasticsearch.xpack.core.scheduler.SchedulerEngine.Event;
+import org.elasticsearch.xpack.dataframe.checkpoint.CheckpointProvider;
 import org.elasticsearch.xpack.dataframe.checkpoint.DataFrameTransformsCheckpointService;
 import org.elasticsearch.xpack.dataframe.notifications.DataFrameAuditor;
 import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
@@ -176,6 +178,36 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
         return currentCheckpoint.get();
     }
 
+    public void getCheckpointingInfo(DataFrameTransformsCheckpointService transformsCheckpointService,
+            ActionListener<DataFrameTransformCheckpointingInfo> listener) {
+        ClientDataFrameIndexer indexer = getIndexer();
+        if (indexer == null) {
+            transformsCheckpointService.getCheckpointingInfo(
+                    transform.getId(),
+                    currentCheckpoint.get(),
+                    initialIndexerState,
+                    initialPosition,
+                    null,
+                    listener);
+            return;
+        }
+        indexer.getCheckpointProvider().getCheckpointingInfo(
+                indexer.getLastCheckpoint(),
+                indexer.getNextCheckpoint(),
+                indexer.getState(),
+                indexer.getPosition(),
+                indexer.getProgress(),
+                listener);
+    }
+
+    public DataFrameTransformCheckpoint getLastCheckpoint() {
+        return getIndexer().getLastCheckpoint();
+    }
+
+    public DataFrameTransformCheckpoint getNextCheckpoint() {
+        return getIndexer().getNextCheckpoint();
+    }
+
     /**
      * Get the in-progress checkpoint
      *
@@ -429,9 +461,11 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
         }
 
         ClientDataFrameIndexer build(DataFrameTransformTask parentTask) {
+            CheckpointProvider checkpointProvider = transformsCheckpointService.getCheckpointProvider(transformConfig);
+
             return new ClientDataFrameIndexer(this.transformId,
                 this.transformsConfigManager,
-                this.transformsCheckpointService,
+                checkpointProvider,
                 new AtomicReference<>(this.indexerState),
                 this.initialPosition,
                 this.client,
@@ -521,7 +555,7 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
         private long logCount = 0;
         private final Client client;
         private final DataFrameTransformsConfigManager transformsConfigManager;
-        private final DataFrameTransformsCheckpointService transformsCheckpointService;
+        private final CheckpointProvider checkpointProvider;
         private final String transformId;
         private final DataFrameTransformTask transformTask;
         private final AtomicInteger failureCount;
@@ -531,7 +565,7 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
 
         ClientDataFrameIndexer(String transformId,
                                DataFrameTransformsConfigManager transformsConfigManager,
-                               DataFrameTransformsCheckpointService transformsCheckpointService,
+                               CheckpointProvider checkpointProvider,
                                AtomicReference<IndexerState> initialState,
                                DataFrameIndexerPosition initialPosition,
                                Client client,
@@ -557,8 +591,8 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
                 nextCheckpoint);
             this.transformId = ExceptionsHelper.requireNonNull(transformId, "transformId");
             this.transformsConfigManager = ExceptionsHelper.requireNonNull(transformsConfigManager, "transformsConfigManager");
-            this.transformsCheckpointService = ExceptionsHelper.requireNonNull(transformsCheckpointService,
-                "transformsCheckpointService");
+            this.checkpointProvider = ExceptionsHelper.requireNonNull(checkpointProvider, "checkpointProvider");
+
             this.client = ExceptionsHelper.requireNonNull(client, "client");
             this.transformTask = parentTask;
             this.failureCount = new AtomicInteger(0);
@@ -595,6 +629,10 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
             return transformId;
         }
 
+        public CheckpointProvider getCheckpointProvider() {
+            return checkpointProvider;
+        }
+
         @Override
         public synchronized boolean maybeTriggerAsyncJob(long now) {
             if (transformTask.taskState.get() == DataFrameTransformTaskState.FAILED) {
@@ -742,7 +780,7 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
                 // super.onFinish() fortunately ignores the listener
                 super.onFinish(listener);
                 long checkpoint = transformTask.currentCheckpoint.getAndIncrement();
-                lastCheckpoint = nextCheckpoint;
+                lastCheckpoint = getNextCheckpoint();
                 nextCheckpoint = null;
                 // Reset our failure count as we have finished and may start again with a new checkpoint
                 failureCount.set(0);
@@ -804,9 +842,7 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
 
         @Override
         protected void createCheckpoint(ActionListener<DataFrameTransformCheckpoint> listener) {
-            transformsCheckpointService.getCheckpoint(transformConfig,
-                transformTask.currentCheckpoint.get() + 1,
-                ActionListener.wrap(
+            checkpointProvider.createNextCheckpoint(getLastCheckpoint(), ActionListener.wrap(
                     checkpoint -> transformsConfigManager.putTransformCheckpoint(checkpoint,
                         ActionListener.wrap(
                             putCheckPointResponse -> listener.onResponse(checkpoint),
@@ -826,18 +862,10 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
             }
 
             CountDownLatch latch = new CountDownLatch(1);
-
             SetOnce<Boolean> changed = new SetOnce<>();
-            transformsCheckpointService.getCheckpoint(transformConfig, new LatchedActionListener<>(ActionListener.wrap(
-                    cp -> {
-                        long behind = DataFrameTransformCheckpoint.getBehind(lastCheckpoint, cp);
-                        if (behind > 0) {
-                            logger.debug("Detected changes, dest is {} operations behind the source", behind);
-                            changed.set(true);
-                        } else {
-                            changed.set(false);
-                        }
-                    }, e -> {
+
+            checkpointProvider.sourceHasChanged(getLastCheckpoint(),
+                    new LatchedActionListener<>(ActionListener.wrap(changed::set, e -> {
                         changed.set(false);
                         logger.warn(
                                 "Failed to detect changes for data frame transform [" + transformId + "], skipping update till next check",

+ 4 - 4
x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/checkpoint/DataFrameTransformCheckpointServiceNodeTests.java

@@ -6,10 +6,10 @@
 
 package org.elasticsearch.xpack.dataframe.checkpoint;
 
-import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
 import org.elasticsearch.action.admin.indices.get.GetIndexResponse;
 import org.elasticsearch.action.admin.indices.stats.CommonStats;
@@ -203,7 +203,7 @@ public class DataFrameTransformCheckpointServiceNodeTests extends DataFrameSingl
                 30L);
 
         assertAsync(listener ->
-                transformsCheckpointService.getCheckpointStats(transformId, 1, 2, IndexerState.STARTED, position, progress, listener),
+                transformsCheckpointService.getCheckpointingInfo(transformId, 1, IndexerState.STARTED, position, progress, listener),
             checkpointInfo, null, null);
 
         mockClientForCheckpointing.setShardStats(createShardStats(createCheckPointMap(transformId, 10, 50, 33)));
@@ -212,7 +212,7 @@ public class DataFrameTransformCheckpointServiceNodeTests extends DataFrameSingl
                 new DataFrameTransformCheckpointStats(2, IndexerState.INDEXING, position, progress, timestamp + 100L, 0L),
                 63L);
         assertAsync(listener ->
-                transformsCheckpointService.getCheckpointStats(transformId, 1, 2, IndexerState.INDEXING, position, progress, listener),
+                transformsCheckpointService.getCheckpointingInfo(transformId, 1, IndexerState.INDEXING, position, progress, listener),
             checkpointInfo, null, null);
 
         // same as current
@@ -222,7 +222,7 @@ public class DataFrameTransformCheckpointServiceNodeTests extends DataFrameSingl
                 new DataFrameTransformCheckpointStats(2, IndexerState.STOPPING, position, progress, timestamp + 100L, 0L),
                 0L);
         assertAsync(listener ->
-                transformsCheckpointService.getCheckpointStats(transformId, 1, 2, IndexerState.STOPPING, position, progress, listener),
+                transformsCheckpointService.getCheckpointingInfo(transformId, 1, IndexerState.STOPPING, position, progress, listener),
             checkpointInfo, null, null);
     }
 

+ 4 - 4
x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/checkpoint/DataFrameTransformsCheckpointServiceTests.java

@@ -53,7 +53,7 @@ public class DataFrameTransformsCheckpointServiceTests extends ESTestCase {
 
         ShardStats[] shardStatsArray = createRandomShardStats(expectedCheckpoints, indices, false, false, false);
 
-        Map<String, long[]> checkpoints = DataFrameTransformsCheckpointService.extractIndexCheckPoints(shardStatsArray, indices);
+        Map<String, long[]> checkpoints = DefaultCheckpointProvider.extractIndexCheckPoints(shardStatsArray, indices);
 
         assertEquals(expectedCheckpoints.size(), checkpoints.size());
         assertEquals(expectedCheckpoints.keySet(), checkpoints.keySet());
@@ -70,7 +70,7 @@ public class DataFrameTransformsCheckpointServiceTests extends ESTestCase {
 
         ShardStats[] shardStatsArray = createRandomShardStats(expectedCheckpoints, indices, false, false, true);
 
-        Map<String, long[]> checkpoints = DataFrameTransformsCheckpointService.extractIndexCheckPoints(shardStatsArray, indices);
+        Map<String, long[]> checkpoints = DefaultCheckpointProvider.extractIndexCheckPoints(shardStatsArray, indices);
 
         assertEquals(expectedCheckpoints.size(), checkpoints.size());
         assertEquals(expectedCheckpoints.keySet(), checkpoints.keySet());
@@ -87,7 +87,7 @@ public class DataFrameTransformsCheckpointServiceTests extends ESTestCase {
 
         ShardStats[] shardStatsArray = createRandomShardStats(expectedCheckpoints, indices, true, false, false);
 
-        Map<String, long[]> checkpoints = DataFrameTransformsCheckpointService.extractIndexCheckPoints(shardStatsArray, indices);
+        Map<String, long[]> checkpoints = DefaultCheckpointProvider.extractIndexCheckPoints(shardStatsArray, indices);
 
         assertEquals(expectedCheckpoints.size(), checkpoints.size());
         assertEquals(expectedCheckpoints.keySet(), checkpoints.keySet());
@@ -106,7 +106,7 @@ public class DataFrameTransformsCheckpointServiceTests extends ESTestCase {
 
         // fail
         CheckpointException e = expectThrows(CheckpointException.class,
-                () -> DataFrameTransformsCheckpointService.extractIndexCheckPoints(shardStatsArray, indices));
+                () -> DefaultCheckpointProvider.extractIndexCheckPoints(shardStatsArray, indices));
 
         assertThat(e.getMessage(), containsString("Global checkpoints mismatch"));
     }

+ 2 - 2
x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/ClientDataFrameIndexerTests.java

@@ -16,7 +16,7 @@ import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheck
 import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig;
 import org.elasticsearch.xpack.core.indexing.IndexerState;
 import org.elasticsearch.xpack.core.scheduler.SchedulerEngine;
-import org.elasticsearch.xpack.dataframe.checkpoint.DataFrameTransformsCheckpointService;
+import org.elasticsearch.xpack.dataframe.checkpoint.CheckpointProvider;
 import org.elasticsearch.xpack.dataframe.notifications.DataFrameAuditor;
 import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
 
@@ -48,7 +48,7 @@ public class ClientDataFrameIndexerTests extends ESTestCase {
             Collections.emptyMap());
         DataFrameTransformTask.ClientDataFrameIndexer indexer = new DataFrameTransformTask.ClientDataFrameIndexer(randomAlphaOfLength(10),
             mock(DataFrameTransformsConfigManager.class),
-            mock(DataFrameTransformsCheckpointService.class),
+            mock(CheckpointProvider.class),
             new AtomicReference<>(IndexerState.STOPPED),
             null,
             mock(Client.class),