Browse Source

[ML] Only report complete writing_results progress after completion (#49551)

We depend on the number of data frame rows in order to report progress
for the writing of results, the last phase of a job run. However, results
include other objects than just the data frame rows (e.g, progress, inference model, etc.).

The problem this commit fixes is that if we receive the last data frame row results
we'll report that progress is complete even though we still have more results to process
potentially. If the job gets stopped for any reason at this point, we will not be able
to restart the job properly as we'll think that the job was completed.

This commit addresses this by limiting the max progress we can report for the
writing_results phase before the results processor completes to 98.
At the end, when the process is done we set the progress to 100.

The commit also improves failure capturing and reporting in the results processor.
Dimitris Athanasiou 5 years ago
parent
commit
7f943028b9

+ 37 - 16
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java

@@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
@@ -37,6 +38,18 @@ public class AnalyticsResultProcessor {
 
     private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);
 
+    /**
+     * While we report progress as we read row results there are other things we need to account for
+     * to report completion. There are other types of results we can't predict the number of like
+     * progress objects and the inference model. Thus, we report a max progress until we know we have
+     * completed processing results.
+     *
+     * It is critical to ensure we do not report complete progress too soon as restarting a job
+     * uses the progress to determine which state to restart from. If we report full progress too soon
+     * we cannot restart a job as we will think the job was finished.
+     */
+    private static final int MAX_PROGRESS_BEFORE_COMPLETION = 98;
+
     private final DataFrameAnalyticsConfig analytics;
     private final DataFrameRowsJoiner dataFrameRowsJoiner;
     private final ProgressTracker progressTracker;
@@ -68,7 +81,7 @@ public class AnalyticsResultProcessor {
             completionLatch.await();
         } catch (InterruptedException e) {
             Thread.currentThread().interrupt();
-            LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for results processor to complete", analytics.getId()), e);
+            setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for results processor to complete", e));
         }
     }
 
@@ -91,27 +104,32 @@ public class AnalyticsResultProcessor {
                 processResult(result, resultsJoiner);
                 if (result.getRowResults() != null) {
                     processedRows++;
-                    progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
+                    updateResultsProgress(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
                 }
             }
-            if (isCancelled == false) {
-                // This means we completed successfully so we need to set the progress to 100.
-                // This is because due to skipped rows, it is possible the processed rows will not reach the total rows.
-                progressTracker.writingResultsPercent.set(100);
-            }
         } catch (Exception e) {
             if (isCancelled) {
                 // No need to log error as it's due to stopping
             } else {
-                LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);
-                failure = "error parsing data frame analytics output: [" + e.getMessage() + "]";
+                setAndReportFailure(e);
             }
         } finally {
+            if (isCancelled == false && failure == null) {
+                completeResultsProgress();
+            }
             completionLatch.countDown();
             process.consumeAndCloseOutputStream();
         }
     }
 
+    private void updateResultsProgress(int progress) {
+        progressTracker.writingResultsPercent.set(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION));
+    }
+
+    private void completeResultsProgress() {
+        progressTracker.writingResultsPercent.set(100);
+    }
+
     private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) {
         RowResults rowResults = result.getRowResults();
         if (rowResults != null) {
@@ -137,7 +155,7 @@ public class AnalyticsResultProcessor {
             }
         } catch (InterruptedException e) {
             Thread.currentThread().interrupt();
-            LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for inference model to be stored", analytics.getId()), e);
+            setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored"));
         }
     }
 
@@ -168,19 +186,22 @@ public class AnalyticsResultProcessor {
             aBoolean -> {
                 if (aBoolean == false) {
                     LOGGER.error("[{}] Storing trained model responded false", analytics.getId());
+                    setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false"));
                 } else {
                     LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
                     auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]");
                 }
             },
-            e -> {
-                LOGGER.error(new ParameterizedMessage("[{}] Error storing trained model [{}]", analytics.getId(),
-                    trainedModelConfig.getModelId()), e);
-                auditor.error(analytics.getId(), "Error storing trained model with id [" + trainedModelConfig.getModelId()
-                    + "]; error message [" + e.getMessage() + "]");
-            }
+            e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e,
+                trainedModelConfig.getModelId()))
         );
         trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
         return latch;
     }
+
+    private void setAndReportFailure(Exception e) {
+        LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e);
+        failure = "error processing results; " + e.getMessage();
+        auditor.error(analytics.getId(), "Error processing results; " + e.getMessage());
+    }
 }

+ 28 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java

@@ -39,9 +39,11 @@ import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasKey;
+import static org.hamcrest.Matchers.startsWith;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -117,6 +119,28 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
         assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
     }
 
+    public void testProcess_GivenDataFrameRowsJoinerFails() {
+        givenDataFrameRows(2);
+        RowResults rowResults1 = mock(RowResults.class);
+        RowResults rowResults2 = mock(RowResults.class);
+        givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null), new AnalyticsResult(rowResults2, 100, null)));
+
+        doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class));
+
+        AnalyticsResultProcessor resultProcessor = createResultProcessor();
+
+        resultProcessor.process(process);
+        resultProcessor.awaitForCompletion();
+
+        assertThat(resultProcessor.getFailure(), equalTo("error processing results; some failure"));
+
+        ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
+        verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
+        assertThat(auditCaptor.getValue(), containsString("Error processing results; some failure"));
+
+        assertThat(progressTracker.writingResultsPercent.get(), equalTo(0));
+    }
+
     @SuppressWarnings("unchecked")
     public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
         givenDataFrameRows(0);
@@ -182,9 +206,11 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
         // This test verifies the processor knows how to handle a failure on storing the model and completes normally
         ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
         verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
-        assertThat(auditCaptor.getValue(), containsString("Error storing trained model with id [" + JOB_ID));
-        assertThat(auditCaptor.getValue(), containsString("[some failure]"));
+        assertThat(auditCaptor.getValue(), containsString("Error processing results; error storing trained model with id [" + JOB_ID));
         Mockito.verifyNoMoreInteractions(auditor);
+
+        assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID));
+        assertThat(progressTracker.writingResultsPercent.get(), equalTo(0));
     }
 
     private void givenProcessResults(List<AnalyticsResult> results) {