Browse Source

[ML][Data Frame] fix progress measurement for continuous transforms (#43838)

* [ML][Data Frame] fix progress measurement for continuous transforms

* Update DataFrameIndexer.java
Benjamin Trent 6 years ago
parent
commit
fcbf4aa88c

+ 84 - 0
x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameGetAndGetStatsIT.java

@@ -13,9 +13,11 @@ import org.junit.After;
 import org.junit.Before;
 
 import java.io.IOException;
+import java.time.Instant;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.TimeUnit;
 
 import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
 import static org.hamcrest.Matchers.equalTo;
@@ -202,4 +204,86 @@ public class DataFrameGetAndGetStatsIT extends DataFrameRestTestCase {
             assertThat("percent_complete is not 100.0", progress.get("percent_complete"), equalTo(100.0));
         }
     }
+
+    @SuppressWarnings("unchecked")
+    public void testGetProgressResetWithContinuous() throws Exception {
+        String transformId = "pivot_progress_continuous";
+        String transformDest = transformId + "_idx";
+        String transformSrc = "reviews_cont_pivot_test";
+        createReviewsIndex(transformSrc);
+        final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId, null);
+        String config = "{ \"dest\": {\"index\":\"" + transformDest + "\"},"
+            + " \"source\": {\"index\":\"" + transformSrc + "\"},"
+            + " \"sync\": {\"time\":{\"field\": \"timestamp\", \"delay\": \"1s\"}},"
+            + " \"pivot\": {"
+            + "   \"group_by\": {"
+            + "     \"reviewer\": {"
+            + "       \"terms\": {"
+            + "         \"field\": \"user_id\""
+            + " } } },"
+            + "   \"aggregations\": {"
+            + "     \"avg_rating\": {"
+            + "       \"avg\": {"
+            + "         \"field\": \"stars\""
+            + " } } } }"
+            + "}";
+
+        createDataframeTransformRequest.setJsonEntity(config);
+
+        Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
+        assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
+        startAndWaitForContinuousTransform(transformId, transformDest, null);
+
+        Request getRequest = createRequestWithAuth("GET", DATAFRAME_ENDPOINT + transformId + "/_stats", null);
+        Map<String, Object> stats = entityAsMap(client().performRequest(getRequest));
+        List<Map<String, Object>> transformsStats = (List<Map<String, Object>>)XContentMapValues.extractValue("transforms", stats);
+        assertEquals(1, transformsStats.size());
+        // Verify that the transform's progress
+        for (Map<String, Object> transformStats : transformsStats) {
+            Map<String, Object> progress = (Map<String, Object>)XContentMapValues.extractValue("state.progress", transformStats);
+            assertThat("total_docs is not 1000", progress.get("total_docs"), equalTo(1000));
+            assertThat("docs_remaining is not 0", progress.get("docs_remaining"), equalTo(0));
+            assertThat("percent_complete is not 100.0", progress.get("percent_complete"), equalTo(100.0));
+        }
+
+        // add more docs to verify total_docs gets updated with continuous
+        int numDocs = 10;
+        final StringBuilder bulk = new StringBuilder();
+        long now = Instant.now().toEpochMilli() - 1_000;
+        for (int i = 0; i < numDocs; i++) {
+            bulk.append("{\"index\":{\"_index\":\"" + transformSrc + "\"}}\n")
+                .append("{\"user_id\":\"")
+                .append("user_")
+                // Doing only new users so that there is a deterministic number of docs for progress
+                .append(randomFrom(42, 47, 113))
+                .append("\",\"business_id\":\"")
+                .append("business_")
+                .append(10)
+                .append("\",\"stars\":")
+                .append(5)
+                .append(",\"timestamp\":")
+                .append(now)
+                .append("}\n");
+        }
+        bulk.append("\r\n");
+        final Request bulkRequest = new Request("POST", "/_bulk");
+        bulkRequest.addParameter("refresh", "true");
+        bulkRequest.setJsonEntity(bulk.toString());
+        client().performRequest(bulkRequest);
+
+        waitForDataFrameCheckpoint(transformId, 2L);
+
+        assertBusy(() -> {
+            Map<String, Object> statsResponse = entityAsMap(client().performRequest(getRequest));
+            List<Map<String, Object>> contStats = (List<Map<String, Object>>)XContentMapValues.extractValue("transforms", statsResponse);
+            assertEquals(1, contStats.size());
+            // add more docs to verify total_docs is the number of new docs added to the index
+            for (Map<String, Object> transformStats : contStats) {
+                Map<String, Object> progress = (Map<String, Object>)XContentMapValues.extractValue("state.progress", transformStats);
+                assertThat("total_docs is not 10", progress.get("total_docs"), equalTo(numDocs));
+                assertThat("docs_remaining is not 0", progress.get("docs_remaining"), equalTo(0));
+                assertThat("percent_complete is not 100.0", progress.get("percent_complete"), equalTo(100.0));
+            }
+        }, 60, TimeUnit.SECONDS);
+    }
 }

