浏览代码

[ML][Inference] adding number_samples to TreeNode (#51937)

in preparation for feature importance and split information gain, adding `number_samples` field to `TreeNode` definition.
Benjamin Trent 5 年之前
父节点
当前提交
dd8b497e68

+ 23 - 4
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java

@@ -42,6 +42,7 @@ public class TreeNode implements ToXContentObject {
     public static final ParseField NODE_INDEX = new ParseField("node_index");
     public static final ParseField SPLIT_GAIN = new ParseField("split_gain");
     public static final ParseField LEAF_VALUE = new ParseField("leaf_value");
+    public static final ParseField NUMBER_SAMPLES = new ParseField("number_samples");
 
 
     private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
@@ -61,6 +62,7 @@ public class TreeNode implements ToXContentObject {
         PARSER.declareInt(Builder::setNodeIndex, NODE_INDEX);
         PARSER.declareDouble(Builder::setSplitGain, SPLIT_GAIN);
         PARSER.declareDouble(Builder::setLeafValue, LEAF_VALUE);
+        PARSER.declareLong(Builder::setNumberSamples, NUMBER_SAMPLES);
     }
 
     public static Builder fromXContent(XContentParser parser) {
@@ -76,6 +78,7 @@ public class TreeNode implements ToXContentObject {
     private final Boolean defaultLeft;
     private final Integer leftChild;
     private final Integer rightChild;
+    private final Long numberSamples;
 
 
     TreeNode(Operator operator,
@@ -86,7 +89,8 @@ public class TreeNode implements ToXContentObject {
              Double leafValue,
              Boolean defaultLeft,
              Integer leftChild,
-             Integer rightChild) {
+             Integer rightChild,
+             Long numberSamples) {
         this.operator = operator;
         this.threshold  = threshold;
         this.splitFeature = splitFeature;
@@ -96,6 +100,7 @@ public class TreeNode implements ToXContentObject {
         this.defaultLeft = defaultLeft;
         this.leftChild  = leftChild;
         this.rightChild = rightChild;
+        this.numberSamples = numberSamples;
     }
 
     public Operator getOperator() {
@@ -134,6 +139,10 @@ public class TreeNode implements ToXContentObject {
         return rightChild;
     }
 
+    public Long getNumberSamples() {
+        return numberSamples;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
@@ -146,6 +155,7 @@ public class TreeNode implements ToXContentObject {
         addOptionalField(builder, DEFAULT_LEFT, defaultLeft );
         addOptionalField(builder, LEFT_CHILD, leftChild);
         addOptionalField(builder, RIGHT_CHILD, rightChild);
+        addOptionalField(builder, NUMBER_SAMPLES, numberSamples);
         builder.endObject();
         return builder;
     }
@@ -169,7 +179,8 @@ public class TreeNode implements ToXContentObject {
             && Objects.equals(leafValue, that.leafValue)
             && Objects.equals(defaultLeft, that.defaultLeft)
             && Objects.equals(leftChild, that.leftChild)
-            && Objects.equals(rightChild, that.rightChild);
+            && Objects.equals(rightChild, that.rightChild)
+            && Objects.equals(numberSamples, that.numberSamples);
     }
 
     @Override
@@ -182,7 +193,8 @@ public class TreeNode implements ToXContentObject {
             leafValue,
             defaultLeft,
             leftChild,
-            rightChild);
+            rightChild,
+            numberSamples);
     }
 
     @Override
@@ -204,6 +216,7 @@ public class TreeNode implements ToXContentObject {
         private Boolean defaultLeft;
         private Integer leftChild;
         private Integer rightChild;
+        private Long numberSamples;
 
         public Builder(int nodeIndex) {
             this.nodeIndex = nodeIndex;
@@ -265,6 +278,11 @@ public class TreeNode implements ToXContentObject {
             return rightChild;
         }
 
+        public Builder setNumberSamples(Long numberSamples) {
+            this.numberSamples = numberSamples;
+            return this;
+        }
+
         public TreeNode build() {
             return new TreeNode(operator,
                 threshold, 
@@ -274,7 +292,8 @@ public class TreeNode implements ToXContentObject {
                 leafValue, 
                 defaultLeft, 
                 leftChild, 
-                rightChild);
+                rightChild,
+                numberSamples);
         }
     }
 }

+ 2 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java

@@ -49,6 +49,7 @@ public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
         return TreeNode.builder(randomInt(100))
             .setDefaultLeft(randomBoolean() ? null : randomBoolean())
             .setLeafValue(internalValue)
+            .setNumberSamples(randomNonNegativeLong())
             .build();
     }
 
@@ -63,6 +64,7 @@ public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
             .setDefaultLeft(randomBoolean() ? null : randomBoolean())
             .setLeftChild(left)
             .setRightChild(right)
+            .setNumberSamples(randomBoolean() ? null : randomNonNegativeLong())
             .setThreshold(threshold)
             .setOperator(operator)
             .setSplitFeature(featureIndex)

+ 35 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java

@@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
 
 import org.apache.lucene.util.Accountable;
 import org.apache.lucene.util.RamUsageEstimator;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.Numbers;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
@@ -38,6 +39,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
     public static final ParseField NODE_INDEX = new ParseField("node_index");
     public static final ParseField SPLIT_GAIN = new ParseField("split_gain");
     public static final ParseField LEAF_VALUE = new ParseField("leaf_value");
+    public static final ParseField NUMBER_SAMPLES = new ParseField("number_samples");
 
     private static final ObjectParser<TreeNode.Builder, Void> LENIENT_PARSER = createParser(true);
     private static final ObjectParser<TreeNode.Builder, Void> STRICT_PARSER = createParser(false);
@@ -59,6 +61,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         parser.declareInt(TreeNode.Builder::setNodeIndex, NODE_INDEX);
         parser.declareDouble(TreeNode.Builder::setSplitGain, SPLIT_GAIN);
         parser.declareDouble(TreeNode.Builder::setLeafValue, LEAF_VALUE);
+        parser.declareLong(TreeNode.Builder::setNumberSamples, NUMBER_SAMPLES);
         return parser;
     }
 
@@ -75,6 +78,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
     private final boolean defaultLeft;
     private final int leftChild;
     private final int rightChild;
+    private final long numberSamples;
 
 
     private TreeNode(Operator operator,
@@ -85,7 +89,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
                      Double leafValue,
                      Boolean defaultLeft,
                      Integer leftChild,
-                     Integer rightChild) {
+                     Integer rightChild,
+                     long numberSamples) {
         this.operator = operator == null ? Operator.LTE : operator;
         this.threshold  = threshold == null ? Double.NaN : threshold;
         this.splitFeature = splitFeature == null ? -1 : splitFeature;
@@ -95,6 +100,10 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         this.defaultLeft = defaultLeft == null ? false : defaultLeft;
         this.leftChild  = leftChild == null ? -1 : leftChild;
         this.rightChild = rightChild == null ? -1 : rightChild;
+        if (numberSamples < 0) {
+            throw new IllegalArgumentException("[" + NUMBER_SAMPLES.getPreferredName() + "] must be greater than or equal to 0");
+        }
+        this.numberSamples = numberSamples;
     }
 
     public TreeNode(StreamInput in) throws IOException {
@@ -107,6 +116,11 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         defaultLeft = in.readBoolean();
         leftChild = in.readInt();
         rightChild = in.readInt();
+        if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
+            this.numberSamples = in.readVLong();
+        } else {
+            this.numberSamples = 0L;
+        }
     }
 
 
@@ -150,6 +164,10 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         return leftChild < 0;
     }
 
+    public long getNumberSamples() {
+        return numberSamples;
+    }
+
     public int compare(List<Double> features) {
         if (isLeaf()) {
             throw new IllegalArgumentException("cannot call compare against a leaf node.");
@@ -176,6 +194,9 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         out.writeBoolean(defaultLeft);
         out.writeInt(leftChild);
         out.writeInt(rightChild);
+        if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
+            out.writeVLong(numberSamples);
+        }
     }
 
     @Override
@@ -196,6 +217,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         if (rightChild >= 0) {
             builder.field(RIGHT_CHILD.getPreferredName(), rightChild);
         }
+        builder.field(NUMBER_SAMPLES.getPreferredName(), numberSamples);
         builder.endObject();
         return builder;
     }
@@ -219,7 +241,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
             && Objects.equals(leafValue, that.leafValue)
             && Objects.equals(defaultLeft, that.defaultLeft)
             && Objects.equals(leftChild, that.leftChild)
-            && Objects.equals(rightChild, that.rightChild);
+            && Objects.equals(rightChild, that.rightChild)
+            && Objects.equals(numberSamples, that.numberSamples);
     }
 
     @Override
