浏览代码

[ML] Refactor DFA custom processor to cross validation splitter (#53915)

While `CustomProcessor` is generic and allows for flexibility, there
are new requirements that make cross validation a concept it's hard
to abstract behind custom processor. In particular, we would like to
add data_counts to the DFA jobs stats. Counting training VS. test
docs would be a useful statistic. We would also want to add a
different cross validation strategy for multiclass classification.

This commit renames custom processors to cross validation splitters
which allows for those enhancements without cryptically doing
things as a side effect of the abstract custom processing.
Dimitris Athanasiou 5 年之前
父节点
当前提交
adcf25e093

+ 5 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

@@ -31,8 +31,8 @@ import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
-import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
-import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
+import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter;
+import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
 import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
 import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
 import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
@@ -208,7 +208,8 @@ public class AnalyticsProcessManager {
     private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
     private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
                                DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException {
                                DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException {
 
 
-        CustomProcessor customProcessor = new CustomProcessorFactory(dataExtractor.getFieldNames()).create(analysis);
+        CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames())
+            .create(analysis);
 
 
         // The extra fields are for the doc hash and the control field (should be an empty string)
         // The extra fields are for the doc hash and the control field (should be an empty string)
         String[] record = new String[dataExtractor.getFieldNames().size() + 2];
         String[] record = new String[dataExtractor.getFieldNames().size() + 2];
@@ -226,7 +227,7 @@ public class AnalyticsProcessManager {
                         String[] rowValues = row.getValues();
                         String[] rowValues = row.getValues();
                         System.arraycopy(rowValues, 0, record, 0, rowValues.length);
                         System.arraycopy(rowValues, 0, record, 0, rowValues.length);
                         record[record.length - 2] = String.valueOf(row.getChecksum());
                         record[record.length - 2] = String.valueOf(row.getChecksum());
-                        customProcessor.process(record);
+                        crossValidationSplitter.process(record);
                         process.writeRecord(record);
                         process.writeRecord(record);
                     }
                     }
                 }
                 }

+ 3 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessor.java → x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java

@@ -3,12 +3,12 @@
  * or more contributor license agreements. Licensed under the Elastic License;
  * or more contributor license agreements. Licensed under the Elastic License;
  * you may not use this file except in compliance with the Elastic License.
  * you may not use this file except in compliance with the Elastic License.
  */
  */
-package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
+package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
 
 
 /**
 /**
- * A processor to manipulate rows before writing them to the process
+ * Processes rows in order to split the dataset in training and test subsets
  */
  */
-public interface CustomProcessor {
+public interface CrossValidationSplitter {
 
 
     void process(String[] row);
     void process(String[] row);
 }
 }

+ 6 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java → x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java

@@ -3,7 +3,7 @@
  * or more contributor license agreements. Licensed under the Elastic License;
  * or more contributor license agreements. Licensed under the Elastic License;
  * you may not use this file except in compliance with the Elastic License.
  * you may not use this file except in compliance with the Elastic License.
  */
  */
-package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
+package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
 
 
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
@@ -12,23 +12,23 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
 import java.util.List;
 import java.util.List;
 import java.util.Objects;
 import java.util.Objects;
 
 
