|
@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
|
|
|
import org.apache.lucene.util.Accountable;
|
|
|
import org.apache.lucene.util.Accountables;
|
|
|
import org.apache.lucene.util.RamUsageEstimator;
|
|
|
+import org.elasticsearch.Version;
|
|
|
import org.elasticsearch.common.ParseField;
|
|
|
import org.elasticsearch.common.Strings;
|
|
|
import org.elasticsearch.common.collect.Tuple;
|
|
@@ -29,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfi
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
|
|
|
|
|
@@ -100,7 +102,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
this.nodes = Collections.unmodifiableList(nodes);
|
|
|
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
|
|
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
|
|
- this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
|
|
|
+ this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue);
|
|
|
}
|
|
|
|
|
|
public Tree(StreamInput in) throws IOException {
|
|
@@ -112,7 +114,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
} else {
|
|
|
this.classificationLabels = null;
|
|
|
}
|
|
|
- this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
|
|
|
+ this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -147,7 +149,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
return buildResult(node.getLeafValue(), featureImportance, config);
|
|
|
}
|
|
|
|
|
|
- private InferenceResults buildResult(Double value, Map<String, Double> featureImportance, InferenceConfig config) {
|
|
|
+ private InferenceResults buildResult(double[] value, Map<String, Double> featureImportance, InferenceConfig config) {
|
|
|
+ assert value != null && value.length > 0;
|
|
|
// Indicates that the config is useless and the caller just wants the raw value
|
|
|
if (config instanceof NullInferenceConfig) {
|
|
|
return new RawInferenceResults(value, featureImportance);
|
|
@@ -160,13 +163,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
classificationLabels,
|
|
|
null,
|
|
|
classificationConfig.getNumTopClasses());
|
|
|
- return new ClassificationInferenceResults(value,
|
|
|
+ return new ClassificationInferenceResults(topClasses.v1(),
|
|
|
classificationLabel(topClasses.v1(), classificationLabels),
|
|
|
topClasses.v2(),
|
|
|
featureImportance,
|
|
|
config);
|
|
|
case REGRESSION:
|
|
|
- return new RegressionInferenceResults(value, config, featureImportance);
|
|
|
+ return new RegressionInferenceResults(value[0], config, featureImportance);
|
|
|
default:
|
|
|
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
|
|
|
}
|
|
@@ -193,14 +196,22 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
return targetType;
|
|
|
}
|
|
|
|
|
|
- private List<Double> classificationProbability(double inferenceValue) {
|
|
|
+ private double[] classificationProbability(double[] inferenceValue) {
|
|
|
+ // Multi-value leaves, indicates that the leaves contain an array of values.
|
|
|
+ // The index of which corresponds to classification values
|
|
|
+ if (inferenceValue.length > 1) {
|
|
|
+ return Statistics.softMax(inferenceValue);
|
|
|
+ }
|
|
|
// If we are classification, we should assume that the inference return value is whole.
|
|
|
- assert inferenceValue == Math.rint(inferenceValue);
|
|
|
+ assert inferenceValue[0] == Math.rint(inferenceValue[0]);
|
|
|
double maxCategory = this.highestOrderCategory.get();
|
|
|
// If we are classification, we should assume that the largest leaf value is whole.
|
|
|
assert maxCategory == Math.rint(maxCategory);
|
|
|
- List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
|
|
|
- list.set(Double.valueOf(inferenceValue).intValue(), 1.0);
|
|
|
+ double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
|
|
|
+ .stream()
|
|
|
+ .mapToDouble(Double::doubleValue)
|
|
|
+ .toArray();
|
|
|
+ list[Double.valueOf(inferenceValue[0]).intValue()] = 1.0;
|
|
|
return list;
|
|
|
}
|
|
|
|
|
@@ -268,6 +279,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
checkTargetType();
|
|
|
detectMissingNodes();
|
|
|
detectCycle();
|
|
|
+ verifyLeafNodeUniformity();
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -331,7 +343,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
TreeNode currNode = nodes.get(nodeIndex);
|
|
|
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
|
|
|
if (currNode.isLeaf()) {
|
|
|
- // TODO multi-value????
|
|
|
double leafValue = nodeValues[nodeIndex];
|
|
|
for (int i = 1; i < nextIndex; ++i) {
|
|
|
double scale = splitPath.sumUnwoundPath(i, nextIndex);
|
|
@@ -375,7 +386,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
|
|
|
TreeNode node = nodes.get(nodeIndex);
|
|
|
if (node.isLeaf()) {
|
|
|
- nodeEstimates[nodeIndex] = node.getLeafValue();
|
|
|
+ // TODO multi-value????
|
|
|
+ nodeEstimates[nodeIndex] = node.getLeafValue()[0];
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
@@ -424,6 +436,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
throw ExceptionsHelper.badRequestException(
|
|
|
"[target_type] should be [classification] if [classification_labels] are provided");
|
|
|
}
|
|
|
+ if (this.targetType != TargetType.CLASSIFICATION && this.nodes.stream().anyMatch(n -> n.getLeafValue().length > 1)) {
|
|
|
+ throw ExceptionsHelper.badRequestException(
|
|
|
+ "[target_type] should be [classification] if leaf nodes have multiple values");
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
private void detectCycle() {
|
|
@@ -465,14 +481,39 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private void verifyLeafNodeUniformity() {
|
|
|
+ Integer leafValueLengths = null;
|
|
|
+ for (TreeNode node : nodes) {
|
|
|
+ if (node.isLeaf()) {
|
|
|
+ if (leafValueLengths == null) {
|
|
|
+ leafValueLengths = node.getLeafValue().length;
|
|
|
+ } else if (leafValueLengths != node.getLeafValue().length) {
|
|
|
+ throw ExceptionsHelper.badRequestException(
|
|
|
+ "[tree.tree_structure] all leaf nodes must have the same number of values");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private static boolean nodeMissing(int nodeIdx, List<TreeNode> nodes) {
|
|
|
return nodeIdx >= nodes.size();
|
|
|
}
|
|
|
|
|
|
private Double maxLeafValue() {
|
|
|
- return targetType == TargetType.CLASSIFICATION ?
|
|
|
- this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() :
|
|
|
- null;
|
|
|
+ if (targetType != TargetType.CLASSIFICATION) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ double max = 0.0;
|
|
|
+ for (TreeNode node : this.nodes) {
|
|
|
+ if (node.isLeaf()) {
|
|
|
+ if (node.getLeafValue().length > 1) {
|
|
|
+ return (double)node.getLeafValue().length;
|
|
|
+ } else {
|
|
|
+ max = Math.max(node.getLeafValue()[0], max);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return max;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -493,6 +534,14 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
return Collections.unmodifiableCollection(accountables);
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ public Version getMinimalCompatibilityVersion() {
|
|
|
+ if (nodes.stream().filter(TreeNode::isLeaf).anyMatch(t -> t.getLeafValue().length > 1)) {
|
|
|
+ return Version.V_7_7_0;
|
|
|
+ }
|
|
|
+ return Version.V_7_6_0;
|
|
|
+ }
|
|
|
+
|
|
|
public static class Builder {
|
|
|
private List<String> featureNames;
|
|
|
private ArrayList<TreeNode.Builder> nodes;
|
|
@@ -586,6 +635,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
* @return this
|
|
|
*/
|
|
|
Tree.Builder addLeaf(int nodeIndex, double value) {
|
|
|
+ return addLeaf(nodeIndex, Arrays.asList(value));
|
|
|
+ }
|
|
|
+
|
|
|
+ Tree.Builder addLeaf(int nodeIndex, List<Double> value) {
|
|
|
for (int i = nodes.size(); i < nodeIndex + 1; i++) {
|
|
|
nodes.add(null);
|
|
|
}
|