+ 7 - 3
x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformProgressIT.java

@@ -135,7 +135,9 @@ public class DataFrameTransformProgressIT extends ESRestTestCase {
             null);
 
         final RestHighLevelClient restClient = new TestRestHighLevelClient();
-        SearchResponse response = restClient.search(TransformProgressGatherer.getSearchRequest(config), RequestOptions.DEFAULT);
+        SearchResponse response = restClient.search(
+            TransformProgressGatherer.getSearchRequest(config, config.getSource().getQueryConfig().getQuery()),
+            RequestOptions.DEFAULT);
 
         DataFrameTransformProgress progress =
             TransformProgressGatherer.searchResponseToDataFrameTransformProgressFunction().apply(response);
@@ -156,7 +158,8 @@ public class DataFrameTransformProgressIT extends ESRestTestCase {
             pivotConfig,
             null);
 
-        response = restClient.search(TransformProgressGatherer.getSearchRequest(config), RequestOptions.DEFAULT);
+        response = restClient.search(TransformProgressGatherer.getSearchRequest(config, config.getSource().getQueryConfig().getQuery()),
+            RequestOptions.DEFAULT);
         progress = TransformProgressGatherer.searchResponseToDataFrameTransformProgressFunction().apply(response);
 
         assertThat(progress.getTotalDocs(), equalTo(35L));
@@ -174,7 +177,8 @@ public class DataFrameTransformProgressIT extends ESRestTestCase {
             pivotConfig,
             null);
 
-        response = restClient.search(TransformProgressGatherer.getSearchRequest(config), RequestOptions.DEFAULT);
+        response = restClient.search(TransformProgressGatherer.getSearchRequest(config, config.getSource().getQueryConfig().getQuery()),
+            RequestOptions.DEFAULT);
         progress = TransformProgressGatherer.searchResponseToDataFrameTransformProgressFunction().apply(response);
 
         assertThat(progress.getTotalDocs(), equalTo(0L));

+ 33 - 44
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexer.java

@@ -117,29 +117,7 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<Map<String,
             if (pageSize == 0) {
                 pageSize = pivot.getInitialPageSize();
             }
-
-            // if run for the 1st time, create checkpoint
-            if (initialRun()) {
-                createCheckpoint(ActionListener.wrap(cp -> {
-                    DataFrameTransformCheckpoint oldCheckpoint = inProgressOrLastCheckpoint;
-
-                    if (oldCheckpoint.isEmpty()) {
-                        // this is the 1st run, accept the new in progress checkpoint and go on
-                        inProgressOrLastCheckpoint = cp;
-                        listener.onResponse(null);
-                    } else {
-                        logger.debug ("Getting changes from {} to {}", oldCheckpoint.getTimeUpperBound(), cp.getTimeUpperBound());
-
-                        getChangedBuckets(oldCheckpoint, cp, ActionListener.wrap(changedBuckets -> {
-                            inProgressOrLastCheckpoint = cp;
-                            this.changedBuckets = changedBuckets;
-                            listener.onResponse(null);
-                        }, listener::onFailure));
-                    }
-                }, listener::onFailure));
-            } else {
-                listener.onResponse(null);
-            }
+            listener.onResponse(null);
         } catch (Exception e) {
             listener.onFailure(e);
         }
@@ -151,8 +129,8 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<Map<String,
 
     @Override
     protected void onFinish(ActionListener<Void> listener) {
-        // reset the page size, so we do not memorize a low page size forever, the pagesize will be re-calculated on start
-        pageSize = 0;
+        // reset the page size, so we do not memorize a low page size forever
+        pageSize = pivot.getInitialPageSize();
         // reset the changed bucket to free memory
         changedBuckets = null;
     }
