|
@@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
|
|
|
+import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
|
|
import org.junit.Before;
|
|
|
import java.io.IOException;
|
|
|
import java.util.ArrayList;
|
|
@@ -39,6 +40,7 @@ import static org.hamcrest.Matchers.closeTo;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
|
|
|
public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
+ private final double eps = 1.0E-8;
|
|
|
|
|
|
private boolean lenient;
|
|
|
|
|
@@ -267,7 +269,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
List<Double> scores = Arrays.asList(0.230557435, 0.162032651);
|
|
|
double eps = 0.000001;
|
|
|
List<ClassificationInferenceResults.TopClassEntry> probabilities =
|
|
|
- ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
|
|
+ ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
|
|
+ .getTopClasses();
|
|
|
for(int i = 0; i < expected.size(); i++) {
|
|
|
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
|
|
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
|
@@ -278,7 +281,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
expected = Arrays.asList(0.310025518, 0.6899744811);
|
|
|
scores = Arrays.asList(0.217017863, 0.2069923443);
|
|
|
probabilities =
|
|
|
- ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
|
|
+ ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
|
|
+ .getTopClasses();
|
|
|
for(int i = 0; i < expected.size(); i++) {
|
|
|
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
|
|
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
|
@@ -289,7 +293,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
expected = Arrays.asList(0.768524783, 0.231475216);
|
|
|
scores = Arrays.asList(0.230557435, 0.162032651);
|
|
|
probabilities =
|
|
|
- ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
|
|
+ ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
|
|
+ .getTopClasses();
|
|
|
for(int i = 0; i < expected.size(); i++) {
|
|
|
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
|
|
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
|
@@ -303,7 +308,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
expected = Arrays.asList(0.6899744811, 0.3100255188);
|
|
|
scores = Arrays.asList(0.482982136, 0.0930076556);
|
|
|
probabilities =
|
|
|
- ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
|
|
|
+ ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
|
|
|
+ .getTopClasses();
|
|
|
for(int i = 0; i < expected.size(); i++) {
|
|
|
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
|
|
|
assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps));
|
|
@@ -361,24 +367,28 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
|
|
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
|
|
assertThat(1.0,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
featureVector = Arrays.asList(2.0, 0.7);
|
|
|
featureMap = zipObjMap(featureNames, featureVector);
|
|
|
assertThat(1.0,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
featureVector = Arrays.asList(0.0, 1.0);
|
|
|
featureMap = zipObjMap(featureNames, featureVector);
|
|
|
assertThat(1.0,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
featureMap = new HashMap<>(2) {{
|
|
|
put("foo", 0.3);
|
|
|
put("bar", null);
|
|
|
}};
|
|
|
assertThat(0.0,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
|
|
+ 0.00001));
|
|
|
}
|
|
|
|
|
|
public void testMultiClassClassificationInference() {
|
|
@@ -432,24 +442,28 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
|
|
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
|
|
assertThat(2.0,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
featureVector = Arrays.asList(2.0, 0.7);
|
|
|
featureMap = zipObjMap(featureNames, featureVector);
|
|
|
assertThat(1.0,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
featureVector = Arrays.asList(0.0, 1.0);
|
|
|
featureMap = zipObjMap(featureNames, featureVector);
|
|
|
assertThat(1.0,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
featureMap = new HashMap<>(2) {{
|
|
|
put("foo", 0.6);
|
|
|
put("bar", null);
|
|
|
}};
|
|
|
assertThat(1.0,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0), Collections.emptyMap())).value(),
|
|
|
+ 0.00001));
|
|
|
}
|
|
|
|
|
|
public void testRegressionInference() {
|
|
@@ -489,12 +503,16 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
List<Double> featureVector = Arrays.asList(0.4, 0.0);
|
|
|
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
|
|
|
assertThat(0.9,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
|
|
+ .value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
featureVector = Arrays.asList(2.0, 0.7);
|
|
|
featureMap = zipObjMap(featureNames, featureVector);
|
|
|
assertThat(0.5,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
|
|
+ .value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
// Test with NO aggregator supplied, verifies default behavior of non-weighted sum
|
|
|
ensemble = Ensemble.builder()
|
|
@@ -506,19 +524,25 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
featureVector = Arrays.asList(0.4, 0.0);
|
|
|
featureMap = zipObjMap(featureNames, featureVector);
|
|
|
assertThat(1.8,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
|
|
+ .value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
featureVector = Arrays.asList(2.0, 0.7);
|
|
|
featureMap = zipObjMap(featureNames, featureVector);
|
|
|
assertThat(1.0,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
|
|
+ .value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
featureMap = new HashMap<>(2) {{
|
|
|
put("foo", 0.3);
|
|
|
put("bar", null);
|
|
|
}};
|
|
|
assertThat(1.8,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
|
|
+ .value(),
|
|
|
+ 0.00001));
|
|
|
}
|
|
|
|
|
|
public void testInferNestedFields() {
|
|
@@ -564,7 +588,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
}});
|
|
|
}};
|
|
|
assertThat(0.9,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
|
|
+ .value(),
|
|
|
+ 0.00001));
|
|
|
|
|
|
featureMap = new HashMap<>() {{
|
|
|
put("foo", new HashMap<>(){{
|
|
@@ -575,7 +601,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
}});
|
|
|
}};
|
|
|
assertThat(0.5,
|
|
|
- closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
|
|
|
+ closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS, Collections.emptyMap()))
|
|
|
+ .value(),
|
|
|
+ 0.00001));
|
|
|
}
|
|
|
|
|
|
public void testOperationsEstimations() {
|
|
@@ -590,6 +618,114 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
|
|
|
assertThat(ensemble.estimatedNumOperations(), equalTo(9L));
|
|
|
}
|
|
|
|
|
|
+ public void testFeatureImportance() {
|
|
|
+ List<String> featureNames = Arrays.asList("foo", "bar");
|
|
|
+ Tree tree1 = Tree.builder()
|
|
|
+ .setFeatureNames(featureNames)
|
|
|
+ .setNodes(
|
|
|
+ TreeNode.builder(0)
|
|
|
+ .setSplitFeature(0)
|
|
|
+ .setOperator(Operator.LT)
|
|
|
+ .setLeftChild(1)
|
|
|
+ .setRightChild(2)
|
|
|
+ .setThreshold(0.55)
|
|
|
+ .setNumberSamples(10L),
|
|
|
+ TreeNode.builder(1)
|
|
|
+ .setSplitFeature(0)
|
|
|
+ .setLeftChild(3)
|
|
|
+ .setRightChild(4)
|
|
|
+ .setOperator(Operator.LT)
|
|
|
+ .setThreshold(0.41)
|
|
|
+ .setNumberSamples(6L),
|
|
|
+ TreeNode.builder(2)
|
|
|
+ .setSplitFeature(1)
|
|
|
+ .setLeftChild(5)
|
|
|
+ .setRightChild(6)
|
|
|
+ .setOperator(Operator.LT)
|
|
|
+ .setThreshold(0.25)
|
|
|
+ .setNumberSamples(4L),
|
|
|
+ TreeNode.builder(3).setLeafValue(1.18230136).setNumberSamples(5L),
|
|
|
+ TreeNode.builder(4).setLeafValue(1.98006658).setNumberSamples(1L),
|
|
|
+ TreeNode.builder(5).setLeafValue(3.25350885).setNumberSamples(3L),
|
|
|
+ TreeNode.builder(6).setLeafValue(2.42384369).setNumberSamples(1L)).build();
|
|
|
+
|
|
|
+ Tree tree2 = Tree.builder()
|
|
|
+ .setFeatureNames(featureNames)
|
|
|
+ .setNodes(
|
|
|
+ TreeNode.builder(0)
|
|
|
+ .setSplitFeature(0)
|
|
|
+ .setOperator(Operator.LT)
|
|
|
+ .setLeftChild(1)
|
|
|
+ .setRightChild(2)
|
|
|
+ .setThreshold(0.45)
|
|
|
+ .setNumberSamples(10L),
|
|
|
+ TreeNode.builder(1)
|
|
|
+ .setSplitFeature(0)
|
|
|
+ .setLeftChild(3)
|
|
|
+ .setRightChild(4)
|
|
|
+ .setOperator(Operator.LT)
|
|
|
+ .setThreshold(0.25)
|
|
|
+ .setNumberSamples(5L),
|
|
|
+ TreeNode.builder(2)
|
|
|
+ .setSplitFeature(0)
|
|
|
+ .setLeftChild(5)
|
|
|
+ .setRightChild(6)
|
|
|
+ .setOperator(Operator.LT)
|
|
|
+ .setThreshold(0.59)
|
|
|
+ .setNumberSamples(5L),
|
|
|
+ TreeNode.builder(3).setLeafValue(1.04476388).setNumberSamples(3L),
|
|
|
+ TreeNode.builder(4).setLeafValue(1.52799228).setNumberSamples(2L),
|
|
|
+ TreeNode.builder(5).setLeafValue(1.98006658).setNumberSamples(1L),
|
|
|
+ TreeNode.builder(6).setLeafValue(2.950216).setNumberSamples(4L)).build();
|
|
|
+
|
|
|
+ Ensemble ensemble = Ensemble.builder().setOutputAggregator(new WeightedSum())
|
|
|
+ .setTrainedModels(Arrays.asList(tree1, tree2))
|
|
|
+ .setFeatureNames(featureNames)
|
|
|
+ .build();
|
|
|
+
|
|
|
+
|
|
|
+ Map<String, Double> featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.0, 0.9)));
|
|
|
+ assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
|
|
|
+ assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
|
|
+
|
|
|
+ featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.1, 0.8)));
|
|
|
+ assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
|
|
|
+ assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
|
|
+
|
|
|
+ featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.2, 0.7)));
|
|
|
+ assertThat(featureImportance.get("foo"), closeTo(-1.653200025, eps));
|
|
|
+ assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
|
|
+
|
|
|
+ featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.3, 0.6)));
|
|
|
+ assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
|
|
|
+ assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
|
|
+
|
|
|
+ featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.4, 0.5)));
|
|
|
+ assertThat(featureImportance.get("foo"), closeTo(-1.16997162, eps));
|
|
|
+ assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
|
|
+
|
|
|
+ featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.5, 0.4)));
|
|
|
+ assertThat(featureImportance.get("foo"), closeTo(0.0798679, eps));
|
|
|
+ assertThat(featureImportance.get("bar"), closeTo( -0.12444978, eps));
|
|
|
+
|
|
|
+ featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.6, 0.3)));
|
|
|
+ assertThat(featureImportance.get("foo"), closeTo(1.80491886, eps));
|
|
|
+ assertThat(featureImportance.get("bar"), closeTo(-0.4355742, eps));
|
|
|
+
|
|
|
+ featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.7, 0.2)));
|
|
|
+ assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
|
|
|
+ assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
|
|
|
+
|
|
|
+ featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.8, 0.1)));
|
|
|
+ assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
|
|
|
+ assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
|
|
|
+
|
|
|
+ featureImportance = ensemble.featureImportance(zipObjMap(featureNames, Arrays.asList(0.9, 0.0)));
|
|
|
+ assertThat(featureImportance.get("foo"), closeTo(2.0538184, eps));
|
|
|
+ assertThat(featureImportance.get("bar"), closeTo(0.1451914, eps));
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
|
|
|
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
|
|
|
}
|