Browse Source

[ML][Inference] Add support for multi-value leaves to the tree model (#52531)

This adds support for multi-value leaves. This is a prerequisite for multi-class boosted tree classification.
Benjamin Trent 5 years ago
parent
commit
e39eade20d
26 changed files with 575 additions and 197 deletions
  1. 2 1
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java
  2. 7 6
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java
  3. 3 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java
  4. 18 10
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java
  5. 12 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java
  6. 5 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java
  7. 17 11
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java
  8. 24 16
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java
  9. 5 7
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java
  10. 36 20
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java
  11. 8 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java
  12. 1 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java
  13. 67 14
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java
  14. 43 13
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java
  15. 19 19
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java
  16. 24 11
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java
  17. 38 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java
  18. 2 5
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java
  19. 52 12
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java
  20. 7 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java
  21. 2 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNodeTests.java
  22. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java
  23. 6 6
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java
  24. 7 7
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java
  25. 12 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
  26. 157 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java

+ 2 - 1
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java

@@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.stream.Collectors;
@@ -225,7 +226,7 @@ public class Tree implements TrainedModel {
             for (int i = nodes.size(); i < nodeIndex + 1; i++) {
                 nodes.add(null);
             }
-            nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(value));
+            nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(Collections.singletonList(value)));
             return this;
         }
 

+ 7 - 6
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNode.java

