Browse Source

[ML] Prepare parsing phase_progress from DFA process (#55580)

Data frame analytics process currently reports progress as
an integer `progress_percent`. We parse that and report it
from the _stats API as the progress of the `analyzing` phase.
However, we want to allow the DFA process to report progress
for more than one phase. This commit prepares for this by
parsing `phase_progress` from the process, an object that
contains the `phase` name plus the `progress_percent` for that
phase.
Dimitris Athanasiou 5 years ago
parent
commit
2d55592c7a

+ 10 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java

@@ -32,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 import org.elasticsearch.xpack.core.security.user.XPackUser;
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
@@ -164,11 +165,20 @@ public class AnalyticsResultProcessor {
         if (rowResults != null) {
             resultsJoiner.processRowResults(rowResults);
         }
+        PhaseProgress phaseProgress = result.getPhaseProgress();
+        if (phaseProgress != null) {
+            LOGGER.debug("[{}] progress for phase [{}] updated to [{}]", analytics.getId(), phaseProgress.getPhase(),
+                phaseProgress.getProgressPercent());
+            statsHolder.getProgressTracker().analyzingPercent.set(phaseProgress.getProgressPercent());
+        }
+
+        // TODO remove after process is writing out phase_progress
         Integer progressPercent = result.getProgressPercent();
         if (progressPercent != null) {
             LOGGER.debug("[{}] Analyzing progress updated to [{}]", analytics.getId(), progressPercent);
             statsHolder.getProgressTracker().analyzingPercent.set(progressPercent);
         }
+
         TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
         if (inferenceModelBuilder != null) {
             createAndIndexInferenceModel(inferenceModelBuilder);

+ 27 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java

@@ -11,11 +11,12 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
-import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
+import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
+import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -28,6 +29,7 @@ public class AnalyticsResult implements ToXContentObject {
 
     public static final ParseField TYPE = new ParseField("analytics_result");
 
+    private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress");
     private static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent");
     private static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
     private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage");
@@ -38,16 +40,18 @@ public class AnalyticsResult implements ToXContentObject {
     public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
             a -> new AnalyticsResult(
                 (RowResults) a[0],
-                (Integer) a[1],
-                (TrainedModelDefinition.Builder) a[2],
-                (MemoryUsage) a[3],
-                (OutlierDetectionStats) a[4],
-                (ClassificationStats) a[5],
-                (RegressionStats) a[6]
+                (PhaseProgress) a[1],
+                (Integer) a[2],
+                (TrainedModelDefinition.Builder) a[3],
+                (MemoryUsage) a[4],
+                (OutlierDetectionStats) a[5],
+                (ClassificationStats) a[6],
+                (RegressionStats) a[7]
             ));
 
     static {
         PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
+        PARSER.declareObject(optionalConstructorArg(), PhaseProgress.PARSER, PHASE_PROGRESS);
         PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT);
         // TODO change back to STRICT_PARSER once native side is aligned
         PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL);
@@ -58,7 +62,11 @@ public class AnalyticsResult implements ToXContentObject {
     }
 
     private final RowResults rowResults;
+    private final PhaseProgress phaseProgress;
+
+    // TODO remove after process is writing out phase_progress
     private final Integer progressPercent;
+
     private final TrainedModelDefinition.Builder inferenceModelBuilder;
     private final TrainedModelDefinition inferenceModel;
     private final MemoryUsage memoryUsage;
@@ -67,6 +75,7 @@ public class AnalyticsResult implements ToXContentObject {
     private final RegressionStats regressionStats;
 
     public AnalyticsResult(@Nullable RowResults rowResults,
+                           @Nullable PhaseProgress phaseProgress,
                            @Nullable Integer progressPercent,
                            @Nullable TrainedModelDefinition.Builder inferenceModelBuilder,
                            @Nullable MemoryUsage memoryUsage,
@@ -74,6 +83,7 @@ public class AnalyticsResult implements ToXContentObject {
                            @Nullable ClassificationStats classificationStats,
                            @Nullable RegressionStats regressionStats) {
         this.rowResults = rowResults;
+        this.phaseProgress = phaseProgress;
         this.progressPercent = progressPercent;
         this.inferenceModelBuilder = inferenceModelBuilder;
         this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build();
@@ -87,6 +97,10 @@ public class AnalyticsResult implements ToXContentObject {
         return rowResults;
     }
 
+    public PhaseProgress getPhaseProgress() {
+        return phaseProgress;
+    }
+
     public Integer getProgressPercent() {
         return progressPercent;
     }
@@ -117,6 +131,9 @@ public class AnalyticsResult implements ToXContentObject {
         if (rowResults != null) {
             builder.field(RowResults.TYPE.getPreferredName(), rowResults);
         }
+        if (phaseProgress != null) {
+            builder.field(PHASE_PROGRESS.getPreferredName(), phaseProgress);
+        }
         if (progressPercent != null) {
             builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent);
         }
@@ -152,6 +169,7 @@ public class AnalyticsResult implements ToXContentObject {
 
         AnalyticsResult that = (AnalyticsResult) other;
         return Objects.equals(rowResults, that.rowResults)
+            && Objects.equals(phaseProgress, that.phaseProgress)
             && Objects.equals(progressPercent, that.progressPercent)
             && Objects.equals(inferenceModel, that.inferenceModel)
             && Objects.equals(memoryUsage, that.memoryUsage)
@@ -162,7 +180,7 @@ public class AnalyticsResult implements ToXContentObject {
 
     @Override
     public int hashCode() {
-        return Objects.hash(rowResults, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
-            regressionStats);
+        return Objects.hash(rowResults, phaseProgress, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats,
+            classificationStats, regressionStats);
     }
 }

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java

@@ -58,7 +58,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
     private static final String CONFIG_ID = "config-id";
     private static final int NUM_ROWS = 100;
     private static final int NUM_COLS = 4;
-    private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null);
+    private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null, null);
 
     private Client client;
     private DataFrameAnalyticsAuditor auditor;

+ 8 - 8
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java

@@ -105,8 +105,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
     public void testProcess_GivenEmptyResults() {
         givenDataFrameRows(2);
         givenProcessResults(Arrays.asList(
-            new AnalyticsResult(null, 50, null, null, null, null, null),
-            new AnalyticsResult(null, 100, null, null, null, null, null)));
+            new AnalyticsResult(null, null,50, null, null, null, null, null),
+            new AnalyticsResult(null, null, 100, null, null, null, null, null)));
         AnalyticsResultProcessor resultProcessor = createResultProcessor();
 
         resultProcessor.process(process);
@@ -121,8 +121,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
         givenDataFrameRows(2);
         RowResults rowResults1 = mock(RowResults.class);
         RowResults rowResults2 = mock(RowResults.class);
-        givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null, null, null, null),
-            new AnalyticsResult(rowResults2, 100, null, null, null, null, null)));
+        givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,50, null, null, null, null, null),
+            new AnalyticsResult(rowResults2, null, 100, null, null, null, null, null)));
         AnalyticsResultProcessor resultProcessor = createResultProcessor();
 
         resultProcessor.process(process);
