|
@@ -91,8 +91,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
private final List<String> classificationLabels;
|
|
|
private final CachedSupplier<Double> highestOrderCategory;
|
|
|
// populated lazily when feature importance is calculated
|
|
|
- private double[] nodeEstimates;
|
|
|
private Integer maxDepth;
|
|
|
+ private Integer leafSize;
|
|
|
|
|
|
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
|
|
|
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
|
|
@@ -137,7 +137,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
.map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
|
|
|
.collect(Collectors.toList());
|
|
|
|
|
|
- Map<String, Double> featureImportance = config.requestingImportance() ?
|
|
|
+ Map<String, double[]> featureImportance = config.requestingImportance() ?
|
|
|
featureImportance(features, featureDecoderMap) :
|
|
|
Collections.emptyMap();
|
|
|
|
|
@@ -149,7 +149,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
return buildResult(node.getLeafValue(), featureImportance, config);
|
|
|
}
|
|
|
|
|
|
- private InferenceResults buildResult(double[] value, Map<String, Double> featureImportance, InferenceConfig config) {
|
|
|
+ private InferenceResults buildResult(double[] value, Map<String, double[]> featureImportance, InferenceConfig config) {
|
|
|
assert value != null && value.length > 0;
|
|
|
// Indicates that the config is useless and the caller just wants the raw value
|
|
|
if (config instanceof NullInferenceConfig) {
|
|
@@ -166,10 +166,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
return new ClassificationInferenceResults(topClasses.v1(),
|
|
|
classificationLabel(topClasses.v1(), classificationLabels),
|
|
|
topClasses.v2(),
|
|
|
- featureImportance,
|
|
|
+ InferenceHelpers.transformFeatureImportance(featureImportance, classificationLabels),
|
|
|
config);
|
|
|
case REGRESSION:
|
|
|
- return new RegressionInferenceResults(value[0], config, featureImportance);
|
|
|
+ return new RegressionInferenceResults(value[0],
|
|
|
+ config,
|
|
|
+ InferenceHelpers.transformFeatureImportance(featureImportance, null));
|
|
|
default:
|
|
|
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
|
|
|
}
|
|
@@ -283,7 +285,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
|
|
+ public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
|
|
|
if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) {
|
|
|
throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance");
|
|
|
}
|
|
@@ -293,9 +295,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
return featureImportance(features, featureDecoder);
|
|
|
}
|
|
|
|
|
|
- private Map<String, Double> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
|
|
|
- calculateNodeEstimatesIfNeeded();
|
|
|
- double[] featureImportance = new double[fieldValues.size()];
|
|
|
+ private Map<String, double[]> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
|
|
|
+ calculateDepthAndLeafValueSize();
|
|
|
+ double[][] featureImportance = new double[fieldValues.size()][leafSize];
|
|
|
+ for (int i = 0; i < fieldValues.size(); i++) {
|
|
|
+ featureImportance[i] = new double[leafSize];
|
|
|
+ }
|
|
|
int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
|
|
|
ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
|
|
|
for (int i = 0; i < arrSize; i++) {
|
|
@@ -303,24 +308,22 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
}
|
|
|
double[] scale = new double[arrSize];
|
|
|
ShapPath initialPath = new ShapPath(elements, scale);
|
|
|
- shapRecursive(fieldValues, this.nodeEstimates, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
|
|
|
+ shapRecursive(fieldValues, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
|
|
|
return InferenceHelpers.decodeFeatureImportances(featureDecoder,
|
|
|
IntStream.range(0, featureImportance.length)
|
|
|
.boxed()
|
|
|
.collect(Collectors.toMap(featureNames::get, i -> featureImportance[i])));
|
|
|
}
|
|
|
|
|
|
- private void calculateNodeEstimatesIfNeeded() {
|
|
|
- if (this.nodeEstimates != null && this.maxDepth != null) {
|
|
|
+ private void calculateDepthAndLeafValueSize() {
|
|
|
+ if (this.maxDepth != null && this.leafSize != null) {
|
|
|
return;
|
|
|
}
|
|
|
synchronized (this) {
|
|
|
- if (this.nodeEstimates != null && this.maxDepth != null) {
|
|
|
+ if (this.maxDepth != null && this.leafSize != null) {
|
|
|
return;
|
|
|
}
|
|
|
- double[] estimates = new double[nodes.size()];
|
|
|
- this.maxDepth = fillNodeEstimates(estimates, 0, 0);
|
|
|
- this.nodeEstimates = estimates;
|
|
|
+ this.maxDepth = getDepth(0, 0);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -331,23 +334,24 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
* side first and then ported to the Java side.
|
|
|
*/
|
|
|
private void shapRecursive(List<Double> processedFeatures,
|
|
|
- double[] nodeValues,
|
|
|
ShapPath parentSplitPath,
|
|
|
int nodeIndex,
|
|
|
double parentFractionZero,
|
|
|
double parentFractionOne,
|
|
|
int parentFeatureIndex,
|
|
|
- double[] featureImportance,
|
|
|
+ double[][] featureImportance,
|
|
|
int nextIndex) {
|
|
|
ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
|
|
|
TreeNode currNode = nodes.get(nodeIndex);
|
|
|
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
|
|
|
if (currNode.isLeaf()) {
|
|
|
- double leafValue = nodeValues[nodeIndex];
|
|
|
+ double[] leafValue = currNode.getLeafValue();
|
|
|
for (int i = 1; i < nextIndex; ++i) {
|
|
|
- double scale = splitPath.sumUnwoundPath(i, nextIndex);
|
|
|
int inputColumnIndex = splitPath.featureIndex(i);
|
|
|
- featureImportance[inputColumnIndex] += scale * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i)) * leafValue;
|
|
|
+ double scaled = splitPath.sumUnwoundPath(i, nextIndex) * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i));
|
|
|
+ for (int j = 0; j < leafValue.length; j++) {
|
|
|
+ featureImportance[inputColumnIndex][j] += scaled * leafValue[j];
|
|
|
+ }
|
|
|
}
|
|
|
} else {
|
|
|
int hotIndex = currNode.compare(processedFeatures);
|
|
@@ -365,41 +369,32 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
|
|
|
|
|
|
double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples();
|
|
|
double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples();
|
|
|
- shapRecursive(processedFeatures, nodeValues, splitPath,
|
|
|
+ shapRecursive(processedFeatures, splitPath,
|
|
|
hotIndex, incomingFractionZero * hotFractionZero,
|
|
|
incomingFractionOne, splitFeature, featureImportance, nextIndex);
|
|
|
- shapRecursive(processedFeatures, nodeValues, splitPath,
|
|
|
+ shapRecursive(processedFeatures, splitPath,
|
|
|
coldIndex, incomingFractionZero * coldFractionZero,
|
|
|
0.0, splitFeature, featureImportance, nextIndex);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
- * This recursively populates the provided {@code double[]} with the node estimated values
|
|
|
+ * Get the depth of the tree and sets leafSize if it is null
|
|
|
*
|
|
|
- * Used when calculating feature importance.
|
|
|
- * @param nodeEstimates Array to update in place with the node estimated values
|
|
|
* @param nodeIndex Current node index
|
|
|
* @param depth Current depth
|
|
|
* @return The current max depth
|
|
|
*/
|
|
|
- private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
|
|
|
+ private int getDepth(int nodeIndex, int depth) {
|
|
|
TreeNode node = nodes.get(nodeIndex);
|
|
|
if (node.isLeaf()) {
|
|
|
- // TODO multi-value????
|
|
|
- nodeEstimates[nodeIndex] = node.getLeafValue()[0];
|
|
|
+ if (leafSize == null) {
|
|
|
+ this.leafSize = node.getLeafValue().length;
|
|
|
+ }
|
|
|
return 0;
|
|
|
}
|
|
|
-
|
|
|
- int depthLeft = fillNodeEstimates(nodeEstimates, node.getLeftChild(), depth + 1);
|
|
|
- int depthRight = fillNodeEstimates(nodeEstimates, node.getRightChild(), depth + 1);
|
|
|
- long leftWeight = nodes.get(node.getLeftChild()).getNumberSamples();
|
|
|
- long rightWeight = nodes.get(node.getRightChild()).getNumberSamples();
|
|
|
- long divisor = leftWeight + rightWeight;
|
|
|
- double averageValue = divisor == 0 ?
|
|
|
- 0.0 :
|
|
|
- (leftWeight * nodeEstimates[node.getLeftChild()] + rightWeight * nodeEstimates[node.getRightChild()]) / divisor;
|
|
|
- nodeEstimates[nodeIndex] = averageValue;
|
|
|
+ int depthLeft = getDepth(node.getLeftChild(), depth + 1);
|
|
|
+ int depthRight = getDepth(node.getRightChild(), depth + 1);
|
|
|
return Math.max(depthLeft, depthRight) + 1;
|
|
|
}
|
|
|
|