|
@@ -8,25 +8,41 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses;
|
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.Version;
|
|
|
import org.elasticsearch.common.Strings;
|
|
|
+import org.elasticsearch.common.bytes.BytesArray;
|
|
|
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|
|
import org.elasticsearch.common.io.stream.Writeable;
|
|
|
+import org.elasticsearch.common.settings.Settings;
|
|
|
+import org.elasticsearch.common.xcontent.DeprecationHandler;
|
|
|
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
|
|
import org.elasticsearch.common.xcontent.ToXContent;
|
|
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
|
+import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.common.xcontent.XContentParser;
|
|
|
+import org.elasticsearch.common.xcontent.XContentType;
|
|
|
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
|
|
import org.elasticsearch.index.mapper.BooleanFieldMapper;
|
|
|
import org.elasticsearch.index.mapper.KeywordFieldMapper;
|
|
|
import org.elasticsearch.index.mapper.NumberFieldMapper;
|
|
|
+import org.elasticsearch.search.SearchModule;
|
|
|
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
+import java.util.ArrayList;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
+import java.util.stream.Collectors;
|
|
|
+import java.util.stream.Stream;
|
|
|
|
|
|
import static org.hamcrest.Matchers.allOf;
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
@@ -55,6 +71,21 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
return createRandom();
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ protected NamedXContentRegistry xContentRegistry() {
|
|
|
+ List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
|
|
+ namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
|
|
+ namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
|
|
|
+ return new NamedXContentRegistry(namedXContent);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ protected NamedWriteableRegistry getNamedWriteableRegistry() {
|
|
|
+ List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
|
|
|
+ entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
|
|
|
+ return new NamedWriteableRegistry(entries);
|
|
|
+ }
|
|
|
+
|
|
|
public static Classification createRandom() {
|
|
|
String dependentVariableName = randomAlphaOfLength(10);
|
|
|
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
|
|
@@ -65,7 +96,14 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true);
|
|
|
Long randomizeSeed = randomBoolean() ? null : randomLong();
|
|
|
return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective,
|
|
|
- numTopClasses, trainingPercent, randomizeSeed);
|
|
|
+ numTopClasses, trainingPercent, randomizeSeed,
|
|
|
+ randomBoolean() ?
|
|
|
+ null :
|
|
|
+ Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(true),
|
|
|
+ OneHotEncodingTests.createRandom(true),
|
|
|
+ TargetMeanEncodingTests.createRandom(true)))
|
|
|
+ .limit(randomIntBetween(0, 5))
|
|
|
+ .collect(Collectors.toList()));
|
|
|
}
|
|
|
|
|
|
public static Classification mutateForVersion(Classification instance, Version version) {
|
|
@@ -75,7 +113,8 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null,
|
|
|
instance.getNumTopClasses(),
|
|
|
instance.getTrainingPercent(),
|
|
|
- instance.getRandomizeSeed());
|
|
|
+ instance.getRandomizeSeed(),
|
|
|
+ version.onOrAfter(Version.V_8_0_0) ? instance.getFeatureProcessors() : Collections.emptyList());
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -91,14 +130,16 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
bwcSerializedObject.getClassAssignmentObjective(),
|
|
|
bwcSerializedObject.getNumTopClasses(),
|
|
|
bwcSerializedObject.getTrainingPercent(),
|
|
|
- 42L);
|
|
|
+ 42L,
|
|
|
+ bwcSerializedObject.getFeatureProcessors());
|
|
|
Classification newInstance = new Classification(testInstance.getDependentVariable(),
|
|
|
testInstance.getBoostedTreeParams(),
|
|
|
testInstance.getPredictionFieldName(),
|
|
|
testInstance.getClassAssignmentObjective(),
|
|
|
testInstance.getNumTopClasses(),
|
|
|
testInstance.getTrainingPercent(),
|
|
|
- 42L);
|
|
|
+ 42L,
|
|
|
+ testInstance.getFeatureProcessors());
|
|
|
super.assertOnBWCObject(newBwc, newInstance, version);
|
|
|
}
|
|
|
|
|
@@ -107,87 +148,138 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
return Classification::new;
|
|
|
}
|
|
|
|
|
|
+ public void testDeserialization() throws IOException {
|
|
|
+ String toDeserialize = "{\n" +
|
|
|
+ " \"dependent_variable\": \"FlightDelayMin\",\n" +
|
|
|
+ " \"feature_processors\": [\n" +
|
|
|
+ " {\n" +
|
|
|
+ " \"one_hot_encoding\": {\n" +
|
|
|
+ " \"field\": \"OriginWeather\",\n" +
|
|
|
+ " \"hot_map\": {\n" +
|
|
|
+ " \"sunny_col\": \"Sunny\",\n" +
|
|
|
+ " \"clear_col\": \"Clear\",\n" +
|
|
|
+ " \"rainy_col\": \"Rain\"\n" +
|
|
|
+ " }\n" +
|
|
|
+ " }\n" +
|
|
|
+ " },\n" +
|
|
|
+ " {\n" +
|
|
|
+ " \"one_hot_encoding\": {\n" +
|
|
|
+ " \"field\": \"DestWeather\",\n" +
|
|
|
+ " \"hot_map\": {\n" +
|
|
|
+ " \"dest_sunny_col\": \"Sunny\",\n" +
|
|
|
+ " \"dest_clear_col\": \"Clear\",\n" +
|
|
|
+ " \"dest_rainy_col\": \"Rain\"\n" +
|
|
|
+ " }\n" +
|
|
|
+ " }\n" +
|
|
|
+ " },\n" +
|
|
|
+ " {\n" +
|
|
|
+ " \"frequency_encoding\": {\n" +
|
|
|
+ " \"field\": \"OriginWeather\",\n" +
|
|
|
+ " \"feature_name\": \"mean\",\n" +
|
|
|
+ " \"frequency_map\": {\n" +
|
|
|
+ " \"Sunny\": 0.8,\n" +
|
|
|
+ " \"Rain\": 0.2\n" +
|
|
|
+ " }\n" +
|
|
|
+ " }\n" +
|
|
|
+ " }\n" +
|
|
|
+ " ]\n" +
|
|
|
+ " }" +
|
|
|
+ "";
|
|
|
+
|
|
|
+ try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
|
|
|
+ DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
|
|
+ new BytesArray(toDeserialize),
|
|
|
+ XContentType.JSON)) {
|
|
|
+ Classification parsed = Classification.fromXContent(parser, false);
|
|
|
+ assertThat(parsed.getDependentVariable(), equalTo("FlightDelayMin"));
|
|
|
+ for (PreProcessor preProcessor : parsed.getFeatureProcessors()) {
|
|
|
+ assertThat(preProcessor.isCustom(), is(true));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong()));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong(), null));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong()));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong(), null));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenNumTopClassesIsLessThanZero() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong()));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong(), null));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
|
|
}
|
|
|
|
|
|
public void testConstructor_GivenNumTopClassesIsGreaterThan1000() {
|
|
|
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
|
|
- () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong()));
|
|
|
+ () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong(), null));
|
|
|
|
|
|
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
|
|
|
}
|
|
|
|
|
|
public void testGetPredictionFieldName() {
|
|
|
- Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
|
|
|
+ Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
|
|
|
assertThat(classification.getPredictionFieldName(), equalTo("result"));
|
|
|
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong(), null);
|
|
|
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
|
|
|
}
|
|
|
|
|
|
public void testClassAssignmentObjective() {
|
|
|
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
|
|
- Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong());
|
|
|
+ Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong(), null);
|
|
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY));
|
|
|
|
|
|
classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result",
|
|
|
- Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong());
|
|
|
+ Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong(), null);
|
|
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
|
|
|
|
|
// class_assignment_objective == null, default applied
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
|
|
assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL));
|
|
|
}
|
|
|
|
|
|
public void testGetNumTopClasses() {
|
|
|
- Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong());
|
|
|
+ Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong(), null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(7));
|
|
|
|
|
|
// Boundary condition: num_top_classes == 0
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong(), null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(0));
|
|
|
|
|
|
// Boundary condition: num_top_classes == 1000
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong(), null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(1000));
|
|
|
|
|
|
// num_top_classes == null, default applied
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong(), null);
|
|
|
assertThat(classification.getNumTopClasses(), equalTo(2));
|
|
|
}
|
|
|
|
|
|
public void testGetTrainingPercent() {
|
|
|
- Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong());
|
|
|
+ Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong(), null);
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(50.0));
|
|
|
|
|
|
// Boundary condition: training_percent == 1.0
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong(), null);
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(1.0));
|
|
|
|
|
|
// Boundary condition: training_percent == 100.0
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong(), null);
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
|
|
|
|
|
// training_percent == null, default applied
|
|
|
- classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong());
|
|
|
+ classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong(), null);
|
|
|
assertThat(classification.getTrainingPercent(), equalTo(100.0));
|
|
|
}
|
|
|
|
|
@@ -231,6 +323,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
|
|
|
null,
|
|
|
null,
|
|
|
50.0,
|
|
|
+ null,
|
|
|
null).getParams(fieldInfo),
|
|
|
equalTo(
|
|
|
Map.of(
|