@@ -218,13 +196,7 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<Map<String,
         });
     }
 
-    @Override
-    protected SearchRequest buildSearchRequest() {
-        SearchRequest searchRequest = new SearchRequest(getConfig().getSource().getIndex());
-        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
-        sourceBuilder.aggregation(pivot.buildAggregation(getPosition(), pageSize));
-        sourceBuilder.size(0);
-
+    protected QueryBuilder buildFilterQuery() {
         QueryBuilder pivotQueryBuilder = getConfig().getSource().getQueryConfig().getQuery();
 
         DataFrameTransformConfig config = getConfig();
@@ -233,9 +205,9 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<Map<String,
                 throw new RuntimeException("in progress checkpoint not found");
             }
 
-            BoolQueryBuilder filteredQuery = new BoolQueryBuilder().
-                    filter(pivotQueryBuilder).
-                    filter(config.getSyncConfig().getRangeQuery(inProgressOrLastCheckpoint));
+            BoolQueryBuilder filteredQuery = new BoolQueryBuilder()
+                .filter(pivotQueryBuilder)
+                .filter(config.getSyncConfig().getRangeQuery(inProgressOrLastCheckpoint));
 
             if (changedBuckets != null && changedBuckets.isEmpty() == false) {
                 QueryBuilder pivotFilter = pivot.filterBuckets(changedBuckets);
@@ -245,11 +217,19 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<Map<String,
             }
 
             logger.trace("running filtered query: {}", filteredQuery);
-            sourceBuilder.query(filteredQuery);
+            return filteredQuery;
         } else {
-            sourceBuilder.query(pivotQueryBuilder);
+            return pivotQueryBuilder;
         }
+    }
 
+    @Override
+    protected SearchRequest buildSearchRequest() {
+        SearchRequest searchRequest = new SearchRequest(getConfig().getSource().getIndex());
+        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
+            .aggregation(pivot.buildAggregation(getPosition(), pageSize))
+            .size(0)
+            .query(buildFilterQuery());
         searchRequest.source(sourceBuilder);
         return searchRequest;
     }
@@ -292,15 +272,24 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<Map<String,
         return true;
     }
 
