|
@@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
|
|
|
|
|
|
import org.apache.lucene.util.Accountable;
|
|
|
import org.apache.lucene.util.RamUsageEstimator;
|
|
|
+import org.elasticsearch.Version;
|
|
|
import org.elasticsearch.common.Numbers;
|
|
|
import org.elasticsearch.common.ParseField;
|
|
|
import org.elasticsearch.common.Strings;
|
|
@@ -38,6 +39,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
public static final ParseField NODE_INDEX = new ParseField("node_index");
|
|
|
public static final ParseField SPLIT_GAIN = new ParseField("split_gain");
|
|
|
public static final ParseField LEAF_VALUE = new ParseField("leaf_value");
|
|
|
+ public static final ParseField NUMBER_SAMPLES = new ParseField("number_samples");
|
|
|
|
|
|
private static final ObjectParser<TreeNode.Builder, Void> LENIENT_PARSER = createParser(true);
|
|
|
private static final ObjectParser<TreeNode.Builder, Void> STRICT_PARSER = createParser(false);
|
|
@@ -59,6 +61,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
parser.declareInt(TreeNode.Builder::setNodeIndex, NODE_INDEX);
|
|
|
parser.declareDouble(TreeNode.Builder::setSplitGain, SPLIT_GAIN);
|
|
|
parser.declareDouble(TreeNode.Builder::setLeafValue, LEAF_VALUE);
|
|
|
+ parser.declareLong(TreeNode.Builder::setNumberSamples, NUMBER_SAMPLES);
|
|
|
return parser;
|
|
|
}
|
|
|
|
|
@@ -75,6 +78,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
private final boolean defaultLeft;
|
|
|
private final int leftChild;
|
|
|
private final int rightChild;
|
|
|
+ private final long numberSamples;
|
|
|
|
|
|
|
|
|
private TreeNode(Operator operator,
|
|
@@ -85,7 +89,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
Double leafValue,
|
|
|
Boolean defaultLeft,
|
|
|
Integer leftChild,
|
|
|
- Integer rightChild) {
|
|
|
+ Integer rightChild,
|
|
|
+ long numberSamples) {
|
|
|
this.operator = operator == null ? Operator.LTE : operator;
|
|
|
this.threshold = threshold == null ? Double.NaN : threshold;
|
|
|
this.splitFeature = splitFeature == null ? -1 : splitFeature;
|
|
@@ -95,6 +100,10 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
this.defaultLeft = defaultLeft == null ? false : defaultLeft;
|
|
|
this.leftChild = leftChild == null ? -1 : leftChild;
|
|
|
this.rightChild = rightChild == null ? -1 : rightChild;
|
|
|
+ if (numberSamples < 0) {
|
|
|
+ throw new IllegalArgumentException("[" + NUMBER_SAMPLES.getPreferredName() + "] must be greater than or equal to 0");
|
|
|
+ }
|
|
|
+ this.numberSamples = numberSamples;
|
|
|
}
|
|
|
|
|
|
public TreeNode(StreamInput in) throws IOException {
|
|
@@ -107,6 +116,11 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
defaultLeft = in.readBoolean();
|
|
|
leftChild = in.readInt();
|
|
|
rightChild = in.readInt();
|
|
|
+ if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
|
|
|
+ this.numberSamples = in.readVLong();
|
|
|
+ } else {
|
|
|
+ this.numberSamples = 0L;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
|
|
@@ -150,6 +164,10 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
return leftChild < 0;
|
|
|
}
|
|
|
|
|
|
+ public long getNumberSamples() {
|
|
|
+ return numberSamples;
|
|
|
+ }
|
|
|
+
|
|
|
public int compare(List<Double> features) {
|
|
|
if (isLeaf()) {
|
|
|
throw new IllegalArgumentException("cannot call compare against a leaf node.");
|
|
@@ -176,6 +194,9 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
out.writeBoolean(defaultLeft);
|
|
|
out.writeInt(leftChild);
|
|
|
out.writeInt(rightChild);
|
|
|
+ if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
|
|
|
+ out.writeVLong(numberSamples);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -196,6 +217,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
if (rightChild >= 0) {
|
|
|
builder.field(RIGHT_CHILD.getPreferredName(), rightChild);
|
|
|
}
|
|
|
+ builder.field(NUMBER_SAMPLES.getPreferredName(), numberSamples);
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
|
}
|
|
@@ -219,7 +241,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
&& Objects.equals(leafValue, that.leafValue)
|
|
|
&& Objects.equals(defaultLeft, that.defaultLeft)
|
|
|
&& Objects.equals(leftChild, that.leftChild)
|
|
|
- && Objects.equals(rightChild, that.rightChild);
|
|
|
+ && Objects.equals(rightChild, that.rightChild)
|
|
|
+ && Objects.equals(numberSamples, that.numberSamples);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -232,7 +255,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
leafValue,
|
|
|
defaultLeft,
|
|
|
leftChild,
|
|
|
- rightChild);
|
|
|
+ rightChild,
|
|
|
+ numberSamples);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -259,6 +283,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
private Boolean defaultLeft;
|
|
|
private Integer leftChild;
|
|
|
private Integer rightChild;
|
|
|
+ private long numberSamples;
|
|
|
|
|
|
public Builder(int nodeIndex) {
|
|
|
this.nodeIndex = nodeIndex;
|
|
@@ -320,6 +345,11 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
return rightChild;
|
|
|
}
|
|
|
|
|
|
+ public Builder setNumberSamples(long numberSamples) {
|
|
|
+ this.numberSamples = numberSamples;
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+
|
|
|
public void validate() {
|
|
|
if (nodeIndex < 0) {
|
|
|
throw new IllegalArgumentException("[node_index] must be a non-negative integer.");
|
|
@@ -351,7 +381,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
|
|
|
leafValue,
|
|
|
defaultLeft,
|
|
|
leftChild,
|
|
|
- rightChild);
|
|
|
+ rightChild,
|
|
|
+ numberSamples);
|
|
|
}
|
|
|
}
|
|
|
}
|