|
|
@@ -6,6 +6,8 @@
|
|
|
|
|
|
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
|
|
|
|
|
|
+import org.apache.logging.log4j.LogManager;
|
|
|
+import org.apache.logging.log4j.Logger;
|
|
|
import org.apache.lucene.util.Accountable;
|
|
|
import org.elasticsearch.common.Nullable;
|
|
|
import org.elasticsearch.common.Numbers;
|
|
|
@@ -28,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
|
|
|
import org.elasticsearch.xpack.core.ml.job.config.Operator;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
@@ -56,6 +59,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNo
|
|
|
|
|
|
public class TreeInferenceModel implements InferenceModel {
|
|
|
|
|
|
+ private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class);
|
|
|
public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
@@ -304,6 +308,7 @@ public class TreeInferenceModel implements InferenceModel {
|
|
|
|
|
|
@Override
|
|
|
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
|
|
|
+ LOGGER.debug("rewriting features {}", newFeatureIndexMapping);
|
|
|
if (preparedForInference) {
|
|
|
return;
|
|
|
}
|
|
|
@@ -358,6 +363,20 @@ public class TreeInferenceModel implements InferenceModel {
|
|
|
return nodes;
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ public String toString() {
|
|
|
+ return "TreeInferenceModel{" +
|
|
|
+ "nodes=" + Arrays.toString(nodes) +
|
|
|
+ ", featureNames=" + Arrays.toString(featureNames) +
|
|
|
+ ", targetType=" + targetType +
|
|
|
+ ", classificationLabels=" + classificationLabels +
|
|
|
+ ", highOrderCategory=" + highOrderCategory +
|
|
|
+ ", maxDepth=" + maxDepth +
|
|
|
+ ", leafSize=" + leafSize +
|
|
|
+ ", preparedForInference=" + preparedForInference +
|
|
|
+ '}';
|
|
|
+ }
|
|
|
+
|
|
|
private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
|
|
|
Node node = nodes[nodeIndex];
|
|
|
if (node instanceof LeafNode) {
|
|
|
@@ -519,6 +538,19 @@ public class TreeInferenceModel implements InferenceModel {
|
|
|
public long ramBytesUsed() {
|
|
|
return SHALLOW_SIZE;
|
|
|
}
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String toString() {
|
|
|
+ return "InnerNode{" +
|
|
|
+ "operator=" + operator +
|
|
|
+ ", threshold=" + threshold +
|
|
|
+ ", splitFeature=" + splitFeature +
|
|
|
+ ", defaultLeft=" + defaultLeft +
|
|
|
+ ", leftChild=" + leftChild +
|
|
|
+ ", rightChild=" + rightChild +
|
|
|
+ ", numberSamples=" + numberSamples +
|
|
|
+ '}';
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
public static class LeafNode extends Node {
|
|
|
@@ -544,5 +576,13 @@ public class TreeInferenceModel implements InferenceModel {
|
|
|
public double[] getLeafValue() {
|
|
|
return leafValue;
|
|
|
}
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String toString() {
|
|
|
+ return "LeafNode{" +
|
|
|
+ "leafValue=" + Arrays.toString(leafValue) +
|
|
|
+ ", numberSamples=" + numberSamples +
|
|
|
+ '}';
|
|
|
+ }
|
|
|
}
|
|
|
}
|