@@ -232,7 +255,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
             leafValue,
             defaultLeft,
             leftChild,
-            rightChild);
+            rightChild,
+            numberSamples);
     }
 
     @Override
@@ -259,6 +283,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         private Boolean defaultLeft;
         private Integer leftChild;
         private Integer rightChild;
+        private long numberSamples;
 
         public Builder(int nodeIndex) {
             this.nodeIndex = nodeIndex;
@@ -320,6 +345,11 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
             return rightChild;
         }
 
+        public Builder setNumberSamples(long numberSamples) {
+            this.numberSamples = numberSamples;
+            return this;
+        }
+
         public void validate() {
             if (nodeIndex < 0) {
                 throw new IllegalArgumentException("[node_index] must be a non-negative integer.");
@@ -351,7 +381,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
                 leafValue, 
                 defaultLeft, 
                 leftChild, 
-                rightChild);
+                rightChild,
+                numberSamples);
         }
     }
 }

+ 2 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNodeTests.java

@@ -54,6 +54,7 @@ public class TreeNodeTests extends AbstractSerializingTestCase<TreeNode> {
     public static TreeNode createRandomLeafNode(double internalValue) {
         return TreeNode.builder(randomInt(100))
             .setDefaultLeft(randomBoolean() ? null : randomBoolean())
+            .setNumberSamples(randomNonNegativeLong())
             .setLeafValue(internalValue)
             .build();
     }
@@ -69,6 +70,7 @@ public class TreeNodeTests extends AbstractSerializingTestCase<TreeNode> {
             .setDefaultLeft(randomBoolean() ? null : randomBoolean())
             .setLeftChild(left)
             .setRightChild(right)
+            .setNumberSamples(randomNonNegativeLong())
             .setThreshold(threshold)
             .setOperator(operator)
             .setSplitFeature(randomBoolean() ? null : randomInt())