-    private void getChangedBuckets(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint,
-            ActionListener<Map<String, Set<String>>> listener) {
-
+    protected void getChangedBuckets(DataFrameTransformCheckpoint oldCheckpoint,
+                                     DataFrameTransformCheckpoint newCheckpoint,
+                                     ActionListener<Map<String, Set<String>>> listener) {
+
+        ActionListener<Map<String, Set<String>>> wrappedListener = ActionListener.wrap(
+            r -> {
+                this.inProgressOrLastCheckpoint = newCheckpoint;
+                this.changedBuckets = r;
+                listener.onResponse(r);
+            },
+            listener::onFailure
+        );
         // initialize the map of changed buckets, the map might be empty if source do not require/implement
         // changed bucket detection
         Map<String, Set<String>> keys = pivot.initialIncrementalBucketUpdateMap();
         if (keys.isEmpty()) {
             logger.trace("This data frame does not implement changed bucket detection, returning");
-            listener.onResponse(null);
+            wrappedListener.onResponse(null);
             return;
         }
 
@@ -324,17 +313,17 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer<Map<String,
             sourceBuilder.query(filteredQuery);
         } else {
             logger.trace("No sync configured");
-            listener.onResponse(null);
+            wrappedListener.onResponse(null);
             return;
         }
 
         searchRequest.source(sourceBuilder);
         searchRequest.allowPartialSearchResults(false);
 
-        collectChangedBuckets(searchRequest, changesAgg, keys, ActionListener.wrap(listener::onResponse, e -> {
+        collectChangedBuckets(searchRequest, changesAgg, keys, ActionListener.wrap(wrappedListener::onResponse, e -> {
             // fall back if bucket collection failed
             logger.error("Failed to retrieve changed buckets, fall back to complete retrieval", e);
-            listener.onResponse(null);
+            wrappedListener.onResponse(null);
         }));
     }
 

+ 28 - 9
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformTask.java

@@ -49,6 +49,7 @@ import org.elasticsearch.xpack.dataframe.transforms.pivot.AggregationResultUtils
 
 import java.util.Arrays;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -523,17 +524,35 @@ public class DataFrameTransformTask extends AllocatedPersistentTask implements S
             // Since multiple checkpoints can be executed in the task while it is running on the same node, we need to gather
             // the progress here, and not in the executor.
             if (initialRun()) {
-                TransformProgressGatherer.getInitialProgress(this.client, getConfig(), ActionListener.wrap(
-                    newProgress -> {
-                        progress = newProgress;
-                        super.onStart(now, listener);
+                ActionListener<Map<String, Set<String>>> changedBucketsListener = ActionListener.wrap(
+                    r -> {
+                        TransformProgressGatherer.getInitialProgress(this.client, buildFilterQuery(), getConfig(), ActionListener.wrap(
+                            newProgress -> {
+                                logger.trace("[{}] reset the progress from [{}] to [{}]", transformId, progress, newProgress);
+                                progress = newProgress;
+                                super.onStart(now, listener);
+                            },
+                            failure -> {
+                                progress = null;
+                                logger.warn("Unable to load progress information for task [" + transformId + "]", failure);
+                                super.onStart(now, listener);
+                            }
+                        ));
                     },
-                    failure -> {
-                        progress = null;
-                        logger.warn("Unable to load progress information for task [" + transformId + "]", failure);
-                        super.onStart(now, listener);
+                    listener::onFailure
+                );
+
+                createCheckpoint(ActionListener.wrap(cp -> {
+                    DataFrameTransformCheckpoint oldCheckpoint = inProgressOrLastCheckpoint;
+                    if (oldCheckpoint.isEmpty()) {
+                        // this is the 1st run, accept the new in progress checkpoint and go on
+                        inProgressOrLastCheckpoint = cp;
+                        changedBucketsListener.onResponse(null);
+                    } else {
+                        logger.debug ("Getting changes from {} to {}", oldCheckpoint.getTimeUpperBound(), cp.getTimeUpperBound());
+                        getChangedBuckets(oldCheckpoint, cp, changedBucketsListener);
                     }
-                ));
+                }, listener::onFailure));
             } else {
                 super.onStart(now, listener);
             }

+ 9 - 5
x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/TransformProgressGatherer.java

@@ -12,6 +12,7 @@ import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.index.query.BoolQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.xpack.core.ClientHelper;
@@ -28,13 +29,16 @@ public final class TransformProgressGatherer {
     /**
      * This gathers the total docs given the config and search
      *
-     * TODO: Support checkpointing logic to restrict the query
-     * @param progressListener The listener to alert on completion
+     * @param client ES Client to make queries
+     * @param filterQuery The adapted filter that can optionally take into account checkpoint information
+     * @param config The transform config containing headers, source, pivot, etc. information
+     * @param progressListener The listener to notify when progress object has been created
      */
     public static void getInitialProgress(Client client,
+                                          QueryBuilder filterQuery,
                                           DataFrameTransformConfig config,
                                           ActionListener<DataFrameTransformProgress> progressListener) {
-        SearchRequest request = getSearchRequest(config);
+        SearchRequest request = getSearchRequest(config, filterQuery);
 
         ActionListener<SearchResponse> searchResponseActionListener = ActionListener.wrap(
             searchResponse -> progressListener.onResponse(searchResponseToDataFrameTransformProgressFunction().apply(searchResponse)),
@@ -48,7 +52,7 @@ public final class TransformProgressGatherer {
             searchResponseActionListener);
     }
 
-    public static SearchRequest getSearchRequest(DataFrameTransformConfig config) {
+    public static SearchRequest getSearchRequest(DataFrameTransformConfig config, QueryBuilder filteredQuery) {
         SearchRequest request = new SearchRequest(config.getSource().getIndex());
         request.allowPartialSearchResults(false);
         BoolQueryBuilder existsClauses = QueryBuilders.boolQuery();
@@ -63,7 +67,7 @@ public final class TransformProgressGatherer {
             .size(0)
             .trackTotalHits(true)
             .query(QueryBuilders.boolQuery()
-                .filter(config.getSource().getQueryConfig().getQuery())
+                .filter(filteredQuery)
                 .filter(existsClauses)));
         return request;
     }