|
@@ -58,7 +58,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo
|
|
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.SPLIT_FEATURE;
|
|
|
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.THRESHOLD;
|
|
|
|
|
|
-public class TreeInferenceModel implements InferenceModel {
|
|
|
+public class TreeInferenceModel implements InferenceModel, BoundedInferenceModel {
|
|
|
|
|
|
private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class);
|
|
|
public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
|
|
@@ -90,7 +90,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|
|
private String[] featureNames;
|
|
|
private final TargetType targetType;
|
|
|
private List<String> classificationLabels;
|
|
|
- private final double highOrderCategory;
|
|
|
+ private final double[] leafBoundaries;
|
|
|
private final int maxDepth;
|
|
|
private final int leafSize;
|
|
|
private volatile boolean preparedForInference = false;
|
|
@@ -108,7 +108,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|
|
this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
|
|
|
this.targetType = targetType == null ? TargetType.REGRESSION : targetType;
|
|
|
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
|
|
- this.highOrderCategory = maxLeafValue();
|
|
|
+ this.leafBoundaries = getLeafBoundaries();
|
|
|
int leafSize = 1;
|
|
|
for (Node node : this.nodes) {
|
|
|
if (node instanceof LeafNode leafNode) {
|
|
@@ -218,7 +218,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|
|
}
|
|
|
// If we are classification, we should assume that the inference return value is whole.
|
|
|
assert inferenceValue[0] == Math.rint(inferenceValue[0]);
|
|
|
- double maxCategory = this.highOrderCategory;
|
|
|
+ double maxCategory = getHighOrderCategory();
|
|
|
// If we are classification, we should assume that the largest leaf value is whole.
|
|
|
assert maxCategory == Math.rint(maxCategory);
|
|
|
double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
|
|
@@ -366,21 +366,20 @@ public class TreeInferenceModel implements InferenceModel {
|
|
|
return size;
|
|
|
}
|
|
|
|
|
|
- private double maxLeafValue() {
|
|
|
- if (targetType != TargetType.CLASSIFICATION) {
|
|
|
- return Double.NaN;
|
|
|
- }
|
|
|
- double max = 0.0;
|
|
|
+ private double[] getLeafBoundaries() {
|
|
|
+ double[] bounds = new double[] { Double.MAX_VALUE, Double.MIN_VALUE };
|
|
|
+
|
|
|
for (Node node : this.nodes) {
|
|
|
if (node instanceof LeafNode leafNode) {
|
|
|
if (leafNode.leafValue.length > 1) {
|
|
|
- return leafNode.leafValue.length;
|
|
|
+ return new double[] { 0, leafNode.leafValue.length };
|
|
|
} else {
|
|
|
- max = Math.max(leafNode.leafValue[0], max);
|
|
|
+ bounds[0] = Math.min(leafNode.leafValue[0], bounds[0]);
|
|
|
+ bounds[1] = Math.max(leafNode.leafValue[0], bounds[1]);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- return max;
|
|
|
+ return bounds;
|
|
|
}
|
|
|
|
|
|
public Node[] getNodes() {
|
|
@@ -389,24 +388,35 @@ public class TreeInferenceModel implements InferenceModel {
|
|
|
|
|
|
@Override
|
|
|
public String toString() {
|
|
|
- return "TreeInferenceModel{"
|
|
|
- + "nodes="
|
|
|
- + Arrays.toString(nodes)
|
|
|
- + ", featureNames="
|
|
|
- + Arrays.toString(featureNames)
|
|
|
- + ", targetType="
|
|
|
- + targetType
|
|
|
- + ", classificationLabels="
|
|
|
- + classificationLabels
|
|
|
- + ", highOrderCategory="
|
|
|
- + highOrderCategory
|
|
|
- + ", maxDepth="
|
|
|
- + maxDepth
|
|
|
- + ", leafSize="
|
|
|
- + leafSize
|
|
|
- + ", preparedForInference="
|
|
|
- + preparedForInference
|
|
|
- + '}';
|
|
|
+ StringBuilder builder = new StringBuilder("TreeInferenceModel{");
|
|
|
+
|
|
|
+ builder.append("nodes=")
|
|
|
+ .append(Arrays.toString(nodes))
|
|
|
+ .append(", featureNames=")
|
|
|
+ .append(Arrays.toString(featureNames))
|
|
|
+ .append(", targetType=")
|
|
|
+ .append(targetType);
|
|
|
+
|
|
|
+ if (targetType == TargetType.CLASSIFICATION) {
|
|
|
+ builder.append(", classificationLabels=")
|
|
|
+ .append(classificationLabels)
|
|
|
+ .append(", highOrderCategory=")
|
|
|
+ .append(getHighOrderCategory());
|
|
|
+ } else if (targetType == TargetType.REGRESSION) {
|
|
|
+ builder.append(", minPredictedValue=")
|
|
|
+ .append(getMinPredictedValue())
|
|
|
+ .append(", maxPredictedValue=")
|
|
|
+ .append(getMaxPredictedValue());
|
|
|
+ }
|
|
|
+
|
|
|
+ builder.append(", maxDepth=")
|
|
|
+ .append(maxDepth)
|
|
|
+ .append(", leafSize=")
|
|
|
+ .append(leafSize)
|
|
|
+ .append(", preparedForInference=")
|
|
|
+ .append(preparedForInference);
|
|
|
+
|
|
|
+ return builder.append('}').toString();
|
|
|
}
|
|
|
|
|
|
private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
|
|
@@ -420,6 +430,20 @@ public class TreeInferenceModel implements InferenceModel {
|
|
|
return Math.max(depthLeft, depthRight) + 1;
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ public double getMinPredictedValue() {
|
|
|
+ return leafBoundaries[0];
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public double getMaxPredictedValue() {
|
|
|
+ return leafBoundaries[1];
|
|
|
+ }
|
|
|
+
|
|
|
+ private double getHighOrderCategory() {
|
|
|
+ return getMaxPredictedValue();
|
|
|
+ }
|
|
|
+
|
|
|
static class NodeBuilder {
|
|
|
|
|
|
private static final ObjectParser<NodeBuilder, Void> PARSER = new ObjectParser<>(
|