|  | @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import org.apache.lucene.util.Accountable;
 | 
	
		
			
				|  |  |  import org.apache.lucene.util.RamUsageEstimator;
 | 
	
		
			
				|  |  | +import org.elasticsearch.Version;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.Numbers;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.ParseField;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.Strings;
 | 
	
	
		
			
				|  | @@ -38,6 +39,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |      public static final ParseField NODE_INDEX = new ParseField("node_index");
 | 
	
		
			
				|  |  |      public static final ParseField SPLIT_GAIN = new ParseField("split_gain");
 | 
	
		
			
				|  |  |      public static final ParseField LEAF_VALUE = new ParseField("leaf_value");
 | 
	
		
			
				|  |  | +    public static final ParseField NUMBER_SAMPLES = new ParseField("number_samples");
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      private static final ObjectParser<TreeNode.Builder, Void> LENIENT_PARSER = createParser(true);
 | 
	
		
			
				|  |  |      private static final ObjectParser<TreeNode.Builder, Void> STRICT_PARSER = createParser(false);
 | 
	
	
		
			
				|  | @@ -59,6 +61,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |          parser.declareInt(TreeNode.Builder::setNodeIndex, NODE_INDEX);
 | 
	
		
			
				|  |  |          parser.declareDouble(TreeNode.Builder::setSplitGain, SPLIT_GAIN);
 | 
	
		
			
				|  |  |          parser.declareDouble(TreeNode.Builder::setLeafValue, LEAF_VALUE);
 | 
	
		
			
				|  |  | +        parser.declareLong(TreeNode.Builder::setNumberSamples, NUMBER_SAMPLES);
 | 
	
		
			
				|  |  |          return parser;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -75,6 +78,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |      private final boolean defaultLeft;
 | 
	
		
			
				|  |  |      private final int leftChild;
 | 
	
		
			
				|  |  |      private final int rightChild;
 | 
	
		
			
				|  |  | +    private final long numberSamples;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      private TreeNode(Operator operator,
 | 
	
	
		
			
				|  | @@ -85,7 +89,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |                       Double leafValue,
 | 
	
		
			
				|  |  |                       Boolean defaultLeft,
 | 
	
		
			
				|  |  |                       Integer leftChild,
 | 
	
		
			
				|  |  | -                     Integer rightChild) {
 | 
	
		
			
				|  |  | +                     Integer rightChild,
 | 
	
		
			
				|  |  | +                     long numberSamples) {
 | 
	
		
			
				|  |  |          this.operator = operator == null ? Operator.LTE : operator;
 | 
	
		
			
				|  |  |          this.threshold  = threshold == null ? Double.NaN : threshold;
 | 
	
		
			
				|  |  |          this.splitFeature = splitFeature == null ? -1 : splitFeature;
 | 
	
	
		
			
				|  | @@ -95,6 +100,10 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |          this.defaultLeft = defaultLeft == null ? false : defaultLeft;
 | 
	
		
			
				|  |  |          this.leftChild  = leftChild == null ? -1 : leftChild;
 | 
	
		
			
				|  |  |          this.rightChild = rightChild == null ? -1 : rightChild;
 | 
	
		
			
				|  |  | +        if (numberSamples < 0) {
 | 
	
		
			
				|  |  | +            throw new IllegalArgumentException("[" + NUMBER_SAMPLES.getPreferredName() + "] must be greater than or equal to 0");
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        this.numberSamples = numberSamples;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public TreeNode(StreamInput in) throws IOException {
 | 
	
	
		
			
				|  | @@ -107,6 +116,11 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |          defaultLeft = in.readBoolean();
 | 
	
		
			
				|  |  |          leftChild = in.readInt();
 | 
	
		
			
				|  |  |          rightChild = in.readInt();
 | 
	
		
			
				|  |  | +        if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
 | 
	
		
			
				|  |  | +            this.numberSamples = in.readVLong();
 | 
	
		
			
				|  |  | +        } else {
 | 
	
		
			
				|  |  | +            this.numberSamples = 0L;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -150,6 +164,10 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |          return leftChild < 0;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    public long getNumberSamples() {
 | 
	
		
			
				|  |  | +        return numberSamples;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      public int compare(List<Double> features) {
 | 
	
		
			
				|  |  |          if (isLeaf()) {
 | 
	
		
			
				|  |  |              throw new IllegalArgumentException("cannot call compare against a leaf node.");
 | 
	
	
		
			
				|  | @@ -176,6 +194,9 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |          out.writeBoolean(defaultLeft);
 | 
	
		
			
				|  |  |          out.writeInt(leftChild);
 | 
	
		
			
				|  |  |          out.writeInt(rightChild);
 | 
	
		
			
				|  |  | +        if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
 | 
	
		
			
				|  |  | +            out.writeVLong(numberSamples);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Override
 | 
	
	
		
			
				|  | @@ -196,6 +217,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |          if (rightChild >= 0) {
 | 
	
		
			
				|  |  |              builder.field(RIGHT_CHILD.getPreferredName(), rightChild);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | +        builder.field(NUMBER_SAMPLES.getPreferredName(), numberSamples);
 | 
	
		
			
				|  |  |          builder.endObject();
 | 
	
		
			
				|  |  |          return builder;
 | 
	
		
			
				|  |  |      }
 | 
	
	
		
			
				|  | @@ -219,7 +241,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |              && Objects.equals(leafValue, that.leafValue)
 | 
	
		
			
				|  |  |              && Objects.equals(defaultLeft, that.defaultLeft)
 | 
	
		
			
				|  |  |              && Objects.equals(leftChild, that.leftChild)
 | 
	
		
			
				|  |  | -            && Objects.equals(rightChild, that.rightChild);
 | 
	
		
			
				|  |  | +            && Objects.equals(rightChild, that.rightChild)
 | 
	
		
			
				|  |  | +            && Objects.equals(numberSamples, that.numberSamples);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Override
 | 
	
	
		
			
				|  | @@ -232,7 +255,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |              leafValue,
 | 
	
		
			
				|  |  |              defaultLeft,
 | 
	
		
			
				|  |  |              leftChild,
 | 
	
		
			
				|  |  | -            rightChild);
 | 
	
		
			
				|  |  | +            rightChild,
 | 
	
		
			
				|  |  | +            numberSamples);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Override
 | 
	
	
		
			
				|  | @@ -259,6 +283,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |          private Boolean defaultLeft;
 | 
	
		
			
				|  |  |          private Integer leftChild;
 | 
	
		
			
				|  |  |          private Integer rightChild;
 | 
	
		
			
				|  |  | +        private long numberSamples;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          public Builder(int nodeIndex) {
 | 
	
		
			
				|  |  |              this.nodeIndex = nodeIndex;
 | 
	
	
		
			
				|  | @@ -320,6 +345,11 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |              return rightChild;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        public Builder setNumberSamples(long numberSamples) {
 | 
	
		
			
				|  |  | +            this.numberSamples = numberSamples;
 | 
	
		
			
				|  |  | +            return this;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          public void validate() {
 | 
	
		
			
				|  |  |              if (nodeIndex < 0) {
 | 
	
		
			
				|  |  |                  throw new IllegalArgumentException("[node_index] must be a non-negative integer.");
 | 
	
	
		
			
				|  | @@ -351,7 +381,8 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 | 
	
		
			
				|  |  |                  leafValue, 
 | 
	
		
			
				|  |  |                  defaultLeft, 
 | 
	
		
			
				|  |  |                  leftChild, 
 | 
	
		
			
				|  |  | -                rightChild);
 | 
	
		
			
				|  |  | +                rightChild,
 | 
	
		
			
				|  |  | +                numberSamples);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  }
 |