|
@@ -12,8 +12,17 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests;
|
|
|
+import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
|
|
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
|
|
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
|
|
+import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
|
|
|
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
|
|
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
|
|
|
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
|
|
@@ -26,12 +35,15 @@ import org.mockito.ArgumentCaptor;
|
|
|
import org.mockito.InOrder;
|
|
|
import org.mockito.Mockito;
|
|
|
|
|
|
+import java.time.Instant;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
import java.util.List;
|
|
|
+import java.util.Optional;
|
|
|
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
+import static org.hamcrest.Matchers.is;
|
|
|
import static org.mockito.Matchers.any;
|
|
|
import static org.mockito.Matchers.eq;
|
|
|
import static org.mockito.Mockito.doThrow;
|
|
@@ -85,8 +97,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|
|
public void testProcess_GivenEmptyResults() {
|
|
|
givenDataFrameRows(2);
|
|
|
givenProcessResults(Arrays.asList(
|
|
|
- new AnalyticsResult(null, null, null,null, null, null, null, null),
|
|
|
- new AnalyticsResult(null, null, null, null, null, null, null, null)));
|
|
|
+ AnalyticsResult.builder().build(),
|
|
|
+ AnalyticsResult.builder().build()));
|
|
|
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
|
|
|
|
|
resultProcessor.process(process);
|
|
@@ -101,8 +113,9 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|
|
givenDataFrameRows(2);
|
|
|
RowResults rowResults1 = mock(RowResults.class);
|
|
|
RowResults rowResults2 = mock(RowResults.class);
|
|
|
- givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null),
|
|
|
- new AnalyticsResult(rowResults2, null, null, null, null, null, null, null)));
|
|
|
+ givenProcessResults(Arrays.asList(
|
|
|
+ AnalyticsResult.builder().setRowResults(rowResults1).build(),
|
|
|
+ AnalyticsResult.builder().setRowResults(rowResults2).build()));
|
|
|
AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
|
|
|
|
|
resultProcessor.process(process);
|
|
@@ -119,8 +132,9 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|
|
givenDataFrameRows(2);
|
|
|
RowResults rowResults1 = mock(RowResults.class);
|
|
|
RowResults rowResults2 = mock(RowResults.class);
|
|
|
- givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null),
|
|
|
- new AnalyticsResult(rowResults2, null, null, null, null, null, null, null)));
|
|
|
+ givenProcessResults(Arrays.asList(
|
|
|
+ AnalyticsResult.builder().setRowResults(rowResults1).build(),
|
|
|
+ AnalyticsResult.builder().setRowResults(rowResults2).build()));
|
|
|
|
|
|
doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class));
|
|
|
|
|
@@ -138,6 +152,146 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
|
|
|
assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
|
|
}
|
|
|
|
|
|
+ public void testCancel_GivenRowResults() {
|
|
|
+ givenDataFrameRows(2);
|
|
|
+ RowResults rowResults1 = mock(RowResults.class);
|
|
|
+ RowResults rowResults2 = mock(RowResults.class);
|
|
|
+ givenProcessResults(Arrays.asList(
|
|
|
+ AnalyticsResult.builder().setRowResults(rowResults1).build(),
|
|
|
+ AnalyticsResult.builder().setRowResults(rowResults2).build()));
|
|
|
+ AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
|
|
+
|
|
|
+ resultProcessor.cancel();
|
|
|
+
|
|
|
+ resultProcessor.process(process);
|
|
|
+ resultProcessor.awaitForCompletion();
|
|
|
+
|
|
|
+ verify(dataFrameRowsJoiner).cancel();
|
|
|
+ verify(dataFrameRowsJoiner).close();
|
|
|
+ Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider);
|
|
|
+ assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCancel_GivenModelChunk() {
|
|
|
+ givenDataFrameRows(2);
|
|
|
+ TrainedModelDefinitionChunk modelChunk = mock(TrainedModelDefinitionChunk.class);
|
|
|
+ givenProcessResults(Arrays.asList(AnalyticsResult.builder().setTrainedModelDefinitionChunk(modelChunk).build()));
|
|
|
+ AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
|
|
+
|
|
|
+ resultProcessor.cancel();
|
|
|
+
|
|
|
+ resultProcessor.process(process);
|
|
|
+ resultProcessor.awaitForCompletion();
|
|
|
+
|
|
|
+ verify(dataFrameRowsJoiner).cancel();
|
|
|
+ verify(dataFrameRowsJoiner).close();
|
|
|
+ Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider);
|
|
|
+ assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCancel_GivenPhaseProgress() {
|
|
|
+ givenDataFrameRows(2);
|
|
|
+ PhaseProgress phaseProgress = new PhaseProgress("analyzing", 18);
|
|
|
+ givenProcessResults(Arrays.asList(AnalyticsResult.builder().setPhaseProgress(phaseProgress).build()));
|
|
|
+ AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
|
|
+
|
|
|
+ resultProcessor.cancel();
|
|
|
+
|
|
|
+ resultProcessor.process(process);
|
|
|
+ resultProcessor.awaitForCompletion();
|
|
|
+
|
|
|
+ verify(dataFrameRowsJoiner).cancel();
|
|
|
+ verify(dataFrameRowsJoiner).close();
|
|
|
+ Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider);
|
|
|
+ assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
|
|
+
|
|
|
+ Optional<PhaseProgress> testPhaseProgress = statsHolder.getProgressTracker().report().stream()
|
|
|
+ .filter(p -> p.getPhase().equals(phaseProgress.getPhase()))
|
|
|
+ .findAny();
|
|
|
+ assertThat(testPhaseProgress.isPresent(), is(true));
|
|
|
+ assertThat(testPhaseProgress.get().getProgressPercent(), equalTo(18));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCancel_GivenMemoryUsage() {
|
|
|
+ givenDataFrameRows(2);
|
|
|
+ MemoryUsage memoryUsage = new MemoryUsage(analyticsConfig.getId(), Instant.now(), 1000L, MemoryUsage.Status.HARD_LIMIT, null);
|
|
|
+ givenProcessResults(Arrays.asList(AnalyticsResult.builder().setMemoryUsage(memoryUsage).build()));
|
|
|
+ AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
|
|
+
|
|
|
+ resultProcessor.cancel();
|
|
|
+
|
|
|
+ resultProcessor.process(process);
|
|
|
+ resultProcessor.awaitForCompletion();
|
|
|
+
|
|
|
+ verify(dataFrameRowsJoiner).cancel();
|
|
|
+ verify(dataFrameRowsJoiner).close();
|
|
|
+ Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider);
|
|
|
+ assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
|
|
+
|
|
|
+ assertThat(statsHolder.getMemoryUsage(), equalTo(memoryUsage));
|
|
|
+ verify(statsPersister).persistWithRetry(eq(memoryUsage), any());
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCancel_GivenOutlierDetectionStats() {
|
|
|
+ givenDataFrameRows(2);
|
|
|
+ OutlierDetectionStats outlierDetectionStats = OutlierDetectionStatsTests.createRandom();
|
|
|
+ givenProcessResults(Arrays.asList(AnalyticsResult.builder().setOutlierDetectionStats(outlierDetectionStats).build()));
|
|
|
+ AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
|
|
+
|
|
|
+ resultProcessor.cancel();
|
|
|
+
|
|
|
+ resultProcessor.process(process);
|
|
|
+ resultProcessor.awaitForCompletion();
|
|
|
+
|
|
|
+ verify(dataFrameRowsJoiner).cancel();
|
|
|
+ verify(dataFrameRowsJoiner).close();
|
|
|
+ Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider);
|
|
|
+ assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
|
|
+
|
|
|
+ assertThat(statsHolder.getAnalysisStats(), equalTo(outlierDetectionStats));
|
|
|
+ verify(statsPersister).persistWithRetry(eq(outlierDetectionStats), any());
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCancel_GivenClassificationStats() {
|
|
|
+ givenDataFrameRows(2);
|
|
|
+ ClassificationStats classificationStats = ClassificationStatsTests.createRandom();
|
|
|
+ givenProcessResults(Arrays.asList(AnalyticsResult.builder().setClassificationStats(classificationStats).build()));
|
|
|
+ AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
|
|
+
|
|
|
+ resultProcessor.cancel();
|
|
|
+
|
|
|
+ resultProcessor.process(process);
|
|
|
+ resultProcessor.awaitForCompletion();
|
|
|
+
|
|
|
+ verify(dataFrameRowsJoiner).cancel();
|
|
|
+ verify(dataFrameRowsJoiner).close();
|
|
|
+ Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider);
|
|
|
+ assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
|
|
+
|
|
|
+ assertThat(statsHolder.getAnalysisStats(), equalTo(classificationStats));
|
|
|
+ verify(statsPersister).persistWithRetry(eq(classificationStats), any());
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCancel_GivenRegressionStats() {
|
|
|
+ givenDataFrameRows(2);
|
|
|
+ RegressionStats regressionStats = RegressionStatsTests.createRandom();
|
|
|
+ givenProcessResults(Arrays.asList(AnalyticsResult.builder().setRegressionStats(regressionStats).build()));
|
|
|
+ AnalyticsResultProcessor resultProcessor = createResultProcessor();
|
|
|
+
|
|
|
+ resultProcessor.cancel();
|
|
|
+
|
|
|
+ resultProcessor.process(process);
|
|
|
+ resultProcessor.awaitForCompletion();
|
|
|
+
|
|
|
+ verify(dataFrameRowsJoiner).cancel();
|
|
|
+ verify(dataFrameRowsJoiner).close();
|
|
|
+ Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider);
|
|
|
+ assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0));
|
|
|
+
|
|
|
+ assertThat(statsHolder.getAnalysisStats(), equalTo(regressionStats));
|
|
|
+ verify(statsPersister).persistWithRetry(eq(regressionStats), any());
|
|
|
+ }
|
|
|
+
|
|
|
private void givenProcessResults(List<AnalyticsResult> results) {
|
|
|
when(process.readAnalyticsResults()).thenReturn(results.iterator());
|
|
|
}
|