@@ -27,6 +27,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Objects;
 
 public class TreeNode implements ToXContentObject {
@@ -61,7 +62,7 @@ public class TreeNode implements ToXContentObject {
         PARSER.declareInt(Builder::setSplitFeature, SPLIT_FEATURE);
         PARSER.declareInt(Builder::setNodeIndex, NODE_INDEX);
         PARSER.declareDouble(Builder::setSplitGain, SPLIT_GAIN);
-        PARSER.declareDouble(Builder::setLeafValue, LEAF_VALUE);
+        PARSER.declareDoubleArray(Builder::setLeafValue, LEAF_VALUE);
         PARSER.declareLong(Builder::setNumberSamples, NUMBER_SAMPLES);
     }
 
@@ -74,7 +75,7 @@ public class TreeNode implements ToXContentObject {
     private final Integer splitFeature;
     private final int nodeIndex;
     private final Double splitGain;
-    private final Double leafValue;
+    private final List<Double> leafValue;
     private final Boolean defaultLeft;
     private final Integer leftChild;
     private final Integer rightChild;
@@ -86,7 +87,7 @@ public class TreeNode implements ToXContentObject {
              Integer splitFeature,
              int nodeIndex,
              Double splitGain,
-             Double leafValue,
+             List<Double> leafValue,
              Boolean defaultLeft,
              Integer leftChild,
              Integer rightChild,
@@ -123,7 +124,7 @@ public class TreeNode implements ToXContentObject {
         return splitGain;
     }
 
-    public Double getLeafValue() {
+    public List<Double> getLeafValue() {
         return leafValue;
     }
 
@@ -212,7 +213,7 @@ public class TreeNode implements ToXContentObject {
         private Integer splitFeature;
         private int nodeIndex;
         private Double splitGain;
-        private Double leafValue;
+        private List<Double> leafValue;
         private Boolean defaultLeft;
         private Integer leftChild;
         private Integer rightChild;
@@ -250,7 +251,7 @@ public class TreeNode implements ToXContentObject {
             return this;
         }
 
-        public Builder setLeafValue(Double leafValue) {
+        public Builder setLeafValue(List<Double> leafValue) {
             this.leafValue = leafValue;
             return this;
         }

+ 3 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeNodeTests.java

@@ -23,6 +23,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractXContentTestCase;
 
 import java.io.IOException;
+import java.util.Collections;
 
 public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
 
@@ -48,7 +49,7 @@ public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
     public static TreeNode createRandomLeafNode(double internalValue) {
         return TreeNode.builder(randomInt(100))
             .setDefaultLeft(randomBoolean() ? null : randomBoolean())
-            .setLeafValue(internalValue)
+            .setLeafValue(Collections.singletonList(internalValue))
             .setNumberSamples(randomNonNegativeLong())
             .build();
     }
@@ -60,7 +61,7 @@ public class TreeNodeTests extends AbstractXContentTestCase<TreeNode> {
                                                 Integer featureIndex,
                                                 Operator operator) {
         return TreeNode.builder(nodeIndex)
-            .setLeafValue(left == null ? randomDouble() : null)
+            .setLeafValue(left == null ? Collections.singletonList(randomDouble()) : null)
             .setDefaultLeft(randomBoolean() ? null : randomBoolean())
             .setLeftChild(left)
             .setRightChild(right)

+ 18 - 10
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

@@ -5,29 +5,37 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
-import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.ingest.IngestDocument;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Map;
 import java.util.Objects;
 
-public class RawInferenceResults extends SingleValueInferenceResults {
+public class RawInferenceResults implements InferenceResults {
 
     public static final String NAME = "raw";
 
-    public RawInferenceResults(double value, Map<String, Double> featureImportance) {
-        super(value, featureImportance);
+    private final double[] value;
+    private final Map<String, Double> featureImportance;
+
+    public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
+        this.value = value;
+        this.featureImportance = featureImportance;
+    }
+
+    public double[] getValue() {
+        return value;
     }
 
-    public RawInferenceResults(StreamInput in) throws IOException {
-        super(in);
+    public Map<String, Double> getFeatureImportance() {
+        return featureImportance;
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
-        super.writeTo(out);
+        throw new UnsupportedOperationException("[raw] does not support wire serialization");
     }
 
     @Override
@@ -35,13 +43,13 @@ public class RawInferenceResults extends SingleValueInferenceResults {
         if (object == this) { return true; }
         if (object == null || getClass() != object.getClass()) { return false; }
         RawInferenceResults that = (RawInferenceResults) object;
-        return Objects.equals(value(), that.value())
-            && Objects.equals(getFeatureImportance(), that.getFeatureImportance());
+        return Arrays.equals(value, that.value)
+            && Objects.equals(featureImportance, that.featureImportance);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(value(), getFeatureImportance());
+        return Objects.hash(Arrays.hashCode(value), featureImportance);
     }
 
     @Override

+ 12 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

@@ -26,30 +26,29 @@ public final class InferenceHelpers {
     /**
      * @return Tuple of the highest scored index and the top classes
      */
-    public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(List<Double> probabilities,
+    public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(double[] probabilities,
                                                                                                 List<String> classificationLabels,
                                                                                                 @Nullable double[] classificationWeights,
                                                                                                 int numToInclude) {
 
-        if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
+        if (classificationLabels != null && probabilities.length != classificationLabels.size()) {
             throw ExceptionsHelper
                 .serverError(
                     "model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
                     null,
-                    probabilities.size(),
+                    probabilities.length,
                     classificationLabels.size());
         }
 
-        List<Double> scores = classificationWeights == null ?
+        double[] scores = classificationWeights == null ?
             probabilities :
-            IntStream.range(0, probabilities.size())
-                .mapToDouble(i -> probabilities.get(i) * classificationWeights[i])
-                .boxed()
-                .collect(Collectors.toList());
+            IntStream.range(0, probabilities.length)
+                .mapToDouble(i -> probabilities[i] * classificationWeights[i])
+                .toArray();
 
-        int[] sortedIndices = IntStream.range(0, probabilities.size())
+        int[] sortedIndices = IntStream.range(0, scores.length)
             .boxed()
-            .sorted(Comparator.comparing(scores::get).reversed())
+            .sorted(Comparator.comparing(i -> scores[(Integer)i]).reversed())
             .mapToInt(i -> i)
             .toArray();
 
@@ -59,14 +58,14 @@ public final class InferenceHelpers {
 
         List<String> labels = classificationLabels == null ?
             // If we don't have the labels we should return the top classification values anyways, they will just be numeric
-            IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) :
+            IntStream.range(0, probabilities.length).boxed().map(String::valueOf).collect(Collectors.toList()) :
             classificationLabels;
 
-        int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size());
+        int count = numToInclude < 0 ? probabilities.length : Math.min(numToInclude, probabilities.length);
         List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
         for(int i = 0; i < count; i++) {
             int idx = sortedIndices[i];
-            topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx)));
+            topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities[idx], scores[idx]));
         }
 
         return Tuple.tuple(sortedIndices[0], topClassEntries);

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.apache.lucene.util.Accountable;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
@@ -62,4 +63,8 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
      * @return A {@code Map<String, Double>} mapping each featureName to its importance
      */
     Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
+
+    default Version getMinimalCompatibilityVersion() {
+        return Version.V_7_6_0;
+    }
 }

+ 17 - 11
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
 import org.apache.lucene.util.Accountable;
 import org.apache.lucene.util.Accountables;
 import org.apache.lucene.util.RamUsageEstimator;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.collect.Tuple;
@@ -20,7 +21,6 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
-import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
@@ -139,19 +139,20 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
             throw ExceptionsHelper.badRequestException(
                 "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
         }
-        List<Double> inferenceResults = new ArrayList<>(this.models.size());
+        double[][] inferenceResults = new double[this.models.size()][];
         List<Map<String, Double>> featureInfluence = new ArrayList<>();
+        int i = 0;
         NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
-        this.models.forEach(model -> {
+        for (TrainedModel model : models) {
             InferenceResults result = model.infer(fields, subModelInferenceConfig, Collections.emptyMap());
-            assert result instanceof SingleValueInferenceResults;
-            SingleValueInferenceResults inferenceResult = (SingleValueInferenceResults) result;
-            inferenceResults.add(inferenceResult.value());
+            assert result instanceof RawInferenceResults;
+            RawInferenceResults inferenceResult = (RawInferenceResults) result;
+            inferenceResults[i++] = inferenceResult.getValue();
             if (config.requestingImportance()) {
                 featureInfluence.add(inferenceResult.getFeatureImportance());
             }
-        });
-        List<Double> processed = outputAggregator.processValues(inferenceResults);
+        }
+        double[] processed = outputAggregator.processValues(inferenceResults);
         return buildResults(processed, featureInfluence, config, featureDecoderMap);
     }
 
@@ -160,13 +161,13 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
         return targetType;
     }
 
-    private InferenceResults buildResults(List<Double> processedInferences,
+    private InferenceResults buildResults(double[] processedInferences,
                                           List<Map<String, Double>> featureInfluence,
                                           InferenceConfig config,
                                           Map<String, String> featureDecoderMap) {
         // Indicates that the config is useless and the caller just wants the raw value
         if (config instanceof NullInferenceConfig) {
-            return new RawInferenceResults(outputAggregator.aggregate(processedInferences),
+            return new RawInferenceResults(new double[] {outputAggregator.aggregate(processedInferences)},
                 InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
         }
         switch(targetType) {
@@ -176,7 +177,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
                     InferenceHelpers.decodeFeatureImportances(featureDecoderMap, mergeFeatureImportances(featureInfluence)));
             case CLASSIFICATION:
                 ClassificationConfig classificationConfig = (ClassificationConfig) config;
-                assert classificationWeights == null || processedInferences.size() == classificationWeights.length;
+                assert classificationWeights == null || processedInferences.length == classificationWeights.length;
                 // Adjust the probabilities according to the thresholds
                 Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
                     processedInferences,
@@ -356,6 +357,11 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
         return Collections.unmodifiableCollection(accountables);
     }
 
+    @Override
+    public Version getMinimalCompatibilityVersion() {
+        return models.stream().map(TrainedModel::getMinimalCompatibilityVersion).max(Version::compareTo).orElse(Version.V_7_6_0);
+    }
+
     public static class Builder {
         private List<String> featureNames;
         private List<TrainedModel> trainedModels;

+ 24 - 16
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java

@@ -19,9 +19,9 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
-import java.util.stream.IntStream;
 
 import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.sigmoid;
+import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
 
 public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
 
@@ -78,31 +78,39 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
     }
 
     @Override
-    public List<Double> processValues(List<Double> values) {
+    public double[] processValues(double[][] values) {
         Objects.requireNonNull(values, "values must not be null");
-        if (weights != null && values.size() != weights.length) {
+        if (weights != null && values.length != weights.length) {
             throw new IllegalArgumentException("values must be the same length as weights.");
         }
-        double summation = weights == null ?
-            values.stream().mapToDouble(Double::valueOf).sum() :
-            IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).sum();
-        double probOfClassOne = sigmoid(summation);
+        double[] sumOnAxis1 = new double[values[0].length];
+        for (int j = 0; j < values.length; j++) {
+            double[] value = values[j];
+            double weight = weights == null ? 1.0 : weights[j];
+            for(int i = 0; i < value.length; i++) {
+                if (i >= sumOnAxis1.length) {
+                    throw new IllegalArgumentException("value entries must have the same dimensions");
+                }
+                sumOnAxis1[i] += (value[i] * weight);
+            }
+        }
+        if (sumOnAxis1.length > 1) {
+            return softMax(sumOnAxis1);
+        }
+
+        double probOfClassOne = sigmoid(sumOnAxis1[0]);
         assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0;
-        return Arrays.asList(1.0 - probOfClassOne, probOfClassOne);
+        return new double[] {1.0 - probOfClassOne, probOfClassOne};
     }
 
     @Override
-    public double aggregate(List<Double> values) {
+    public double aggregate(double[] values) {
         Objects.requireNonNull(values, "values must not be null");
-        assert values.size() == 2;
         int bestValue = 0;
         double bestProb = Double.NEGATIVE_INFINITY;
-        for (int i = 0; i < values.size(); i++) {
-            if (values.get(i) == null) {
-                throw new IllegalArgumentException("values must not contain null values");
-            }
-            if (values.get(i) > bestProb) {
-                bestProb = values.get(i);
+        for (int i = 0; i < values.length; i++) {
+            if (values[i] > bestProb) {
+                bestProb = values[i];
                 bestValue = i;
             }
         }

+ 5 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java

@@ -10,8 +10,6 @@ import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
 
-import java.util.List;
-
 public interface OutputAggregator extends NamedXContentObject, NamedWriteable, Accountable {
 
     /**
@@ -20,15 +18,15 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable, A
     Integer expectedValueSize();
 
     /**
-     * This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(List)} method.
+     * This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(double[])} method.
      *
      * Two major types of pre-processed values could be returned:
-     *   - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(List)}
-     *   - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(List)}
+     *   - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(double[][])}
+     *   - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(double[][])}
      * @param values the values to process
      * @return A new list containing the processed values or the same list if no processing is required
      */
-    List<Double> processValues(List<Double> values);
+    double[] processValues(double[][] values);
 
     /**
      * Function to aggregate the processed values into a single double
@@ -40,7 +38,7 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable, A
      * @param processedValues The values to aggregate
      * @return the aggregated value.
      */
-    double aggregate(List<Double> processedValues);
+    double aggregate(double[] processedValues);
 
     /**
      * @return The name of the output aggregator

+ 36 - 20
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java

@@ -89,21 +89,37 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
     }
 
     @Override
-    public List<Double> processValues(List<Double> values) {
+    public double[] processValues(double[][] values) {
         Objects.requireNonNull(values, "values must not be null");
-        if (weights != null && values.size() != weights.length) {
+        if (weights != null && values.length != weights.length) {
             throw new IllegalArgumentException("values must be the same length as weights.");
         }
+        // Multiple leaf values
+        if (values[0].length > 1) {
+            double[] sumOnAxis1 = new double[values[0].length];
+            for (int j = 0; j < values.length; j++) {
+                double[] value = values[j];
+                double weight = weights == null ? 1.0 : weights[j];
+                for(int i = 0; i < value.length; i++) {
+                    if (i >= sumOnAxis1.length) {
+                        throw new IllegalArgumentException("value entries must have the same dimensions");
+                    }
+                    sumOnAxis1[i] += (value[i] * weight);
+                }
+            }
+            return softMax(sumOnAxis1);
+        }
+        // Singular leaf values
         List<Integer> freqArray = new ArrayList<>();
-        Integer maxVal = 0;
-        for (Double value : values) {
-            if (value == null) {
-                throw new IllegalArgumentException("values must not contain null values");
+        int maxVal = 0;
+        for (double[] value : values) {
+            if (value.length != 1) {
+                throw new IllegalArgumentException("value entries must have the same dimensions");
             }
-            if (Double.isNaN(value) || Double.isInfinite(value) || value < 0.0 || value != Math.rint(value)) {
+            if (Double.isNaN(value[0]) || Double.isInfinite(value[0]) || value[0] < 0.0 || value[0] != Math.rint(value[0])) {
                 throw new IllegalArgumentException("values must be whole, non-infinite, and positive");
             }
-            Integer integerValue = value.intValue();
+            int integerValue = Double.valueOf(value[0]).intValue();
             freqArray.add(integerValue);
             if (integerValue > maxVal) {
                 maxVal = integerValue;
@@ -112,27 +128,27 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
         if (maxVal >= numClasses) {
             throw new IllegalArgumentException("values contain entries larger than expected max of [" + (numClasses - 1) + "]");
         }
-        List<Double> frequencies = new ArrayList<>(Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY));
+        double[] frequencies = Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY)
+            .stream()
+            .mapToDouble(Double::doubleValue)
+            .toArray();
         for (int i = 0; i < freqArray.size(); i++) {
-            Double weight = weights == null ? 1.0 : weights[i];
-            Integer value = freqArray.get(i);
-            Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight;
-            frequencies.set(value, frequency);
+            double weight = weights == null ? 1.0 : weights[i];
+            int value = freqArray.get(i);
+            double frequency = frequencies[value] == Double.NEGATIVE_INFINITY ? weight : frequencies[value] + weight;
+            frequencies[value] = frequency;
         }
         return softMax(frequencies);
     }
 
     @Override
-    public double aggregate(List<Double> values) {
+    public double aggregate(double[] values) {
         Objects.requireNonNull(values, "values must not be null");
         int bestValue = 0;
         double bestFreq = Double.NEGATIVE_INFINITY;
-        for (int i = 0; i < values.size(); i++) {
-            if (values.get(i) == null) {
-                throw new IllegalArgumentException("values must not contain null values");
-            }
-            if (values.get(i) > bestFreq) {
-                bestFreq = values.get(i);
+        for (int i = 0; i < values.length; i++) {
+            if (values[i] > bestFreq) {
+                bestFreq = values[i];
                 bestValue = i;
             }
         }

+ 8 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java

@@ -19,8 +19,6 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
-import java.util.Optional;
-import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
@@ -73,28 +71,25 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
     }
 
     @Override
-    public List<Double> processValues(List<Double> values) {
+    public double[] processValues(double[][] values) {
         Objects.requireNonNull(values, "values must not be null");
+        assert values[0].length == 1;
         if (weights == null) {
-            return values;
+            return Arrays.stream(values).mapToDouble(v -> v[0]).toArray();
         }
-        if (values.size() != weights.length) {
+        if (values.length != weights.length) {
             throw new IllegalArgumentException("values must be the same length as weights.");
         }
-        return IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).boxed().collect(Collectors.toList());
+        return IntStream.range(0, weights.length).mapToDouble(i -> values[i][0] * weights[i]).toArray();
     }
 
     @Override
-    public double aggregate(List<Double> values) {
+    public double aggregate(double[] values) {
         Objects.requireNonNull(values, "values must not be null");
-        if (values.isEmpty()) {
+        if (values.length == 0) {
             throw new IllegalArgumentException("values must not be empty");
         }
-        Optional<Double> summation = values.stream().reduce(Double::sum);
-        if (summation.isPresent()) {
-            return summation.get();
-        }
-        throw new IllegalArgumentException("values must not contain null values");
+        return Arrays.stream(values).sum();
     }
 
     @Override

+ 1 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java

@@ -30,7 +30,6 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
-import java.util.stream.Collectors;
 
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
 import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
@@ -130,7 +129,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
         double[] h0 = hiddenLayer.productPlusBias(false, embeddedVector);
         double[] scores = softmaxLayer.productPlusBias(true, h0);
 
-        List<Double> probabilities = softMax(Arrays.stream(scores).boxed().collect(Collectors.toList()));
+        double[] probabilities = softMax(scores);
 
         ClassificationConfig classificationConfig = (ClassificationConfig) config;
         Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(

+ 67 - 14
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
 import org.apache.lucene.util.Accountable;
 import org.apache.lucene.util.Accountables;
 import org.apache.lucene.util.RamUsageEstimator;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.collect.Tuple;
@@ -29,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfi
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 
@@ -100,7 +102,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         this.nodes = Collections.unmodifiableList(nodes);
         this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
         this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
-        this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
+        this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue);
     }
 
     public Tree(StreamInput in) throws IOException {
@@ -112,7 +114,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         } else {
             this.classificationLabels = null;
         }
-        this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
+        this.highestOrderCategory = new CachedSupplier<>(this::maxLeafValue);
     }
 
     @Override
@@ -147,7 +149,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         return buildResult(node.getLeafValue(), featureImportance, config);
     }
 
-    private InferenceResults buildResult(Double value, Map<String, Double> featureImportance, InferenceConfig config) {
+    private InferenceResults buildResult(double[] value, Map<String, Double> featureImportance, InferenceConfig config) {
+        assert value != null && value.length > 0;
         // Indicates that the config is useless and the caller just wants the raw value
         if (config instanceof NullInferenceConfig) {
             return new RawInferenceResults(value, featureImportance);
@@ -160,13 +163,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
                     classificationLabels,
                     null,
                     classificationConfig.getNumTopClasses());
-                return new ClassificationInferenceResults(value,
+                return new ClassificationInferenceResults(topClasses.v1(),
                     classificationLabel(topClasses.v1(), classificationLabels),
                     topClasses.v2(),
                     featureImportance,
                     config);
             case REGRESSION:
-                return new RegressionInferenceResults(value, config, featureImportance);
+                return new RegressionInferenceResults(value[0], config, featureImportance);
             default:
                 throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
         }
@@ -193,14 +196,22 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         return targetType;
     }
 
-    private List<Double> classificationProbability(double inferenceValue) {
+    private double[] classificationProbability(double[] inferenceValue) {
+        // Multi-value leaves, indicates that the leaves contain an array of values.
+        // The index of which corresponds to classification values
+        if (inferenceValue.length > 1) {
+            return Statistics.softMax(inferenceValue);
+        }
         // If we are classification, we should assume that the inference return value is whole.
-        assert inferenceValue == Math.rint(inferenceValue);
+        assert inferenceValue[0] == Math.rint(inferenceValue[0]);
         double maxCategory = this.highestOrderCategory.get();
         // If we are classification, we should assume that the largest leaf value is whole.
         assert maxCategory == Math.rint(maxCategory);
-        List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
-        list.set(Double.valueOf(inferenceValue).intValue(), 1.0);
+        double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
+            .stream()
+            .mapToDouble(Double::doubleValue)
+            .toArray();
+        list[Double.valueOf(inferenceValue[0]).intValue()] = 1.0;
         return list;
     }
 
@@ -268,6 +279,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         checkTargetType();
         detectMissingNodes();
         detectCycle();
+        verifyLeafNodeUniformity();
     }
 
     @Override
@@ -331,7 +343,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         TreeNode currNode = nodes.get(nodeIndex);
         nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
         if (currNode.isLeaf()) {
-            // TODO multi-value????
             double leafValue = nodeValues[nodeIndex];
             for (int i = 1; i < nextIndex; ++i) {
                 double scale = splitPath.sumUnwoundPath(i, nextIndex);
@@ -375,7 +386,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
     private int fillNodeEstimates(double[] nodeEstimates, int nodeIndex, int depth) {
         TreeNode node = nodes.get(nodeIndex);
         if (node.isLeaf()) {
-            nodeEstimates[nodeIndex] = node.getLeafValue();
+            // TODO multi-value????
+            nodeEstimates[nodeIndex] = node.getLeafValue()[0];
             return 0;
         }
 
@@ -424,6 +436,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
             throw ExceptionsHelper.badRequestException(
                 "[target_type] should be [classification] if [classification_labels] are provided");
         }
+        if (this.targetType != TargetType.CLASSIFICATION && this.nodes.stream().anyMatch(n -> n.getLeafValue().length > 1)) {
+            throw ExceptionsHelper.badRequestException(
+                "[target_type] should be [classification] if leaf nodes have multiple values");
+        }
     }
 
     private void detectCycle() {
@@ -465,14 +481,39 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         }
     }
 
+    private void verifyLeafNodeUniformity() {
+        Integer leafValueLengths = null;
+        for (TreeNode node : nodes) {
+            if (node.isLeaf()) {
+                if (leafValueLengths == null) {
+                    leafValueLengths = node.getLeafValue().length;
+                } else if (leafValueLengths != node.getLeafValue().length) {
+                    throw ExceptionsHelper.badRequestException(
+                        "[tree.tree_structure] all leaf nodes must have the same number of values");
+                }
+            }
+        }
+    }
+
     private static boolean nodeMissing(int nodeIdx, List<TreeNode> nodes) {
         return nodeIdx >= nodes.size();
     }
 
     private Double maxLeafValue() {
-        return targetType == TargetType.CLASSIFICATION ?
-            this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() :
-            null;
+        if (targetType != TargetType.CLASSIFICATION) {
+            return null;
+        }
+        double max = 0.0;
+        for (TreeNode node : this.nodes) {
+            if (node.isLeaf()) {
+                if (node.getLeafValue().length > 1) {
+                    return (double)node.getLeafValue().length;
+                } else {
+                    max = Math.max(node.getLeafValue()[0], max);
+                }
+            }
+        }
+        return max;
     }
 
     @Override
@@ -493,6 +534,14 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         return Collections.unmodifiableCollection(accountables);
     }
 
+    @Override
+    public Version getMinimalCompatibilityVersion() {
+        if (nodes.stream().filter(TreeNode::isLeaf).anyMatch(t -> t.getLeafValue().length > 1)) {
+            return Version.V_7_7_0;
+        }
+        return Version.V_7_6_0;
+    }
+
     public static class Builder {
         private List<String> featureNames;
         private ArrayList<TreeNode.Builder> nodes;
@@ -586,6 +635,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
          * @return this
          */
         Tree.Builder addLeaf(int nodeIndex, double value) {
+            return addLeaf(nodeIndex, Arrays.asList(value));
+        }
+
+        Tree.Builder addLeaf(int nodeIndex, List<Double> value) {
             for (int i = nodes.size(); i < nodeIndex + 1; i++) {
                 nodes.add(null);
             }

+ 43 - 13
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java

@@ -21,6 +21,8 @@ import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.job.config.Operator;
 
 import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 
@@ -60,7 +62,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         parser.declareInt(TreeNode.Builder::setSplitFeature, SPLIT_FEATURE);
         parser.declareInt(TreeNode.Builder::setNodeIndex, NODE_INDEX);
         parser.declareDouble(TreeNode.Builder::setSplitGain, SPLIT_GAIN);
-        parser.declareDouble(TreeNode.Builder::setLeafValue, LEAF_VALUE);
+        parser.declareDoubleArray(TreeNode.Builder::setLeafValue, LEAF_VALUE);
         parser.declareLong(TreeNode.Builder::setNumberSamples, NUMBER_SAMPLES);
         return parser;
     }
@@ -74,7 +76,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
     private final int splitFeature;
     private final int nodeIndex;
     private final double splitGain;
-    private final double leafValue;
+    private final double[] leafValue;
     private final boolean defaultLeft;
     private final int leftChild;
     private final int rightChild;
@@ -86,7 +88,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
                      Integer splitFeature,
                      int nodeIndex,
                      Double splitGain,
-                     Double leafValue,
+                     List<Double> leafValue,
                      Boolean defaultLeft,
                      Integer leftChild,
                      Integer rightChild,
@@ -96,7 +98,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         this.splitFeature = splitFeature == null ? -1 : splitFeature;
         this.nodeIndex = nodeIndex;
         this.splitGain  = splitGain == null ? Double.NaN : splitGain;
-        this.leafValue = leafValue == null ? Double.NaN : leafValue;
+        this.leafValue = leafValue == null ? new double[0] : leafValue.stream().mapToDouble(Double::doubleValue).toArray();
         this.defaultLeft = defaultLeft == null ? false : defaultLeft;
         this.leftChild  = leftChild == null ? -1 : leftChild;
         this.rightChild = rightChild == null ? -1 : rightChild;
@@ -112,7 +114,11 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         splitFeature = in.readInt();
         splitGain = in.readDouble();
         nodeIndex = in.readVInt();
-        leafValue = in.readDouble();
+        if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
+            leafValue = in.readDoubleArray();
+        } else {
+            leafValue = new double[]{in.readDouble()};
+        }
         defaultLeft = in.readBoolean();
         leftChild = in.readInt();
         rightChild = in.readInt();
@@ -144,7 +150,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         return splitGain;
     }
 
-    public double getLeafValue() {
+    public double[] getLeafValue() {
         return leafValue;
     }
 
@@ -190,7 +196,18 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         out.writeInt(splitFeature);
         out.writeDouble(splitGain);
         out.writeVInt(nodeIndex);
-        out.writeDouble(leafValue);
+        if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
+            out.writeDoubleArray(leafValue);
+        } else {
+            if (leafValue.length > 1) {
+                throw new IOException("Multi-class classification models require that all nodes are at least version 7.7.0.");
+            }
+            if (leafValue.length == 0) {
+                out.writeDouble(Double.NaN);
+            } else {
+                out.writeDouble(leafValue[0]);
+            }
+        }
         out.writeBoolean(defaultLeft);
         out.writeInt(leftChild);
         out.writeInt(rightChild);
@@ -209,7 +226,9 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         }
         addOptionalDouble(builder, SPLIT_GAIN, splitGain);
         builder.field(NODE_INDEX.getPreferredName(), nodeIndex);
-        addOptionalDouble(builder, LEAF_VALUE, leafValue);
+        if (leafValue.length > 0) {
+            builder.field(LEAF_VALUE.getPreferredName(), leafValue);
+        }
         builder.field(DEFAULT_LEFT.getPreferredName(), defaultLeft);
         if (leftChild >= 0) {
             builder.field(LEFT_CHILD.getPreferredName(), leftChild);
@@ -238,7 +257,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
             && Objects.equals(splitFeature, that.splitFeature)
             && Objects.equals(nodeIndex, that.nodeIndex)
             && Objects.equals(splitGain, that.splitGain)
-            && Objects.equals(leafValue, that.leafValue)
+            && Arrays.equals(leafValue, that.leafValue)
             && Objects.equals(defaultLeft, that.defaultLeft)
             && Objects.equals(leftChild, that.leftChild)
             && Objects.equals(rightChild, that.rightChild)
@@ -252,7 +271,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
             splitFeature,
             splitGain,
             nodeIndex,
-            leafValue,
+            Arrays.hashCode(leafValue),
             defaultLeft,
             leftChild,
             rightChild,
@@ -270,7 +289,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
 
     @Override
     public long ramBytesUsed() {
-        return SHALLOW_SIZE;
+        return SHALLOW_SIZE + this.leafValue.length * Double.BYTES;
     }
 
     public static class Builder {
@@ -279,7 +298,7 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
         private Integer splitFeature;
         private int nodeIndex;
         private Double splitGain;
-        private Double leafValue;
+        private List<Double> leafValue;
         private Boolean defaultLeft;
         private Integer leftChild;
         private Integer rightChild;
@@ -317,11 +336,19 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
             return this;
         }
 
-        public Builder setLeafValue(Double leafValue) {
+        public Builder setLeafValue(double leafValue) {
+            return this.setLeafValue(Collections.singletonList(leafValue));
+        }
+
+        public Builder setLeafValue(List<Double> leafValue) {
             this.leafValue = leafValue;
             return this;
         }
 
+        List<Double> getLeafValue() {
+            return this.leafValue;
+        }
+
         public Builder setDefaultLeft(Boolean defaultLeft) {
             this.defaultLeft = defaultLeft;
             return this;
@@ -358,6 +385,9 @@ public class TreeNode implements ToXContentObject, Writeable, Accountable {
                 if (leafValue == null) {
                     throw new IllegalArgumentException("[leaf_value] is required for a leaf node.");
                 }
+                if (leafValue.stream().anyMatch(Objects::isNull)) {
+                    throw new IllegalArgumentException("[leaf_value] cannot have null values.");
+                }
             } else {
                 if (leftChild < 0) {
                     throw new IllegalArgumentException("[left_child] must be a non-negative integer.");

+ 19 - 19
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java

@@ -7,8 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.utils;
 
 import org.elasticsearch.common.Numbers;
 
-import java.util.List;
-import java.util.stream.Collectors;
+import java.util.Arrays;
 
 public final class Statistics {
 
@@ -20,28 +19,29 @@ public final class Statistics {
      * Any {@link Double#isInfinite()}, {@link Double#NaN}, or `null` values are ignored in calculation and returned as 0.0 in the
      * softMax.
      * @param values Values on which to run SoftMax.
-     * @return A new list containing the softmax of the passed values
+     * @return A new array containing the softmax of the passed values
      */
-    public static List<Double> softMax(List<Double> values) {
-        Double expSum = 0.0;
-        Double max = values.stream().filter(Statistics::isValid).max(Double::compareTo).orElse(null);
-        if (max == null) {
+    public static double[] softMax(double[] values) {
+        double expSum = 0.0;
+        double max = Arrays.stream(values).filter(Statistics::isValid).max().orElse(Double.NaN);
+        if (isValid(max) == false) {
             throw new IllegalArgumentException("no valid values present");
         }
-        List<Double> exps = values.stream().map(v -> isValid(v) ? v - max : Double.NEGATIVE_INFINITY)
-            .collect(Collectors.toList());
-        for (int i = 0; i < exps.size(); i++) {
-            if (isValid(exps.get(i))) {
-                Double exp = Math.exp(exps.get(i));
+        double[] exps = new double[values.length];
+        for (int i = 0; i < exps.length; i++) {
+            if (isValid(values[i])) {
+                double exp = Math.exp(values[i] - max);
                 expSum += exp;
-                exps.set(i, exp);
+                exps[i] = exp;
+            } else {
+                exps[i] = Double.NaN;
             }
         }
-        for (int i = 0; i < exps.size(); i++) {
-            if (isValid(exps.get(i))) {
-                exps.set(i, exps.get(i)/expSum);
+        for (int i = 0; i < exps.length; i++) {
+            if (isValid(exps[i])) {
+                exps[i] /= expSum;
             } else {
-                exps.set(i, 0.0);
+                exps[i] = 0.0;
             }
         }
         return exps;
@@ -51,8 +51,8 @@ public final class Statistics {
         return 1/(1 + Math.exp(-value));
     }
 
-    private static boolean isValid(Double v) {
-        return v != null && Numbers.isValidDouble(v);
+    private static boolean isValid(double v) {
+        return Numbers.isValidDouble(v);
     }
 
 }

+ 24 - 11
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java

@@ -5,24 +5,37 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.test.ESTestCase;
 
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 
-public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {
+import static org.hamcrest.CoreMatchers.equalTo;
+
+public class RawInferenceResultsTests extends ESTestCase {
 
     public static RawInferenceResults createRandomResults() {
-        return new RawInferenceResults(randomDouble(), randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
+        int n = randomIntBetween(1, 10);
+        double[] results = new double[n];
+        for (int i = 0; i < n; i++) {
+            results[i] = randomDouble();
+        }
+        return new RawInferenceResults(results, randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08));
     }
 
-    @Override
-    protected RawInferenceResults createTestInstance() {
-        return createRandomResults();
+    public void testEqualityAndHashcode() {
+        int n = randomIntBetween(1, 10);
+        double[] results = new double[n];
+        for (int i = 0; i < n; i++) {
+            results[i] = randomDouble();
+        }
+        Map<String, Double> importance = randomBoolean() ? Collections.emptyMap() : Collections.singletonMap("foo", 1.08);
+        RawInferenceResults lft = new RawInferenceResults(results, new HashMap<>(importance));
+        RawInferenceResults rgt = new RawInferenceResults(Arrays.copyOf(results, n), new HashMap<>(importance));
+        assertThat(lft, equalTo(rgt));
+        assertThat(lft.hashCode(), equalTo(rgt.hashCode()));
     }
 
-    @Override
-    protected Writeable.Reader<RawInferenceResults> instanceReader() {
-        return RawInferenceResults::new;
-    }
 }

+ 38 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegressionTests.java

@@ -11,11 +11,10 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 
 import java.io.IOException;
-import java.util.Arrays;
-import java.util.List;
 import java.util.stream.Stream;
 
 import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.equalTo;
 
 public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticRegression> {
@@ -43,7 +42,13 @@ public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticReg
 
     public void testAggregate() {
         double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
-        List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
+        double[][] values = new double[][]{
+            new double[] {1.0},
+            new double[] {2.0},
+            new double[] {2.0},
+            new double[] {3.0},
+            new double[] {5.0}
+        };
 
         LogisticRegression logisticRegression = new LogisticRegression(ones);
         assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0));
@@ -57,6 +62,36 @@ public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticReg
         assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0));
     }
 
+    public void testAggregateMultiValueArrays() {
+        double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
+        double[][] values = new double[][]{
+            new double[] {1.0, 0.0, 1.0},
+            new double[] {2.0, 0.0, 0.0},
+            new double[] {2.0, 3.0, 1.0},
+            new double[] {3.0, 3.0, 1.0},
+            new double[] {1.0, 1.0, 5.0}
+        };
+
+        LogisticRegression logisticRegression = new LogisticRegression(ones);
+        double[] processedValues = logisticRegression.processValues(values);
+        assertThat(processedValues.length, equalTo(3));
+        assertThat(processedValues[0], closeTo(0.665240955, 0.00001));
+        assertThat(processedValues[1], closeTo(0.090030573, 0.00001));
+        assertThat(processedValues[2], closeTo(0.244728471, 0.00001));
+        assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(0.0));
+
+        double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
+
+        logisticRegression = new LogisticRegression(variedWeights);
+        processedValues = logisticRegression.processValues(values);
+        assertThat(processedValues.length, equalTo(3));
+        assertThat(processedValues[0], closeTo(0.0, 0.00001));
+        assertThat(processedValues[1], closeTo(0.0, 0.00001));
+        assertThat(processedValues[2], closeTo(0.9999999, 0.00001));
+        assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(2.0));
+
+    }
+
     public void testCompatibleWith() {
         LogisticRegression logisticRegression = createTestInstance();
         assertThat(logisticRegression.compatibleWith(TargetType.CLASSIFICATION), is(true));

+ 2 - 5
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java

@@ -8,9 +8,6 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.junit.Before;
 
-import java.util.ArrayList;
-import java.util.List;
-
 import static org.hamcrest.Matchers.equalTo;
 
 public abstract class WeightedAggregatorTests<T extends OutputAggregator> extends AbstractSerializingTestCase<T> {
@@ -35,9 +32,9 @@ public abstract class WeightedAggregatorTests<T extends OutputAggregator> extend
 
     public void testWithValuesOfWrongLength() {
         int numberOfValues = randomIntBetween(5, 10);
-        List<Double> values = new ArrayList<>(numberOfValues);
+        double[][] values = new double[numberOfValues][];
         for (int i = 0; i < numberOfValues; i++) {
-            values.add(randomDouble());
+            values[i] = new double[] {randomDouble()};
         }
 
         OutputAggregator outputAggregatorWithTooFewWeights = createTestInstance(randomIntBetween(1, numberOfValues - 1));

+ 52 - 12
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java

@@ -11,8 +11,6 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 
 import java.io.IOException;
-import java.util.Arrays;
-import java.util.List;
 import java.util.stream.Stream;
 
 import static org.hamcrest.CoreMatchers.is;
@@ -44,7 +42,13 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
 
     public void testAggregate() {
         double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
-        List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
+        double[][] values = new double[][]{
+            new double[] {1.0},
+            new double[] {2.0},
+            new double[] {2.0},
+            new double[] {3.0},
+            new double[] {5.0}
+        };
 
         WeightedMode weightedMode = new WeightedMode(ones, 6);
         assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
@@ -57,19 +61,55 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
         weightedMode = new WeightedMode(6);
         assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
 
-        values = Arrays.asList(1.0, 1.0, 1.0, 1.0, 2.0);
+        values = new double[][]{
+            new double[] {1.0},
+            new double[] {1.0},
+            new double[] {1.0},
+            new double[] {1.0},
+            new double[] {2.0}
+        };
         weightedMode = new WeightedMode(6);
-        List<Double> processedValues = weightedMode.processValues(values);
-        assertThat(processedValues.size(), equalTo(6));
-        assertThat(processedValues.get(0), equalTo(0.0));
-        assertThat(processedValues.get(1), closeTo(0.95257412, 0.00001));
-        assertThat(processedValues.get(2), closeTo((1.0 - 0.95257412), 0.00001));
-        assertThat(processedValues.get(3), equalTo(0.0));
-        assertThat(processedValues.get(4), equalTo(0.0));
-        assertThat(processedValues.get(5), equalTo(0.0));
+        double[] processedValues = weightedMode.processValues(values);
+        assertThat(processedValues.length, equalTo(6));
+        assertThat(processedValues[0], equalTo(0.0));
+        assertThat(processedValues[1], closeTo(0.95257412, 0.00001));
+        assertThat(processedValues[2], closeTo((1.0 - 0.95257412), 0.00001));
+        assertThat(processedValues[3], equalTo(0.0));
+        assertThat(processedValues[4], equalTo(0.0));
+        assertThat(processedValues[5], equalTo(0.0));
         assertThat(weightedMode.aggregate(processedValues), equalTo(1.0));
     }
 
+    public void testAggregateMultiValueArrays() {
+        double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
+        double[][] values = new double[][]{
+            new double[] {1.0, 0.0, 1.0},
+            new double[] {2.0, 0.0, 0.0},
+            new double[] {2.0, 3.0, 1.0},
+            new double[] {3.0, 3.0, 1.0},
+            new double[] {1.0, 1.0, 5.0}
+        };
+
+        WeightedMode weightedMode = new WeightedMode(ones, 3);
+        double[] processedValues = weightedMode.processValues(values);
+        assertThat(processedValues.length, equalTo(3));
+        assertThat(processedValues[0], closeTo(0.665240955, 0.00001));
+        assertThat(processedValues[1], closeTo(0.090030573, 0.00001));
+        assertThat(processedValues[2], closeTo(0.244728471, 0.00001));
+        assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(0.0));
+
+        double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
+
+        weightedMode = new WeightedMode(variedWeights, 3);
+        processedValues = weightedMode.processValues(values);
+        assertThat(processedValues.length, equalTo(3));
+        assertThat(processedValues[0], closeTo(0.0, 0.00001));
+        assertThat(processedValues[1], closeTo(0.0, 0.00001));
+        assertThat(processedValues[2], closeTo(0.9999999, 0.00001));
+        assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
+
+    }
+
     public void testCompatibleWith() {
         WeightedMode weightedMode = createTestInstance();
         assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true));

+ 7 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java

@@ -11,8 +11,6 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 
 import java.io.IOException;
-import java.util.Arrays;
-import java.util.List;
 import java.util.stream.Stream;
 
 import static org.hamcrest.CoreMatchers.is;
@@ -43,7 +41,13 @@ public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
 
     public void testAggregate() {
         double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
-        List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
+        double[][] values = new double[][]{
+            new double[] {1.0},
+            new double[] {2.0},
+            new double[] {2.0},
+            new double[] {3.0},
+            new double[] {5.0}
+        };
 
         WeightedSum weightedSum = new WeightedSum(ones);
         assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));

+ 2 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNodeTests.java

@@ -55,7 +55,7 @@ public class TreeNodeTests extends AbstractSerializingTestCase<TreeNode> {
         return TreeNode.builder(randomInt(100))
             .setDefaultLeft(randomBoolean() ? null : randomBoolean())
             .setNumberSamples(randomNonNegativeLong())
-            .setLeafValue(internalValue)
+            .setLeafValue(Collections.singletonList(internalValue))
             .build();
     }
 
@@ -66,7 +66,7 @@ public class TreeNodeTests extends AbstractSerializingTestCase<TreeNode> {
                                                 Integer featureIndex,
                                                 Operator operator) {
         return TreeNode.builder(nodeId)
-            .setLeafValue(left == null ? randomDouble() : null)
+            .setLeafValue(left == null ? Collections.singletonList(randomDouble()) : null)
             .setDefaultLeft(randomBoolean() ? null : randomBoolean())
             .setLeftChild(left)
             .setRightChild(right)

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java

@@ -112,7 +112,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
 
     public void testInferWithStump() {
         Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
-        builder.setRoot(TreeNode.builder(0).setLeafValue(42.0));
+        builder.setRoot(TreeNode.builder(0).setLeafValue(Collections.singletonList(42.0)));
         builder.setFeatureNames(Collections.emptyList());
 
         Tree tree = builder.build();

+ 6 - 6
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java

@@ -16,18 +16,18 @@ import static org.hamcrest.Matchers.closeTo;
 public class StatisticsTests extends ESTestCase {
 
     public void testSoftMax() {
-        List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, 1.0, -0.5, null, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0);
-        List<Double> softMax = Statistics.softMax(values);
+        double[] values = new double[] {Double.NEGATIVE_INFINITY, 1.0, -0.5, Double.NaN, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0};
+        double[] softMax = Statistics.softMax(values);
 
-        List<Double> expected = Arrays.asList(0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042);
+        double[] expected = new double[] {0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042};
 
-        for(int i = 0; i < expected.size(); i++) {
-            assertThat(softMax.get(i), closeTo(expected.get(i), 0.000001));
+        for(int i = 0; i < expected.length; i++) {
+            assertThat(softMax[i], closeTo(expected[i], 0.000001));
         }
     }
 
     public void testSoftMaxWithNoValidValues() {
-        List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, null, Double.NaN, Double.POSITIVE_INFINITY);
+        double[] values = new double[] {Double.NEGATIVE_INFINITY, Double.NaN, Double.POSITIVE_INFINITY};
         expectThrows(IllegalArgumentException.class, () -> Statistics.softMax(values));
     }
 

+ 7 - 7
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java

@@ -211,14 +211,14 @@ public class TrainedModelIT extends ESRestTestCase {
                 .setRightChild(2)
                 .setSplitFeature(0)
                 .setThreshold(0.5),
-                TreeNode.builder(1).setLeafValue(0.3),
+                TreeNode.builder(1).setLeafValue(Collections.singletonList(0.3)),
                 TreeNode.builder(2)
                 .setThreshold(0.0)
                 .setSplitFeature(3)
                 .setLeftChild(3)
                 .setRightChild(4),
-                TreeNode.builder(3).setLeafValue(0.1),
-                TreeNode.builder(4).setLeafValue(0.2))
+                TreeNode.builder(3).setLeafValue(Collections.singletonList(0.1)),
+                TreeNode.builder(4).setLeafValue(Collections.singletonList(0.2)))
             .build();
         Tree tree2 = Tree.builder()
             .setFeatureNames(featureNames)
@@ -227,8 +227,8 @@ public class TrainedModelIT extends ESRestTestCase {
                 .setRightChild(2)
                 .setSplitFeature(2)
                 .setThreshold(1.0),
-                TreeNode.builder(1).setLeafValue(1.5),
-                TreeNode.builder(2).setLeafValue(0.9))
+                TreeNode.builder(1).setLeafValue(Collections.singletonList(1.5)),
+                TreeNode.builder(2).setLeafValue(Collections.singletonList(0.9)))
             .build();
         Tree tree3 = Tree.builder()
             .setFeatureNames(featureNames)
@@ -237,8 +237,8 @@ public class TrainedModelIT extends ESRestTestCase {
                 .setRightChild(2)
                 .setSplitFeature(1)
                 .setThreshold(0.2),
-                TreeNode.builder(1).setLeafValue(1.5),
-                TreeNode.builder(2).setLeafValue(0.9))
+                TreeNode.builder(1).setLeafValue(Collections.singletonList(1.5)),
+                TreeNode.builder(2).setLeafValue(Collections.singletonList(0.9)))
             .build();
         return Ensemble.builder()
             .setTargetType(TargetType.REGRESSION)

+ 12 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

@@ -97,6 +97,18 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
             return;
         }
 
+        Version minCompatibilityVersion = request.getTrainedModelConfig()
+            .getModelDefinition()
+            .getTrainedModel()
+            .getMinimalCompatibilityVersion();
+        if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) {
+            listener.onFailure(ExceptionsHelper.badRequestException(
+                "Definition for [{}] requires that all nodes are at least version [{}]",
+                request.getTrainedModelConfig().getModelId(),
+                minCompatibilityVersion.toString()));
+            return;
+        }
+
         TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
             .setVersion(Version.CURRENT)
             .setCreateTime(Instant.now())

+ 157 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java

@@ -22,6 +22,12 @@ import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceRes
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
@@ -189,6 +195,109 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
         assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
     }
 
+    public void testInferModelMultiClassModel() throws Exception {
+        String modelId = "test-load-models-classification-multi";
+        Map<String, String> oneHotEncoding = new HashMap<>();
+        oneHotEncoding.put("cat", "animal_cat");
+        oneHotEncoding.put("dog", "animal_dog");
+        TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId)
+            .setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
+            .setParsedDefinition(new TrainedModelDefinition.Builder()
+                .setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
+                .setTrainedModel(buildMultiClassClassification()))
+            .setVersion(Version.CURRENT)
+            .setLicenseLevel(License.OperationMode.PLATINUM.description())
+            .setCreateTime(Instant.now())
+            .setEstimatedOperations(0)
+            .setEstimatedHeapMemory(0)
+            .build();
+        AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
+        AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+
+        blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
+        assertThat(putConfigHolder.get(), is(true));
+        assertThat(exceptionHolder.get(), is(nullValue()));
+
+
+        List<Map<String, Object>> toInfer = new ArrayList<>();
+        toInfer.add(new HashMap<>() {{
+            put("field", new HashMap<>(){{
+                put("foo", 1.0);
+                put("bar", 0.5);
+            }});
+            put("other", new HashMap<>(){{
+                put("categorical", "dog");
+            }});
+        }});
+        toInfer.add(new HashMap<>() {{
+            put("field", new HashMap<>(){{
+                put("foo", 0.9);
+                put("bar", 1.5);
+            }});
+            put("other", new HashMap<>(){{
+                put("categorical", "cat");
+            }});
+        }});
+
+        List<Map<String, Object>> toInfer2 = new ArrayList<>();
+        toInfer2.add(new HashMap<>() {{
+            put("field", new HashMap<>(){{
+                put("foo", 0.0);
+                put("bar", 0.01);
+            }});
+            put("other", new HashMap<>(){{
+                put("categorical", "dog");
+            }});
+        }});
+        toInfer2.add(new HashMap<>() {{
+            put("field", new HashMap<>(){{
+                put("foo", 1.0);
+                put("bar", 0.0);
+            }});
+            put("other", new HashMap<>(){{
+                put("categorical", "cat");
+            }});
+        }});
+
+        // Test regression
+        InternalInferModelAction.Request request = new InternalInferModelAction.Request(modelId,
+            toInfer,
+            ClassificationConfig.EMPTY_PARAMS,
+            true);
+        InternalInferModelAction.Response response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
+        assertThat(response.getInferenceResults()
+                .stream()
+                .map(i -> ((SingleValueInferenceResults)i).valueAsString())
+                .collect(Collectors.toList()),
+            contains("option_0", "option_2"));
+
+        request = new InternalInferModelAction.Request(modelId, toInfer2, ClassificationConfig.EMPTY_PARAMS, true);
+        response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
+        assertThat(response.getInferenceResults()
+                .stream()
+                .map(i -> ((SingleValueInferenceResults)i).valueAsString())
+                .collect(Collectors.toList()),
+            contains("option_2", "option_0"));
+
+
+        // Get top classes
+        request = new InternalInferModelAction.Request(modelId, toInfer, new ClassificationConfig(3, null, null), true);
+        response = client().execute(InternalInferModelAction.INSTANCE, request).actionGet();
+
+        ClassificationInferenceResults classificationInferenceResults =
+            (ClassificationInferenceResults)response.getInferenceResults().get(0);
+
+        assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("option_0"));
+        assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("option_2"));
+        assertThat(classificationInferenceResults.getTopClasses().get(2).getClassification(), equalTo("option_1"));
+
+        classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1);
+        assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("option_2"));
+        assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("option_0"));
+        assertThat(classificationInferenceResults.getTopClasses().get(2).getClassification(), equalTo("option_1"));
+    }
+
+
     public void testInferMissingModel() {
         String model = "test-infer-missing-model";
         InternalInferModelAction.Request request = new InternalInferModelAction.Request(
@@ -256,6 +365,54 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
             .setModelId(modelId);
     }
 
+    public static TrainedModel buildMultiClassClassification() {
+        List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
+
+        Tree tree1 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(1.0, 0.0, 2.0)))
+            .addNode(TreeNode.builder(2)
+                .setThreshold(0.8)
+                .setSplitFeature(1)
+                .setLeftChild(3)
+                .setRightChild(4))
+            .addNode(TreeNode.builder(3).setLeafValue(Arrays.asList(0.0, 1.0, 0.0)))
+            .addNode(TreeNode.builder(4).setLeafValue(Arrays.asList(0.0, 0.0, 1.0))).build();
+        Tree tree2 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(3)
+                .setThreshold(1.0))
+            .addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(2.0, 0.0, 0.0)))
+            .addNode(TreeNode.builder(2).setLeafValue(Arrays.asList(0.0, 2.0, 0.0)))
+            .build();
+        Tree tree3 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(1.0))
+            .addNode(TreeNode.builder(1).setLeafValue(Arrays.asList(0.0, 0.0, 1.0)))
+            .addNode(TreeNode.builder(2).setLeafValue(Arrays.asList(0.0, 1.0, 0.0)))
+            .build();
+        return Ensemble.builder()
+            .setClassificationLabels(Arrays.asList("option_0", "option_1", "option_2"))
+            .setTargetType(TargetType.CLASSIFICATION)
+            .setFeatureNames(featureNames)
+            .setTrainedModels(Arrays.asList(tree1, tree2, tree3))
+            .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}, 3))
+            .build();
+    }
+
+
     @Override
     public NamedXContentRegistry xContentRegistry() {
         List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();