|
@@ -9,11 +9,13 @@ import org.elasticsearch.common.ParseField;
|
|
|
import org.elasticsearch.common.Strings;
|
|
import org.elasticsearch.common.Strings;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
|
|
+import org.elasticsearch.common.util.CachedSupplier;
|
|
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
|
import org.elasticsearch.common.xcontent.ObjectParser;
|
|
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.common.xcontent.XContentParser;
|
|
import org.elasticsearch.common.xcontent.XContentParser;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
|
|
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
|
|
|
|
|
import java.io.IOException;
|
|
import java.io.IOException;
|
|
@@ -31,10 +33,13 @@ import java.util.stream.Collectors;
|
|
|
|
|
|
|
|
public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
|
|
public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
|
|
|
|
|
|
|
|
|
|
+ // TODO should we have regression/classification sub-classes that accept the builder?
|
|
|
public static final ParseField NAME = new ParseField("tree");
|
|
public static final ParseField NAME = new ParseField("tree");
|
|
|
|
|
|
|
|
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
|
|
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
|
|
|
public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure");
|
|
public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure");
|
|
|
|
|
+ public static final ParseField TARGET_TYPE = new ParseField("target_type");
|
|
|
|
|
+ public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
|
|
|
|
|
|
|
|
private static final ObjectParser<Tree.Builder, Void> LENIENT_PARSER = createParser(true);
|
|
private static final ObjectParser<Tree.Builder, Void> LENIENT_PARSER = createParser(true);
|
|
|
private static final ObjectParser<Tree.Builder, Void> STRICT_PARSER = createParser(false);
|
|
private static final ObjectParser<Tree.Builder, Void> STRICT_PARSER = createParser(false);
|
|
@@ -46,6 +51,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
Tree.Builder::new);
|
|
Tree.Builder::new);
|
|
|
parser.declareStringArray(Tree.Builder::setFeatureNames, FEATURE_NAMES);
|
|
parser.declareStringArray(Tree.Builder::setFeatureNames, FEATURE_NAMES);
|
|
|
parser.declareObjectArray(Tree.Builder::setNodes, (p, c) -> TreeNode.fromXContent(p, lenient), TREE_STRUCTURE);
|
|
parser.declareObjectArray(Tree.Builder::setNodes, (p, c) -> TreeNode.fromXContent(p, lenient), TREE_STRUCTURE);
|
|
|
|
|
+ parser.declareString(Tree.Builder::setTargetType, TARGET_TYPE);
|
|
|
|
|
+ parser.declareStringArray(Tree.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
|
|
|
return parser;
|
|
return parser;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -59,15 +66,28 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
|
|
|
|
|
private final List<String> featureNames;
|
|
private final List<String> featureNames;
|
|
|
private final List<TreeNode> nodes;
|
|
private final List<TreeNode> nodes;
|
|
|
|
|
+ private final TargetType targetType;
|
|
|
|
|
+ private final List<String> classificationLabels;
|
|
|
|
|
+ private final CachedSupplier<Double> highestOrderCategory;
|
|
|
|
|
|
|
|
- Tree(List<String> featureNames, List<TreeNode> nodes) {
|
|
|
|
|
|
|
+ Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
|
|
|
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
|
|
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
|
|
|
this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE));
|
|
this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE));
|
|
|
|
|
+ this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
|
|
|
|
|
+ this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
|
|
|
|
|
+ this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
public Tree(StreamInput in) throws IOException {
|
|
public Tree(StreamInput in) throws IOException {
|
|
|
this.featureNames = Collections.unmodifiableList(in.readStringList());
|
|
this.featureNames = Collections.unmodifiableList(in.readStringList());
|
|
|
this.nodes = Collections.unmodifiableList(in.readList(TreeNode::new));
|
|
this.nodes = Collections.unmodifiableList(in.readList(TreeNode::new));
|
|
|
|
|
+ this.targetType = TargetType.fromStream(in);
|
|
|
|
|
+ if (in.readBoolean()) {
|
|
|
|
|
+ this.classificationLabels = Collections.unmodifiableList(in.readStringList());
|
|
|
|
|
+ } else {
|
|
|
|
|
+ this.classificationLabels = null;
|
|
|
|
|
+ }
|
|
|
|
|
+ this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
@@ -90,7 +110,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
return infer(features);
|
|
return infer(features);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- private double infer(List<Double> features) {
|
|
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public double infer(List<Double> features) {
|
|
|
TreeNode node = nodes.get(0);
|
|
TreeNode node = nodes.get(0);
|
|
|
while(node.isLeaf() == false) {
|
|
while(node.isLeaf() == false) {
|
|
|
node = nodes.get(node.compare(features));
|
|
node = nodes.get(node.compare(features));
|
|
@@ -115,13 +136,40 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
|
- public boolean isClassification() {
|
|
|
|
|
- return false;
|
|
|
|
|
|
|
+ public TargetType targetType() {
|
|
|
|
|
+ return targetType;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public List<Double> classificationProbability(Map<String, Object> fields) {
|
|
|
|
|
+ if ((targetType == TargetType.CLASSIFICATION) == false) {
|
|
|
|
|
+ throw new UnsupportedOperationException(
|
|
|
|
|
+ "Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
|
|
|
|
+ }
|
|
|
|
|
+ return classificationProbability(featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public List<Double> classificationProbability(List<Double> fields) {
|
|
|
|
|
+ if ((targetType == TargetType.CLASSIFICATION) == false) {
|
|
|
|
|
+ throw new UnsupportedOperationException(
|
|
|
|
|
+ "Cannot determine classification probability with target_type [" + targetType.toString() + "]");
|
|
|
|
|
+ }
|
|
|
|
|
+ double label = infer(fields);
|
|
|
|
|
+ // If we are classification, we should assume that the inference return value is whole.
|
|
|
|
|
+ assert label == Math.rint(label);
|
|
|
|
|
+ double maxCategory = this.highestOrderCategory.get();
|
|
|
|
|
+ // If we are classification, we should assume that the largest leaf value is whole.
|
|
|
|
|
+ assert maxCategory == Math.rint(maxCategory);
|
|
|
|
|
+ List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
|
|
|
|
|
+ // TODO, eventually have TreeNodes contain confidence levels
|
|
|
|
|
+ list.set(Double.valueOf(label).intValue(), 1.0);
|
|
|
|
|
+ return list;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
|
- public List<Double> inferProbabilities(Map<String, Object> fields) {
|
|
|
|
|
- throw new UnsupportedOperationException("Cannot infer probabilities against a regression model.");
|
|
|
|
|
|
|
+ public List<String> classificationLabels() {
|
|
|
|
|
+ return classificationLabels;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
@@ -133,6 +181,11 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
public void writeTo(StreamOutput out) throws IOException {
|
|
public void writeTo(StreamOutput out) throws IOException {
|
|
|
out.writeStringCollection(featureNames);
|
|
out.writeStringCollection(featureNames);
|
|
|
out.writeCollection(nodes);
|
|
out.writeCollection(nodes);
|
|
|
|
|
+ targetType.writeTo(out);
|
|
|
|
|
+ out.writeBoolean(classificationLabels != null);
|
|
|
|
|
+ if (classificationLabels != null) {
|
|
|
|
|
+ out.writeStringCollection(classificationLabels);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
@@ -140,6 +193,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
builder.startObject();
|
|
builder.startObject();
|
|
|
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
|
|
builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
|
|
|
builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
|
|
builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
|
|
|
|
|
+ builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
|
|
|
|
|
+ if(classificationLabels != null) {
|
|
|
|
|
+ builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
|
|
|
|
|
+ }
|
|
|
builder.endObject();
|
|
builder.endObject();
|
|
|
return builder;
|
|
return builder;
|
|
|
}
|
|
}
|
|
@@ -155,22 +212,96 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
if (o == null || getClass() != o.getClass()) return false;
|
|
if (o == null || getClass() != o.getClass()) return false;
|
|
|
Tree that = (Tree) o;
|
|
Tree that = (Tree) o;
|
|
|
return Objects.equals(featureNames, that.featureNames)
|
|
return Objects.equals(featureNames, that.featureNames)
|
|
|
- && Objects.equals(nodes, that.nodes);
|
|
|
|
|
|
|
+ && Objects.equals(nodes, that.nodes)
|
|
|
|
|
+ && Objects.equals(targetType, that.targetType)
|
|
|
|
|
+ && Objects.equals(classificationLabels, that.classificationLabels);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
|
public int hashCode() {
|
|
public int hashCode() {
|
|
|
- return Objects.hash(featureNames, nodes);
|
|
|
|
|
|
|
+ return Objects.hash(featureNames, nodes, targetType, classificationLabels);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
public static Builder builder() {
|
|
public static Builder builder() {
|
|
|
return new Builder();
|
|
return new Builder();
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public void validate() {
|
|
|
|
|
+ checkTargetType();
|
|
|
|
|
+ detectMissingNodes();
|
|
|
|
|
+ detectCycle();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private void checkTargetType() {
|
|
|
|
|
+ if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
|
|
|
|
|
+ throw ExceptionsHelper.badRequestException(
|
|
|
|
|
+ "[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private void detectCycle() {
|
|
|
|
|
+ if (nodes.isEmpty()) {
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ Set<Integer> visited = new HashSet<>(nodes.size());
|
|
|
|
|
+ Queue<Integer> toVisit = new ArrayDeque<>(nodes.size());
|
|
|
|
|
+ toVisit.add(0);
|
|
|
|
|
+ while(toVisit.isEmpty() == false) {
|
|
|
|
|
+ Integer nodeIdx = toVisit.remove();
|
|
|
|
|
+ if (visited.contains(nodeIdx)) {
|
|
|
|
|
+ throw ExceptionsHelper.badRequestException("[tree] contains cycle at node {}", nodeIdx);
|
|
|
|
|
+ }
|
|
|
|
|
+ visited.add(nodeIdx);
|
|
|
|
|
+ TreeNode treeNode = nodes.get(nodeIdx);
|
|
|
|
|
+ if (treeNode.getLeftChild() >= 0) {
|
|
|
|
|
+ toVisit.add(treeNode.getLeftChild());
|
|
|
|
|
+ }
|
|
|
|
|
+ if (treeNode.getRightChild() >= 0) {
|
|
|
|
|
+ toVisit.add(treeNode.getRightChild());
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private void detectMissingNodes() {
|
|
|
|
|
+ if (nodes.isEmpty()) {
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ List<Integer> missingNodes = new ArrayList<>();
|
|
|
|
|
+ for (int i = 0; i < nodes.size(); i++) {
|
|
|
|
|
+ TreeNode currentNode = nodes.get(i);
|
|
|
|
|
+ if (currentNode == null) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ if (nodeMissing(currentNode.getLeftChild(), nodes)) {
|
|
|
|
|
+ missingNodes.add(currentNode.getLeftChild());
|
|
|
|
|
+ }
|
|
|
|
|
+ if (nodeMissing(currentNode.getRightChild(), nodes)) {
|
|
|
|
|
+ missingNodes.add(currentNode.getRightChild());
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ if (missingNodes.isEmpty() == false) {
|
|
|
|
|
+ throw ExceptionsHelper.badRequestException("[tree] contains missing nodes {}", missingNodes);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private static boolean nodeMissing(int nodeIdx, List<TreeNode> nodes) {
|
|
|
|
|
+ return nodeIdx >= nodes.size();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private Double maxLeafValue() {
|
|
|
|
|
+ return targetType == TargetType.CLASSIFICATION ?
|
|
|
|
|
+ this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() :
|
|
|
|
|
+ null;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
public static class Builder {
|
|
public static class Builder {
|
|
|
private List<String> featureNames;
|
|
private List<String> featureNames;
|
|
|
private ArrayList<TreeNode.Builder> nodes;
|
|
private ArrayList<TreeNode.Builder> nodes;
|
|
|
private int numNodes;
|
|
private int numNodes;
|
|
|
|
|
+ private TargetType targetType = TargetType.REGRESSION;
|
|
|
|
|
+ private List<String> classificationLabels;
|
|
|
|
|
|
|
|
public Builder() {
|
|
public Builder() {
|
|
|
nodes = new ArrayList<>();
|
|
nodes = new ArrayList<>();
|
|
@@ -185,13 +316,18 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
return this;
|
|
return this;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ public Builder setRoot(TreeNode.Builder root) {
|
|
|
|
|
+ nodes.set(0, root);
|
|
|
|
|
+ return this;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
public Builder addNode(TreeNode.Builder node) {
|
|
public Builder addNode(TreeNode.Builder node) {
|
|
|
nodes.add(node);
|
|
nodes.add(node);
|
|
|
return this;
|
|
return this;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
public Builder setNodes(List<TreeNode.Builder> nodes) {
|
|
public Builder setNodes(List<TreeNode.Builder> nodes) {
|
|
|
- this.nodes = new ArrayList<>(nodes);
|
|
|
|
|
|
|
+ this.nodes = new ArrayList<>(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE.getPreferredName()));
|
|
|
return this;
|
|
return this;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -199,6 +335,21 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
return setNodes(Arrays.asList(nodes));
|
|
return setNodes(Arrays.asList(nodes));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+
|
|
|
|
|
+ public Builder setTargetType(TargetType targetType) {
|
|
|
|
|
+ this.targetType = targetType;
|
|
|
|
|
+ return this;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ public Builder setClassificationLabels(List<String> classificationLabels) {
|
|
|
|
|
+ this.classificationLabels = classificationLabels;
|
|
|
|
|
+ return this;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private void setTargetType(String targetType) {
|
|
|
|
|
+ this.targetType = TargetType.fromString(targetType);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
/**
|
|
/**
|
|
|
* Add a decision node. Space for the child nodes is allocated
|
|
* Add a decision node. Space for the child nodes is allocated
|
|
|
* @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index
|
|
* @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index
|
|
@@ -231,61 +382,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
return node;
|
|
return node;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- void detectCycle(List<TreeNode.Builder> nodes) {
|
|
|
|
|
- if (nodes.isEmpty()) {
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
- Set<Integer> visited = new HashSet<>();
|
|
|
|
|
- Queue<Integer> toVisit = new ArrayDeque<>(nodes.size());
|
|
|
|
|
- toVisit.add(0);
|
|
|
|
|
- while(toVisit.isEmpty() == false) {
|
|
|
|
|
- Integer nodeIdx = toVisit.remove();
|
|
|
|
|
- if (visited.contains(nodeIdx)) {
|
|
|
|
|
- throw new IllegalArgumentException("[tree] contains cycle at node " + nodeIdx);
|
|
|
|
|
- }
|
|
|
|
|
- visited.add(nodeIdx);
|
|
|
|
|
- TreeNode.Builder treeNode = nodes.get(nodeIdx);
|
|
|
|
|
- if (treeNode.getLeftChild() != null) {
|
|
|
|
|
- toVisit.add(treeNode.getLeftChild());
|
|
|
|
|
- }
|
|
|
|
|
- if (treeNode.getRightChild() != null) {
|
|
|
|
|
- toVisit.add(treeNode.getRightChild());
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- void detectNullOrMissingNode(List<TreeNode.Builder> nodes) {
|
|
|
|
|
- if (nodes.isEmpty()) {
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
- if (nodes.get(0) == null) {
|
|
|
|
|
- throw new IllegalArgumentException("[tree] must have non-null root node.");
|
|
|
|
|
- }
|
|
|
|
|
- List<Integer> nullOrMissingNodes = new ArrayList<>();
|
|
|
|
|
- for (int i = 0; i < nodes.size(); i++) {
|
|
|
|
|
- TreeNode.Builder currentNode = nodes.get(i);
|
|
|
|
|
- if (currentNode == null) {
|
|
|
|
|
- continue;
|
|
|
|
|
- }
|
|
|
|
|
- if (nodeNullOrMissing(currentNode.getLeftChild())) {
|
|
|
|
|
- nullOrMissingNodes.add(currentNode.getLeftChild());
|
|
|
|
|
- }
|
|
|
|
|
- if (nodeNullOrMissing(currentNode.getRightChild())) {
|
|
|
|
|
- nullOrMissingNodes.add(currentNode.getRightChild());
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- if (nullOrMissingNodes.isEmpty() == false) {
|
|
|
|
|
- throw new IllegalArgumentException("[tree] contains null or missing nodes " + nullOrMissingNodes);
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- private boolean nodeNullOrMissing(Integer nodeIdx) {
|
|
|
|
|
- if (nodeIdx == null) {
|
|
|
|
|
- return false;
|
|
|
|
|
- }
|
|
|
|
|
- return nodeIdx >= nodes.size() || nodes.get(nodeIdx) == null;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
/**
|
|
/**
|
|
|
* Sets the node at {@code nodeIndex} to a leaf node.
|
|
* Sets the node at {@code nodeIndex} to a leaf node.
|
|
|
* @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)}
|
|
* @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)}
|
|
@@ -301,10 +397,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
public Tree build() {
|
|
public Tree build() {
|
|
|
- detectNullOrMissingNode(nodes);
|
|
|
|
|
- detectCycle(nodes);
|
|
|
|
|
|
|
+ if (nodes.stream().anyMatch(Objects::isNull)) {
|
|
|
|
|
+ throw ExceptionsHelper.badRequestException("[tree] cannot contain null nodes");
|
|
|
|
|
+ }
|
|
|
return new Tree(featureNames,
|
|
return new Tree(featureNames,
|
|
|
- nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()));
|
|
|
|
|
|
|
+ nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()),
|
|
|
|
|
+ targetType,
|
|
|
|
|
+ classificationLabels);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|