Browse Source

[ML] Recover data frame extraction search from latest sort key (#61544)

If a search failure occurs during data frame extraction we catch
the error and retry once. However, we retry another search that is
identical to the first one. This means we will re-fetch any docs
that were already processed. This may result either to training
a model using duplicate data or in the case of outlier detection to
an error message that the process received more records than it
expected.

This commit fixes this issue by tracking the latest doc's sort key
and then using that in a range query in case we restart the search
due to a failure.
Dimitris Athanasiou 5 years ago
parent
commit
8fb18b6c88

+ 22 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java

@@ -67,6 +67,7 @@ public class DataFrameDataExtractor {
     private final Client client;
     private final DataFrameDataExtractorContext context;
     private String scrollId;
+    private String lastSortKey;
     private boolean isCancelled;
     private boolean hasNext;
     private boolean searchHasShardFailure;
@@ -122,7 +123,9 @@ public class DataFrameDataExtractor {
         }
 
         Optional<List<Row>> hits = scrollId == null ? Optional.ofNullable(initScroll()) : Optional.ofNullable(continueScroll());
-        if (!hits.isPresent()) {
+        if (hits.isPresent() && hits.get().isEmpty() == false) {
+            lastSortKey = hits.get().get(hits.get().size() - 1).getSortKey();
+        } else {
             hasNext = false;
         }
         return hits;
@@ -135,6 +138,7 @@ public class DataFrameDataExtractor {
 
     private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request) throws IOException {
         try {
+
             // We've set allow_partial_search_results to false which means if something
             // goes wrong the request will throw.
             SearchResponse searchResponse = request.get();
@@ -165,8 +169,19 @@ public class DataFrameDataExtractor {
                 .setAllowPartialSearchResults(false)
                 .addSort(DestinationIndex.ID_COPY, SortOrder.ASC)
                 .setIndices(context.indices)
-                .setSize(context.scrollSize)
-                .setQuery(context.query);
+                .setSize(context.scrollSize);
+
+        if (lastSortKey == null) {
+            searchRequestBuilder.setQuery(context.query);
+        } else {
+            LOGGER.debug(() -> new ParameterizedMessage("[{}] Searching docs with [{}] greater than [{}]",
+                context.jobId, DestinationIndex.ID_COPY, lastSortKey));
+            QueryBuilder queryPlusLastSortKey = QueryBuilders.boolQuery()
+                .filter(context.query)
+                .filter(QueryBuilders.rangeQuery(DestinationIndex.ID_COPY).gt(lastSortKey));
+            searchRequestBuilder.setQuery(queryPlusLastSortKey);
+        }
+
         setFetchSource(searchRequestBuilder);
 
         for (ExtractedField docValueField : context.extractedFields.getDocValueFields()) {
@@ -426,5 +441,9 @@ public class DataFrameDataExtractor {
         public int getChecksum() {
             return Arrays.hashCode(values);
         }
+
+        public String getSortKey() {
+            return (String) hit.getSortValues()[0];
+        }
     }
 }

+ 21 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java

@@ -58,6 +58,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.Matchers.same;
@@ -77,6 +78,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
     private TrainTestSplitterFactory trainTestSplitterFactory;
     private ArgumentCaptor<ClearScrollRequest> capturedClearScrollRequests;
     private ActionFuture<ClearScrollResponse> clearScrollFuture;
+    private int searchHitCounter;
 
     @Before
     @SuppressWarnings("unchecked")
@@ -196,6 +198,13 @@ public class DataFrameDataExtractorTests extends ESTestCase {
         List<String> capturedClearScrollRequests = getCapturedClearScrollIds();
         assertThat(capturedClearScrollRequests.size(), equalTo(1));
         assertThat(capturedClearScrollRequests.get(0), equalTo(lastAndEmptyResponse.getScrollId()));
+
+        // Notice we've done two searches here
+        assertThat(dataExtractor.capturedSearchRequests, hasSize(2));
+
+        // Assert the second search did not include a range query as the failure happened on the very first search
+        String searchRequest = dataExtractor.capturedSearchRequests.get(1).request().toString().replaceAll("\\s", "");
+        assertThat(searchRequest, containsString("\"query\":{\"match_all\":{\"boost\":1.0}}"));
     }
 
     public void testErrorOnSearchTwiceLeadsToFailure() {
@@ -215,14 +224,14 @@ public class DataFrameDataExtractorTests extends ESTestCase {
         TestExtractor dataExtractor = createExtractor(true, false);
 
         // Search will succeed
-        SearchResponse response1 = createSearchResponse(Arrays.asList(1_1), Arrays.asList(2_1));
+        SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2));
         dataExtractor.setNextResponse(response1);
 
         // But the first continue scroll fails
         dataExtractor.setNextResponse(createResponseWithShardFailures());
 
         // The next one succeeds and we shall recover
-        SearchResponse response2 = createSearchResponse(Arrays.asList(1_2), Arrays.asList(2_2));
+        SearchResponse response2 = createSearchResponse(Arrays.asList(1_3), Arrays.asList(2_3));
         dataExtractor.setNextResponse(response2);
 
         // Last one
@@ -234,15 +243,16 @@ public class DataFrameDataExtractorTests extends ESTestCase {
         // First batch expected as normally since we'll retry after the error
         Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
         assertThat(rows.isPresent(), is(true));
-        assertThat(rows.get().size(), equalTo(1));
+        assertThat(rows.get().size(), equalTo(2));
         assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"}));
+        assertThat(rows.get().get(1).getValues(), equalTo(new String[] {"12", "22"}));
         assertThat(dataExtractor.hasNext(), is(true));
 
         // We get second batch as we retried after the error
         rows = dataExtractor.next();
         assertThat(rows.isPresent(), is(true));
         assertThat(rows.get().size(), equalTo(1));
-        assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"12", "22"}));
+        assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"13", "23"}));
         assertThat(dataExtractor.hasNext(), is(true));
 
         // Next batch should return empty
@@ -254,6 +264,12 @@ public class DataFrameDataExtractorTests extends ESTestCase {
         assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(2));
         assertThat(dataExtractor.capturedContinueScrollIds.size(), equalTo(2));
 
+        // Assert the second search continued from the latest successfully processed doc
+        String searchRequest = dataExtractor.capturedSearchRequests.get(1).request().toString().replaceAll("\\s", "");
+        assertThat(searchRequest, containsString("\"query\":{\"bool\":{"));
+        assertThat(searchRequest, containsString("{\"match_all\":{\"boost\":1.0}"));
+        assertThat(searchRequest, containsString("{\"range\":{\"ml__id_copy\":{\"from\":\"1\",\"to\":null,\"include_lower\":false"));
+
         // Check we cleared the scroll with the latest scroll id
         List<String> capturedClearScrollRequests = getCapturedClearScrollIds();
         assertThat(capturedClearScrollRequests.size(), equalTo(1));
@@ -583,6 +599,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
             addField(searchHitBuilder, "field_1", field1Values.get(i));
             addField(searchHitBuilder, "field_2", field2Values.get(i));
             searchHitBuilder.setSource("{\"field_1\":" + field1Values.get(i) + ",\"field_2\":" + field2Values.get(i) + "}");
+            searchHitBuilder.setStringSortValue(String.valueOf(searchHitCounter++));
             hits.add(searchHitBuilder.build());
         }
         SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[0]), new TotalHits(hits.size(), TotalHits.Relation.EQUAL_TO), 1);

+ 6 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/SearchHitBuilder.java

@@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.test;
 
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.document.DocumentField;
+import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.SearchHit;
 
 import java.util.Arrays;
@@ -41,6 +42,11 @@ public class SearchHitBuilder {
         return this;
     }
 
+    public SearchHitBuilder setStringSortValue(String sortValue) {
+        hit.sortValues(new String[] { sortValue }, new DocValueFormat[] { DocValueFormat.RAW });
+        return this;
+    }
+
     public SearchHit build() {
         return hit;
     }