|
@@ -3,7 +3,7 @@
|
|
|
* or more contributor license agreements. Licensed under the Elastic License;
|
|
|
* you may not use this file except in compliance with the Elastic License.
|
|
|
*/
|
|
|
-package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
|
|
|
+package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
|
|
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
|
@@ -20,7 +20,7 @@ import static org.hamcrest.Matchers.greaterThan;
|
|
|
import static org.hamcrest.Matchers.is;
|
|
|
import static org.hamcrest.Matchers.lessThan;
|
|
|
|
|
|
-public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|
|
+public class RandomCrossValidationSplitterTests extends ESTestCase {
|
|
|
|
|
|
private List<String> fields;
|
|
|
private int dependentVariableIndex;
|
|
@@ -40,7 +40,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
public void testProcess_GivenRowsWithoutDependentVariableValue() {
|
|
|
- CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0, randomizeSeed);
|
|
|
+ CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(fields, dependentVariable, 50.0, randomizeSeed);
|
|
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
|
String[] row = new String[fields.size()];
|
|
@@ -50,7 +50,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
|
|
- customProcessor.process(processedRow);
|
|
|
+ crossValidationSplitter.process(processedRow);
|
|
|
|
|
|
// As all these rows have no dependent variable value, they're not for training and should be unaffected
|
|
|
assertThat(Arrays.equals(processedRow, row), is(true));
|
|
@@ -58,7 +58,8 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
|
|
- CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0, randomizeSeed);
|
|
|
+ CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
|
|
|
+ fields, dependentVariable, 100.0, randomizeSeed);
|
|
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
|
String[] row = new String[fields.size()];
|
|
@@ -68,7 +69,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
|
|
- customProcessor.process(processedRow);
|
|
|
+ crossValidationSplitter.process(processedRow);
|
|
|
|
|
|
// We should pick them all as training percent is 100
|
|
|
assertThat(Arrays.equals(processedRow, row), is(true));
|
|
@@ -78,7 +79,8 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|
|
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
|
|
|
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
|
|
|
double trainingFraction = trainingPercent / 100;
|
|
|
- CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent, randomizeSeed);
|
|
|
+ CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
|
|
|
+ fields, dependentVariable, trainingPercent, randomizeSeed);
|
|
|
|
|
|
int runCount = 20;
|
|
|
int rowsCount = 1000;
|
|
@@ -92,7 +94,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
|
|
- customProcessor.process(processedRow);
|
|
|
+ crossValidationSplitter.process(processedRow);
|
|
|
|
|
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
|
|
if (fieldIndex != dependentVariableIndex) {
|
|
@@ -124,7 +126,8 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
|
|
|
- CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0, randomizeSeed);
|
|
|
+ CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
|
|
|
+ fields, dependentVariable, 1.0, randomizeSeed);
|
|
|
|
|
|
// We have some non-training rows and then a training row to check
|
|
|
// we maintain the first training row and not just the first row
|
|
@@ -139,7 +142,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
|
|
- customProcessor.process(processedRow);
|
|
|
+ crossValidationSplitter.process(processedRow);
|
|
|
|
|
|
assertThat(Arrays.equals(processedRow, row), is(true));
|
|
|
}
|