@@ -139,8 +139,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
         givenDataFrameRows(2);
         RowResults rowResults1 = mock(RowResults.class);
         RowResults rowResults2 = mock(RowResults.class);
-        givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null, null, null, null),
-            new AnalyticsResult(rowResults2, 100, null, null, null, null, null)));
+        givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,50, null, null, null, null, null),
+            new AnalyticsResult(rowResults2, null, 100, null, null, null, null, null)));
 
         doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class));
 
@@ -174,7 +174,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
         extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
         TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
         TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
-        givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
+        givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, null, inferenceModel, null, null, null, null)));
         AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);
 
         resultProcessor.process(process);
@@ -238,7 +238,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
 
         TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
         TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
-        givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
+        givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, null, inferenceModel, null, null, null, null)));
         AnalyticsResultProcessor resultProcessor = createResultProcessor();
 
         resultProcessor.process(process);

+ 7 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStat
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
+import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
 import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
 
 import java.util.ArrayList;
@@ -41,6 +42,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
     @Override
     protected AnalyticsResult createTestInstance() {
         RowResults rowResults = null;
+        PhaseProgress phaseProgress = null;
         Integer progressPercent = null;
         TrainedModelDefinition.Builder inferenceModel = null;
         MemoryUsage memoryUsage = null;
@@ -50,6 +52,9 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
         if (randomBoolean()) {
             rowResults = RowResultsTests.createRandom();
         }
+        if (randomBoolean()) {
+            phaseProgress = new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100));
+        }
         if (randomBoolean()) {
             progressPercent = randomIntBetween(0, 100);
         }
@@ -68,8 +73,8 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
         if (randomBoolean()) {
             regressionStats = RegressionStatsTests.createRandom();
         }
-        return new AnalyticsResult(rowResults, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
-            regressionStats);
+        return new AnalyticsResult(rowResults, phaseProgress, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats,
+            classificationStats, regressionStats);
     }
 
     @Override