-public class CustomProcessorFactory {
+public class CrossValidationSplitterFactory {
 
 
     private final List<String> fieldNames;
     private final List<String> fieldNames;
 
 
-    public CustomProcessorFactory(List<String> fieldNames) {
+    public CrossValidationSplitterFactory(List<String> fieldNames) {
         this.fieldNames = Objects.requireNonNull(fieldNames);
         this.fieldNames = Objects.requireNonNull(fieldNames);
     }
     }
 
 
-    public CustomProcessor create(DataFrameAnalysis analysis) {
+    public CrossValidationSplitter create(DataFrameAnalysis analysis) {
         if (analysis instanceof Regression) {
         if (analysis instanceof Regression) {
             Regression regression = (Regression) analysis;
             Regression regression = (Regression) analysis;
-            return new DatasetSplittingCustomProcessor(
+            return new RandomCrossValidationSplitter(
                 fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
                 fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
         }
         }
         if (analysis instanceof Classification) {
         if (analysis instanceof Classification) {
             Classification classification = (Classification) analysis;
             Classification classification = (Classification) analysis;
-            return new DatasetSplittingCustomProcessor(
+            return new RandomCrossValidationSplitter(
                 fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
                 fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
         }
         }
         return row -> {};
         return row -> {};

+ 5 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java → x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java

@@ -3,7 +3,7 @@
  * or more contributor license agreements. Licensed under the Elastic License;
  * or more contributor license agreements. Licensed under the Elastic License;
  * you may not use this file except in compliance with the Elastic License.
  * you may not use this file except in compliance with the Elastic License.
  */
  */
-package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
+package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
 
 
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
@@ -12,19 +12,19 @@ import java.util.List;
 import java.util.Random;
 import java.util.Random;
 
 
 /**
 /**
- * A processor that randomly clears the dependent variable value
- * in order to split the dataset in training and validation data.
+ * A cross validation splitter that randomly clears the dependent variable value
+ * in order to split the dataset in training and test data.
  * This relies on the fact that when the dependent variable field
  * This relies on the fact that when the dependent variable field
  * is empty, then the row is not used for training but only to make predictions.
  * is empty, then the row is not used for training but only to make predictions.
  */
  */
-class DatasetSplittingCustomProcessor implements CustomProcessor {
+class RandomCrossValidationSplitter implements CrossValidationSplitter {
 
 
     private final int dependentVariableIndex;
     private final int dependentVariableIndex;
     private final double trainingPercent;
     private final double trainingPercent;
     private final Random random;
     private final Random random;
     private boolean isFirstRow = true;
     private boolean isFirstRow = true;
 
 
-    DatasetSplittingCustomProcessor(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
+    RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
         this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
         this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
         this.trainingPercent = trainingPercent;
         this.trainingPercent = trainingPercent;
         this.random = new Random(randomizeSeed);
         this.random = new Random(randomizeSeed);

+ 13 - 10
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java

@@ -3,7 +3,7 @@
  * or more contributor license agreements. Licensed under the Elastic License;
  * or more contributor license agreements. Licensed under the Elastic License;
  * you may not use this file except in compliance with the Elastic License.
  * you may not use this file except in compliance with the Elastic License.
  */
  */
-package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
+package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
 
 
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
 import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
@@ -20,7 +20,7 @@ import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.lessThan;
 import static org.hamcrest.Matchers.lessThan;
 
 
-public class DatasetSplittingCustomProcessorTests extends ESTestCase {
+public class RandomCrossValidationSplitterTests extends ESTestCase {
 
 
     private List<String> fields;
     private List<String> fields;
     private int dependentVariableIndex;
     private int dependentVariableIndex;
@@ -40,7 +40,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
     }
     }
 
 
     public void testProcess_GivenRowsWithoutDependentVariableValue() {
     public void testProcess_GivenRowsWithoutDependentVariableValue() {
-        CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0, randomizeSeed);
+        CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(fields, dependentVariable, 50.0, randomizeSeed);
 
 
         for (int i = 0; i < 100; i++) {
         for (int i = 0; i < 100; i++) {
             String[] row = new String[fields.size()];
             String[] row = new String[fields.size()];
@@ -50,7 +50,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
             }
             }
 
 
             String[] processedRow = Arrays.copyOf(row, row.length);
             String[] processedRow = Arrays.copyOf(row, row.length);
-            customProcessor.process(processedRow);
+            crossValidationSplitter.process(processedRow);
 
 
             // As all these rows have no dependent variable value, they're not for training and should be unaffected
             // As all these rows have no dependent variable value, they're not for training and should be unaffected
             assertThat(Arrays.equals(processedRow, row), is(true));
             assertThat(Arrays.equals(processedRow, row), is(true));
@@ -58,7 +58,8 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
     }
     }
 
 
     public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
     public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
-        CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0, randomizeSeed);
+        CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
+            fields, dependentVariable, 100.0, randomizeSeed);
 
 
         for (int i = 0; i < 100; i++) {
         for (int i = 0; i < 100; i++) {
             String[] row = new String[fields.size()];
             String[] row = new String[fields.size()];
@@ -68,7 +69,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
             }
             }
 
 
             String[] processedRow = Arrays.copyOf(row, row.length);
             String[] processedRow = Arrays.copyOf(row, row.length);
-            customProcessor.process(processedRow);
+            crossValidationSplitter.process(processedRow);
 
 
             // We should pick them all as training percent is 100
             // We should pick them all as training percent is 100
             assertThat(Arrays.equals(processedRow, row), is(true));
             assertThat(Arrays.equals(processedRow, row), is(true));
@@ -78,7 +79,8 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
     public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
     public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
         double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
         double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
         double trainingFraction = trainingPercent / 100;
         double trainingFraction = trainingPercent / 100;
-        CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent, randomizeSeed);
+        CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
+            fields, dependentVariable, trainingPercent, randomizeSeed);
 
 
         int runCount = 20;
         int runCount = 20;
         int rowsCount = 1000;
         int rowsCount = 1000;
@@ -92,7 +94,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
                 }
                 }
 
 
                 String[] processedRow = Arrays.copyOf(row, row.length);
                 String[] processedRow = Arrays.copyOf(row, row.length);
-                customProcessor.process(processedRow);
+                crossValidationSplitter.process(processedRow);
 
 
                 for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
                 for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
                     if (fieldIndex != dependentVariableIndex) {
                     if (fieldIndex != dependentVariableIndex) {
@@ -124,7 +126,8 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
     }
     }
 
 
     public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
     public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
-        CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0, randomizeSeed);
+        CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
+            fields, dependentVariable, 1.0, randomizeSeed);
 
 
         // We have some non-training rows and then a training row to check
         // We have some non-training rows and then a training row to check
         // we maintain the first training row and not just the first row
         // we maintain the first training row and not just the first row
@@ -139,7 +142,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
             }
             }
 
 
             String[] processedRow = Arrays.copyOf(row, row.length);
             String[] processedRow = Arrays.copyOf(row, row.length);
-            customProcessor.process(processedRow);
+            crossValidationSplitter.process(processedRow);
 
 
             assertThat(Arrays.equals(processedRow, row), is(true));
             assertThat(Arrays.equals(processedRow, row), is(true));
         }
         }