|  | @@ -91,8 +91,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 | 
	
		
			
				|  |  |      private final List<String> classificationLabels;
 | 
	
		
			
				|  |  |      private final CachedSupplier<Double> highestOrderCategory;
 | 
	
		
			
				|  |  |      // populated lazily when feature importance is calculated
 | 
	
		
			
				|  |  | -    private double[] nodeEstimates;
 | 
	
		
			
				|  |  |      private Integer maxDepth;
 | 
	
		
			
				|  |  | +    private Integer leafSize;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
 | 
	
		
			
				|  |  |          this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
 | 
	
	
		
			
				|  | @@ -137,7 +137,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 | 
	
		
			
				|  |  |              .map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
 | 
	
		
			
				|  |  |              .collect(Collectors.toList());
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        Map<String, Double> featureImportance = config.requestingImportance() ?
 | 
	
		
			
				|  |  | +        Map<String, double[]> featureImportance = config.requestingImportance() ?
 | 
	
		
			
				|  |  |              featureImportance(features, featureDecoderMap) :
 | 
	
		
			
				|  |  |              Collections.emptyMap();
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -149,7 +149,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 | 
	
		
			
				|  |  |          return buildResult(node.getLeafValue(), featureImportance, config);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    private InferenceResults buildResult(double[] value, Map<String, Double> featureImportance, InferenceConfig config) {
 | 
	
		
			
				|  |  | +    private InferenceResults buildResult(double[] value, Map<String, double[]> featureImportance, InferenceConfig config) {
 | 
	
		
			
				|  |  |          assert value != null && value.length > 0;
 | 
	
		
			
				|  |  |          // Indicates that the config is useless and the caller just wants the raw value
 | 
	
		
			
				|  |  |          if (config instanceof NullInferenceConfig) {
 | 
	
	
		
			
				|  | @@ -166,10 +166,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 | 
	
		
			
				|  |  |                  return new ClassificationInferenceResults(topClasses.v1(),
 | 
	
		
			
				|  |  |                      classificationLabel(topClasses.v1(), classificationLabels),
 | 
	
		
			
				|  |  |                      topClasses.v2(),
 | 
	
		
			
				|  |  | -                    featureImportance,
 | 
	
		
			
				|  |  | +                    InferenceHelpers.transformFeatureImportance(featureImportance, classificationLabels),
 | 
	
		
			
				|  |  |                      config);
 | 
	
		
			
				|  |  |              case REGRESSION:
 | 
	
		
			
				|  |  | -                return new RegressionInferenceResults(value[0], config, featureImportance);
 | 
	
		
			
				|  |  | +                return new RegressionInferenceResults(value[0],
 | 
	
		
			
				|  |  | +                    config,
 | 
	
		
			
				|  |  | +                    InferenceHelpers.transformFeatureImportance(featureImportance, null));
 | 
	
		
			
				|  |  |              default:
 | 
	
		
			
				|  |  |                  throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
 | 
	
		
			
				|  |  |          }
 | 
	
	
		
			
				|  | @@ -283,7 +285,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Override
 | 
	
		
			
				|  |  | -    public Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
 | 
	
		
			
				|  |  | +    public Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder) {
 | 
	
		
			
				|  |  |          if (nodes.stream().allMatch(n -> n.getNumberSamples() == 0)) {
 | 
	
		
			
				|  |  |              throw ExceptionsHelper.badRequestException("[tree_structure.number_samples] must be greater than zero for feature importance");
 | 
	
		
			
				|  |  |          }
 | 
	
	
		
			
				|  | @@ -293,9 +295,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 | 
	
		
			
				|  |  |          return featureImportance(features, featureDecoder);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    private Map<String, Double> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
 | 
	
		
			
				|  |  | -        calculateNodeEstimatesIfNeeded();
 | 
	
		
			
				|  |  | -        double[] featureImportance = new double[fieldValues.size()];
 | 
	
		
			
				|  |  | +    private Map<String, double[]> featureImportance(List<Double> fieldValues, Map<String, String> featureDecoder) {
 | 
	
		
			
				|  |  | +        calculateDepthAndLeafValueSize();
 | 
	
		
			
				|  |  | +        double[][] featureImportance = new double[fieldValues.size()][leafSize];
 | 
	
		
			
				|  |  | +        for (int i = 0; i < fieldValues.size(); i++) {
 | 
	
		
			
				|  |  | +            featureImportance[i] = new double[leafSize];
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  |          int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
 | 
	
		
			
				|  |  |          ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
 | 
	
		
			
				|  |  |          for (int i = 0; i < arrSize; i++) {
 | 
	
	
		
			
				|  | @@ -303,24 +308,22 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |          double[] scale = new double[arrSize];
 | 
	
		
			
				|  |  |          ShapPath initialPath = new ShapPath(elements, scale);
 | 
	
		
			
				|  |  | -        shapRecursive(fieldValues, this.nodeEstimates, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
 | 
	
		
			
				|  |  | +        shapRecursive(fieldValues, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
 | 
	
		
			
				|  |  |          return InferenceHelpers.decodeFeatureImportances(featureDecoder,
 | 
	
		
			
				|  |  |              IntStream.range(0, featureImportance.length)
 | 
	
		
			
				|  |  |                  .boxed()
 | 
	
		
			
				|  |  |                  .collect(Collectors.toMap(featureNames::get, i -> featureImportance[i])));
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    private void calculateNodeEstimatesIfNeeded() {
 | 
	
		
			
				|  |  | -        if (this.nodeEstimates != null && this.maxDepth != null) {
 | 
	
		
			
				|  |  | +    private void calculateDepthAndLeafValueSize() {
 | 
	
		
			
				|  |  | +        if (this.maxDepth != null && this.leafSize != null) {
 | 
	
		
			
				|  |  |              return;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |          synchronized (this) {
 | 
	
		
			
				|  |  | -            if (this.nodeEstimates != null && this.maxDepth != null) {
 | 
	
		
			
				|  |  | +            if (this.maxDepth != null && this.leafSize != null) {
 | 
	
		
			
				|  |  |                  return;
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  | -            double[] estimates = new double[nodes.size()];
 | 
	
		
			
				|  |  | -            this.maxDepth = fillNodeEstimates(estimates, 0, 0);
 | 
	
		
			
				|  |  | -            this.nodeEstimates = estimates;
 | 
	
		
			
				|  |  | +            this.maxDepth = getDepth(0, 0);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -331,23 +334,24 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 | 
	
		
			
				|  |  |       * side first and then ported to the Java side.
 | 
	
		
			
				|  |  |       */
 | 
	
		
			
				|  |  |      private void shapRecursive(List<Double> processedFeatures,
 | 
	
		
			
				|  |  | -                               double[] nodeValues,
 | 
	
		
			
				|  |  |                                 ShapPath parentSplitPath,
 | 
	
		
			
				|  |  |                                 int nodeIndex,
 | 
	
		
			
				|  |  |                                 double parentFractionZero,
 | 
	
		
			
				|  |  |                                 double parentFractionOne,
 | 
	
		
			
				|  |  |                                 int parentFeatureIndex,
 | 
	
		
			
				|  |  | -                               double[] featureImportance,
 | 
	
		
			
				|  |  | +                               double[][] featureImportance,
 | 
	
		
			
				|  |  |                                 int nextIndex) {
 | 
	
		
			
				|  |  |          ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
 | 
	
		
			
				|  |  |          TreeNode currNode = nodes.get(nodeIndex);
 | 
	
		
			
				|  |  |          nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
 | 
	
		
			
				|  |  |          if (currNode.isLeaf()) {
 | 
	
		
			
				|  |  | -            double leafValue = nodeValues[nodeIndex];
 | 
	
		
			
				|  |  | +            double[] leafValue = currNode.getLeafValue();
 | 
	
		
			
				|  |  |              for (int i = 1; i < nextIndex; ++i) {
 | 
	
		
			
				|  |  | -                double scale = splitPath.sumUnwoundPath(i, nextIndex);
 | 
	
		
			
				|  |  |                  int inputColumnIndex = splitPath.featureIndex(i);
 | 
	
		
			
				|  |  | -                featureImportance[inputColumnIndex] += scale * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i)) * leafValue;
 | 
	
		
			
				|  |  | +                double scaled = splitPath.sumUnwoundPath(i, nextIndex) * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i));
 | 
	
		
			
				|  |  | +                for (int j = 0; j < leafValue.length; j++) {
 | 
	
		
			
				|  |  | +                    featureImportance[inputColumnIndex][j] += scaled * leafValue[j];
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |          } else {
 | 
	
		
			
				|  |  |              int hotIndex = currNode.compare(processedFeatures);
 | 
	
	
		
			
				|  | @@ -365,41 +369,32 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |              double hotFractionZero = nodes.get(hotIndex).getNumberSamples() / (double)currNode.getNumberSamples();
 | 
	
		
			
				|  |  |              double coldFractionZero = nodes.get(coldIndex).getNumberSamples() / (double)currNode.getNumberSamples();
 | 
	
		
			
				|  |  | -            shapRecursive(processedFeatures, nodeValues, splitPath,
 | 
	
		
			
				|  |  | +            shapRecursive(processedFeatures, splitPath,
 | 
	
		
			
				|  |  |                  hotIndex, incomingFractionZero * hotFractionZero,
 | 
	
		
			
				|  |  |                  incomingFractionOne, splitFeature, featureImportance, nextIndex);
 | 
	
		
			
				|  |  | -            shapRecursive(processedFeatures, nodeValues, splitPath,
 | 
	
		
			
				|  |  | +            shapRecursive(processedFeatures, splitPath,
 | 
	
		
			
				|  |  |                  coldIndex, incomingFractionZero * coldFractionZero,
 | 
	
		
			
				|  |  |                  0.0, splitFeature, featureImportance, nextIndex);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      /**
 | 
	
		
			
				|  |  | -     * This recursively populates the provided {@code double[]} with the node estimated values
 | 
	
		
			
				|  |  | +     * Get the depth of the tree and sets leafSize if it is null
 | 
	
		
			
				|  |  |       *
 | 
	
		
			
				|  |  | -     * Used when calculating feature importance.
 | 
	
		
			
				|  |  | -     * @param nodeEstimates Array to update in place with the node estimated values
 | 
	
		
			
				|  |  |       * @param nodeIndex Current node index
 | 
	
		
			
				|  |  |       * @param depth Current depth
 | 
	
		
			
				|  |  |       * @return The current max depth
 | 
	
		
			
				|  |  |       */
 | 
	
		
			
				|  |  | -    private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
 | 
	
		
			
				|  |  | +    private int getDepth(int nodeIndex, int depth) {
 | 
	
		
			
				|  |  |          TreeNode node = nodes.get(nodeIndex);
 | 
	
		
			
				|  |  |          if (node.isLeaf()) {
 | 
	
		
			
				|  |  | -            // TODO multi-value????
 | 
	
		
			
				|  |  | -            nodeEstimates[nodeIndex] = node.getLeafValue()[0];
 | 
	
		
			
				|  |  | +            if (leafSize == null) {
 | 
	
		
			
				|  |  | +                this.leafSize = node.getLeafValue().length;
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |              return 0;
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        int depthLeft = fillNodeEstimates(nodeEstimates, node.getLeftChild(), depth + 1);
 | 
	
		
			
				|  |  | -        int depthRight = fillNodeEstimates(nodeEstimates, node.getRightChild(), depth + 1);
 | 
	
		
			
				|  |  | -        long leftWeight = nodes.get(node.getLeftChild()).getNumberSamples();
 | 
	
		
			
				|  |  | -        long rightWeight = nodes.get(node.getRightChild()).getNumberSamples();
 | 
	
		
			
				|  |  | -        long divisor = leftWeight + rightWeight;
 | 
	
		
			
				|  |  | -        double averageValue = divisor == 0 ?
 | 
	
		
			
				|  |  | -            0.0 :
 | 
	
		
			
				|  |  | -            (leftWeight * nodeEstimates[node.getLeftChild()] + rightWeight * nodeEstimates[node.getRightChild()]) / divisor;
 | 
	
		
			
				|  |  | -        nodeEstimates[nodeIndex] = averageValue;
 | 
	
		
			
				|  |  | +        int depthLeft = getDepth(node.getLeftChild(), depth + 1);
 | 
	
		
			
				|  |  | +        int depthRight = getDepth(node.getRightChild(), depth + 1);
 | 
	
		
			
				|  |  |          return Math.max(depthLeft, depthRight) + 1;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 |