浏览代码

[ML][Inference] adding ensemble model objects (#47241)

* [ML][Inference] adding ensemble model objects

* addressing PR comments

* Update TreeTests.java

* addressing PR comments

* fixing test
Benjamin Trent 6 年之前
父节点
当前提交
af4e6ededa
共有 31 个文件被更改,包括 2421 次插入118 次删除
  1. 13 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java
  2. 35 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java
  3. 188 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java
  4. 28 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/OutputAggregator.java
  5. 84 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java
  6. 84 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java
  7. 57 9
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java
  8. 10 4
      client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java
  9. 97 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java
  10. 51 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java
  11. 51 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java
  12. 17 9
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java
  13. 33 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  14. 36 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java
  15. 34 3
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java
  16. 311 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java
  17. 10 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LenientlyParsedOutputAggregator.java
  18. 47 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java
  19. 10 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/StrictlyParsedOutputAggregator.java
  20. 161 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java
  21. 138 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java
  22. 166 67
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java
  23. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java
  24. 52 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java
  25. 2 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java
  26. 402 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java
  27. 51 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java
  28. 58 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java
  29. 58 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java
  30. 103 24
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java
  31. 33 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java

+ 13 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java

@@ -19,6 +19,10 @@
 package org.elasticsearch.client.ml.inference;
 
 import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
 import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
 import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
 import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
@@ -47,6 +51,15 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
 
         // Model
         namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
+        namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Ensemble.NAME), Ensemble::fromXContent));
+
+        // Aggregating output
+        namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
+            new ParseField(WeightedMode.NAME),
+            WeightedMode::fromXContent));
+        namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
+            new ParseField(WeightedSum.NAME),
+            WeightedSum::fromXContent));
 
         return namedXContent;
     }

+ 35 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java

@@ -0,0 +1,35 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel;
+
+import java.util.Locale;
+
+public enum TargetType {
+
+    REGRESSION, CLASSIFICATION;
+
+    public static TargetType fromString(String name) {
+        return valueOf(name.trim().toUpperCase(Locale.ROOT));
+    }
+
+    @Override
+    public String toString() {
+        return name().toLowerCase(Locale.ROOT);
+    }
+}

+ 188 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java

@@ -0,0 +1,188 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+public class Ensemble implements TrainedModel {
+
+    public static final String NAME = "ensemble";
+    public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
+    public static final ParseField TRAINED_MODELS = new ParseField("trained_models");
+    public static final ParseField AGGREGATE_OUTPUT  = new ParseField("aggregate_output");
+    public static final ParseField TARGET_TYPE = new ParseField("target_type");
+    public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
+
+    private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
+        NAME,
+        true,
+        Ensemble.Builder::new);
+
+    static {
+        PARSER.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES);
+        PARSER.declareNamedObjects(Ensemble.Builder::setTrainedModels,
+            (p, c, n) ->
+                    p.namedObject(TrainedModel.class, n, null),
+            (ensembleBuilder) -> { /* Noop does not matter client side */ },
+            TRAINED_MODELS);
+        PARSER.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
+            (p, c, n) -> p.namedObject(OutputAggregator.class, n, null),
+            (ensembleBuilder) -> { /* Noop does not matter client side */ },
+            AGGREGATE_OUTPUT);
+        PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
+        PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
+    }
+
+    public static Ensemble fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null).build();
+    }
+
+    private final List<String> featureNames;
+    private final List<TrainedModel> models;
+    private final OutputAggregator outputAggregator;
+    private final TargetType targetType;
+    private final List<String> classificationLabels;
+
+    Ensemble(List<String> featureNames,
+             List<TrainedModel> models,
+             @Nullable OutputAggregator outputAggregator,
+             TargetType targetType,
+             @Nullable List<String> classificationLabels) {
+        this.featureNames = featureNames;
+        this.models = models;
+        this.outputAggregator = outputAggregator;
+        this.targetType = targetType;
+        this.classificationLabels = classificationLabels;
+    }
+
+    @Override
+    public List<String> getFeatureNames() {
+        return featureNames;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        if (featureNames != null) {
+            builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
+        }
+        if (models != null) {
+            NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models);
+        }
+        if (outputAggregator != null) {
+            NamedXContentObjectHelper.writeNamedObjects(builder,
+                params,
+                false,
+                AGGREGATE_OUTPUT.getPreferredName(),
+                Collections.singletonList(outputAggregator));
+        }
+        if (targetType != null) {
+            builder.field(TARGET_TYPE.getPreferredName(), targetType);
+        }
+        if (classificationLabels != null) {
+            builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Ensemble that = (Ensemble) o;
+        return Objects.equals(featureNames, that.featureNames)
+            && Objects.equals(models, that.models)
+            && Objects.equals(targetType, that.targetType)
+            && Objects.equals(classificationLabels, that.classificationLabels)
+            && Objects.equals(outputAggregator, that.outputAggregator);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType);
+    }
+
+    public static Builder builder() {
+        return new Builder();
+    }
+
+    public static class Builder {
+        private List<String> featureNames;
+        private List<TrainedModel> trainedModels;
+        private OutputAggregator outputAggregator;
+        private TargetType targetType;
+        private List<String> classificationLabels;
+
+        public Builder setFeatureNames(List<String> featureNames) {
+            this.featureNames = featureNames;
+            return this;
+        }
+
+        public Builder setTrainedModels(List<TrainedModel> trainedModels) {
+            this.trainedModels = trainedModels;
+            return this;
+        }
+
+        public Builder setOutputAggregator(OutputAggregator outputAggregator) {
+            this.outputAggregator = outputAggregator;
+            return this;
+        }
+
+        public Builder setTargetType(TargetType targetType) {
+            this.targetType = targetType;
+            return this;
+        }
+
+        public Builder setClassificationLabels(List<String> classificationLabels) {
+            this.classificationLabels = classificationLabels;
+            return this;
+        }
+
+        private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
+            this.setOutputAggregator(outputAggregators.get(0));
+        }
+
+        private void setTargetType(String targetType) {
+            this.targetType = TargetType.fromString(targetType);
+        }
+
+        public Ensemble build() {
+            return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
+        }
+    }
+}

+ 28 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/OutputAggregator.java

@@ -0,0 +1,28 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.client.ml.inference.NamedXContentObject;
+
+public interface OutputAggregator extends NamedXContentObject {
+    /**
+     * @return The name of the output aggregator
+     */
+    String getName();
+}

+ 84 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java

@@ -0,0 +1,84 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
+
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+
+public class WeightedMode implements OutputAggregator {
+
+    public static final String NAME = "weighted_mode";
+    public static final ParseField WEIGHTS = new ParseField("weights");
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<WeightedMode, Void> PARSER = new ConstructingObjectParser<>(
+        NAME,
+        true,
+        a -> new WeightedMode((List<Double>)a[0]));
+    static {
+        PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
+    }
+
+    public static WeightedMode fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final List<Double> weights;
+
+    public WeightedMode(List<Double> weights) {
+        this.weights = weights;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        if (weights != null) {
+            builder.field(WEIGHTS.getPreferredName(), weights);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        WeightedMode that = (WeightedMode) o;
+        return Objects.equals(weights, that.weights);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(weights);
+    }
+}

+ 84 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java

@@ -0,0 +1,84 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
+
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+public class WeightedSum implements OutputAggregator {
+
+    public static final String NAME = "weighted_sum";
+    public static final ParseField WEIGHTS = new ParseField("weights");
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<WeightedSum, Void> PARSER = new ConstructingObjectParser<>(
+        NAME,
+        true,
+        a -> new WeightedSum((List<Double>)a[0]));
+
+    static {
+        PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
+    }
+
+    public static WeightedSum fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final List<Double> weights;
+
+    public WeightedSum(List<Double> weights) {
+        this.weights = weights;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        if (weights != null) {
+            builder.field(WEIGHTS.getPreferredName(), weights);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        WeightedSum that = (WeightedSum) o;
+        return Objects.equals(weights, that.weights);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(weights);
+    }
+}

+ 57 - 9
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java

@@ -18,7 +18,9 @@
  */
 package org.elasticsearch.client.ml.inference.trainedmodel.tree;
 
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.xcontent.ObjectParser;
@@ -28,7 +30,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.stream.Collectors;
@@ -39,12 +40,16 @@ public class Tree implements TrainedModel {
 
     public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
     public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure");
+    public static final ParseField TARGET_TYPE = new ParseField("target_type");
+    public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
 
     private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME, true, Builder::new);
 
     static {
         PARSER.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES);
         PARSER.declareObjectArray(Builder::setNodes, (p, c) -> TreeNode.fromXContent(p), TREE_STRUCTURE);
+        PARSER.declareString(Builder::setTargetType, TARGET_TYPE);
+        PARSER.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS);
     }
 
     public static Tree fromXContent(XContentParser parser) {
@@ -53,10 +58,14 @@ public class Tree implements TrainedModel {
 
     private final List<String> featureNames;
     private final List<TreeNode> nodes;
-
-    Tree(List<String> featureNames, List<TreeNode> nodes) {
-        this.featureNames = Collections.unmodifiableList(Objects.requireNonNull(featureNames));
-        this.nodes = Collections.unmodifiableList(Objects.requireNonNull(nodes));
+    private final TargetType targetType;
+    private final List<String> classificationLabels;
+
+    Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
+        this.featureNames = featureNames;
+        this.nodes = nodes;
+        this.targetType = targetType;
+        this.classificationLabels = classificationLabels;
     }
 
     @Override
@@ -73,11 +82,30 @@ public class Tree implements TrainedModel {
         return nodes;
     }
 
+    @Nullable
+    public List<String> getClassificationLabels() {
+        return classificationLabels;
+    }
+
+    public TargetType getTargetType() {
+        return targetType;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
-        builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
-        builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
+        if (featureNames != null) {
+            builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
+        }
+        if (nodes != null) {
+            builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
+        }
+        if (classificationLabels != null) {
+            builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
+        }
+        if (targetType != null) {
+            builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
+        }
         builder.endObject();
         return  builder;
     }
@@ -93,12 +121,14 @@ public class Tree implements TrainedModel {
         if (o == null || getClass() != o.getClass()) return false;
         Tree that = (Tree) o;
         return Objects.equals(featureNames, that.featureNames)
+            && Objects.equals(classificationLabels, that.classificationLabels)
+            && Objects.equals(targetType, that.targetType)
             && Objects.equals(nodes, that.nodes);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(featureNames, nodes);
+        return Objects.hash(featureNames, nodes, targetType, classificationLabels);
     }
 
     public static Builder builder() {
@@ -109,6 +139,8 @@ public class Tree implements TrainedModel {
         private List<String> featureNames;
         private ArrayList<TreeNode.Builder> nodes;
         private int numNodes;
+        private TargetType targetType;
+        private List<String> classificationLabels;
 
         public Builder() {
             nodes = new ArrayList<>();
@@ -137,6 +169,20 @@ public class Tree implements TrainedModel {
             return setNodes(Arrays.asList(nodes));
         }
 
+        public Builder setTargetType(TargetType targetType) {
+            this.targetType = targetType;
+            return this;
+        }
+
+        public Builder setClassificationLabels(List<String> classificationLabels) {
+            this.classificationLabels = classificationLabels;
+            return this;
+        }
+
+        private void setTargetType(String targetType) {
+            this.targetType = TargetType.fromString(targetType);
+        }
+
         /**
          * Add a decision node. Space for the child nodes is allocated
          * @param nodeIndex         Where to place the node. This is either 0 (root) or an existing child node index
@@ -185,7 +231,9 @@ public class Tree implements TrainedModel {
 
         public Tree build() {
             return new Tree(featureNames,
-                nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()));
+                nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()),
+                targetType,
+                classificationLabels);
         }
     }
 

+ 10 - 4
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -67,6 +67,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Binar
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric;
 import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
 import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
 import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
 import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
@@ -683,7 +686,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(44, namedXContents.size());
+        assertEquals(47, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -693,7 +696,7 @@ public class RestHighLevelClientTests extends ESTestCase {
                 categories.put(namedXContent.categoryClass, counter + 1);
             }
         }
-        assertEquals("Had: " + categories, 11, categories.size());
+        assertEquals("Had: " + categories, 12, categories.size());
         assertEquals(Integer.valueOf(3), categories.get(Aggregation.class));
         assertTrue(names.contains(ChildrenAggregationBuilder.NAME));
         assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME));
@@ -744,8 +747,11 @@ public class RestHighLevelClientTests extends ESTestCase {
                 RSquaredMetric.NAME));
         assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
         assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME));
-        assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
-        assertThat(names, hasItems(Tree.NAME));
+        assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
+        assertThat(names, hasItems(Tree.NAME, Ensemble.NAME));
+        assertEquals(Integer.valueOf(2),
+            categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
+        assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME));
     }
 
     public void testApiNamingConventions() throws Exception {

+ 97 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java

@@ -0,0 +1,97 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> !field.isEmpty();
+    }
+
+    @Override
+    protected Ensemble doParseInstance(XContentParser parser) throws IOException {
+        return Ensemble.fromXContent(parser);
+    }
+
+    public static Ensemble createRandom() {
+        int numberOfFeatures = randomIntBetween(1, 10);
+        List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10))
+            .limit(numberOfFeatures)
+            .collect(Collectors.toList());
+        int numberOfModels = randomIntBetween(1, 10);
+        List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
+            .limit(numberOfFeatures)
+            .collect(Collectors.toList());
+        OutputAggregator outputAggregator = null;
+        if (randomBoolean()) {
+            List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
+            outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
+        }
+        List<String> categoryLabels = null;
+        if (randomBoolean()) {
+            categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
+        }
+        return new Ensemble(featureNames,
+            models,
+            outputAggregator,
+            randomFrom(TargetType.values()),
+            categoryLabels);
+    }
+
+    @Override
+    protected Ensemble createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        return new NamedXContentRegistry(namedXContent);
+    }
+
+}

+ 51 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java

@@ -0,0 +1,51 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class WeightedModeTests extends AbstractXContentTestCase<WeightedMode> {
+
+    WeightedMode createTestInstance(int numberOfWeights) {
+        return new WeightedMode(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
+    }
+
+    @Override
+    protected WeightedMode doParseInstance(XContentParser parser) throws IOException {
+        return WeightedMode.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected WeightedMode createTestInstance() {
+        return randomBoolean() ? new WeightedMode(null) : createTestInstance(randomIntBetween(1, 100));
+    }
+
+}

+ 51 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java

@@ -0,0 +1,51 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class WeightedSumTests extends AbstractXContentTestCase<WeightedSum> {
+
+    WeightedSum createTestInstance(int numberOfWeights) {
+        return new WeightedSum(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
+    }
+
+    @Override
+    protected WeightedSum doParseInstance(XContentParser parser) throws IOException {
+        return WeightedSum.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected WeightedSum createTestInstance() {
+        return randomBoolean() ? new WeightedSum(null) : createTestInstance(randomIntBetween(1, 100));
+    }
+
+}

+ 17 - 9
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java

@@ -18,11 +18,13 @@
  */
 package org.elasticsearch.client.ml.inference.trainedmodel.tree;
 
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractXContentTestCase;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import java.util.function.Predicate;
 
@@ -50,16 +52,17 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
     }
 
     public static Tree createRandom() {
-        return buildRandomTree(randomIntBetween(2, 15),  6);
+        int numberOfFeatures = randomIntBetween(1, 10);
+        List<String> featureNames = new ArrayList<>();
+        for (int i = 0; i < numberOfFeatures; i++) {
+            featureNames.add(randomAlphaOfLength(10));
+        }
+        return buildRandomTree(featureNames,  6);
     }
 
-    public static Tree buildRandomTree(int numFeatures, int depth) {
-
+    public static Tree buildRandomTree(List<String> featureNames, int depth) {
+        int numFeatures = featureNames.size();
         Tree.Builder builder = Tree.builder();
-        List<String> featureNames = new ArrayList<>(numFeatures);
-        for(int i = 0; i < numFeatures; i++) {
-            featureNames.add(randomAlphaOfLength(10));
-        }
         builder.setFeatureNames(featureNames);
 
         TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
@@ -80,8 +83,13 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
             }
             childNodes = nextNodes;
         }
-
-        return builder.build();
+        List<String> categoryLabels = null;
+        if (randomBoolean()) {
+            categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
+        }
+        return builder.setClassificationLabels(categoryLabels)
+            .setTargetType(randomFrom(TargetType.values()))
+            .build();
     }
 
 }

+ 33 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -11,6 +11,12 @@ import org.elasticsearch.plugins.spi.NamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
@@ -46,9 +52,27 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
 
         // Model Lenient
         namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
+        namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Ensemble.NAME, Ensemble::fromXContentLenient));
+
+        // Output Aggregator Lenient
+        namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
+            WeightedMode.NAME,
+            WeightedMode::fromXContentLenient));
+        namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
+            WeightedSum.NAME,
+            WeightedSum::fromXContentLenient));
 
         // Model Strict
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Ensemble.NAME, Ensemble::fromXContentStrict));
+
+        // Output Aggregator Strict
+        namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
+            WeightedMode.NAME,
+            WeightedMode::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
+            WeightedSum.NAME,
+            WeightedSum::fromXContentStrict));
 
         return namedXContent;
     }
@@ -66,6 +90,15 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
 
         // Model
         namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new));
+
+        // Output Aggregator
+        namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
+            WeightedSum.NAME.getPreferredName(),
+            WeightedSum::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
+            WeightedMode.NAME.getPreferredName(),
+            WeightedMode::new));
 
         return namedWriteables;
     }

+ 36 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java

@@ -0,0 +1,36 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+
+import java.io.IOException;
+import java.util.Locale;
+
+public enum TargetType implements Writeable {
+
+    REGRESSION, CLASSIFICATION;
+
+    public static TargetType fromString(String name) {
+        return valueOf(name.trim().toUpperCase(Locale.ROOT));
+    }
+
+    public static TargetType fromStream(StreamInput in) throws IOException {
+        return in.readEnum(TargetType.class);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeEnum(this);
+    }
+
+    @Override
+    public String toString() {
+        return name().toLowerCase(Locale.ROOT);
+    }
+}

+ 34 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
 
@@ -28,17 +29,47 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable {
     double infer(Map<String, Object> fields);
 
     /**
-     * @return {@code true} if the model is classification, {@code false} otherwise.
+     * @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles
+     * @return The predicted value.
      */
-    boolean isClassification();
+    double infer(List<Double> fields);
+
+    /**
+     * @return {@link TargetType} for the model.
+     */
+    TargetType targetType();
 
     /**
      * This gathers the probabilities for each potential classification value.
      *
+     * The probabilities are indexed by classification ordinal label encoding.
+     * The length of this list is equal to the number of classification labels.
+     *
      * This only should return if the implementation model is inferring classification values and not regression
      * @param fields The fields and their values to infer against
      * @return The probabilities of each classification value
      */
-    List<Double> inferProbabilities(Map<String, Object> fields);
+    List<Double> classificationProbability(Map<String, Object> fields);
+
+    /**
+     * @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles
+     * @return The probabilities of each classification value
+     */
+    List<Double> classificationProbability(List<Double> fields);
 
+    /**
+     * The ordinal encoded list of the classification labels.
+     * @return Oridinal encoded list of classification labels.
+     */
+    @Nullable
+    List<String> classificationLabels();
+
+    /**
+     * Runs validations against the model.
+     *
+     * Example: {@link org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree} should check if there are any loops
+     *
+     * @throws org.elasticsearch.ElasticsearchException if validations fail
+     */
+    void validate();
 }

+ 311 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java

@@ -0,0 +1,311 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
+
+    // TODO should we have regression/classification sub-classes that accept the builder?
+    public static final ParseField NAME = new ParseField("ensemble");
+    public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
+    public static final ParseField TRAINED_MODELS = new ParseField("trained_models");
+    public static final ParseField AGGREGATE_OUTPUT  = new ParseField("aggregate_output");
+    public static final ParseField TARGET_TYPE = new ParseField("target_type");
+    public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
+
+    private static final ObjectParser<Ensemble.Builder, Void> LENIENT_PARSER = createParser(true);
+    private static final ObjectParser<Ensemble.Builder, Void> STRICT_PARSER = createParser(false);
+
+    private static ObjectParser<Ensemble.Builder, Void> createParser(boolean lenient) {
+        ObjectParser<Ensemble.Builder, Void> parser = new ObjectParser<>(
+            NAME.getPreferredName(),
+            lenient,
+            Ensemble.Builder::builderForParser);
+        parser.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES);
+        parser.declareNamedObjects(Ensemble.Builder::setTrainedModels,
+            (p, c, n) ->
+                lenient ? p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
+                    p.namedObject(StrictlyParsedTrainedModel.class, n, null),
+            (ensembleBuilder) -> ensembleBuilder.setModelsAreOrdered(true),
+            TRAINED_MODELS);
+        parser.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser,
+            (p, c, n) ->
+                lenient ? p.namedObject(LenientlyParsedOutputAggregator.class, n, null) :
+                    p.namedObject(StrictlyParsedOutputAggregator.class, n, null),
+            (ensembleBuilder) -> {/*Noop as it could be an array or object, it just has to be a one*/},
+            AGGREGATE_OUTPUT);
+        parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
+        parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
+        return parser;
+    }
+
+    public static Ensemble fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null).build();
+    }
+
+    public static Ensemble fromXContentLenient(XContentParser parser) {
+        return LENIENT_PARSER.apply(parser, null).build();
+    }
+
+    private final List<String> featureNames;
+    private final List<TrainedModel> models;
+    private final OutputAggregator outputAggregator;
+    private final TargetType targetType;
+    private final List<String> classificationLabels;
+
+    Ensemble(List<String> featureNames,
+             List<TrainedModel> models,
+             OutputAggregator outputAggregator,
+             TargetType targetType,
+             @Nullable List<String> classificationLabels) {
+        this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
+        this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS));
+        this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
+        this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
+        this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
+    }
+
+    public Ensemble(StreamInput in) throws IOException {
+        this.featureNames = Collections.unmodifiableList(in.readStringList());
+        this.models = Collections.unmodifiableList(in.readNamedWriteableList(TrainedModel.class));
+        this.outputAggregator = in.readNamedWriteable(OutputAggregator.class);
+        this.targetType = TargetType.fromStream(in);
+        if (in.readBoolean()) {
+            this.classificationLabels = in.readStringList();
+        } else {
+            this.classificationLabels = null;
+        }
+    }
+
+    @Override
+    public List<String> getFeatureNames() {
+        return featureNames;
+    }
+
+    @Override
+    public double infer(Map<String, Object> fields) {
+        List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
+        return infer(features);
+    }
+
+    @Override
+    public double infer(List<Double> fields) {
+        List<Double> processedInferences = inferAndProcess(fields);
+        return outputAggregator.aggregate(processedInferences);
+    }
+
+    @Override
+    public TargetType targetType() {
+        return targetType;
+    }
+
+    @Override
+    public List<Double> classificationProbability(Map<String, Object> fields) {
+        if ((targetType == TargetType.CLASSIFICATION) == false) {
+            throw new UnsupportedOperationException(
+                "Cannot determine classification probability with target_type [" + targetType.toString() + "]");
+        }
+        List<Double> features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList());
+        return classificationProbability(features);
+    }
+
+    @Override
+    public List<Double> classificationProbability(List<Double> fields) {
+        if ((targetType == TargetType.CLASSIFICATION) == false) {
+            throw new UnsupportedOperationException(
+                "Cannot determine classification probability with target_type [" + targetType.toString() + "]");
+        }
+        return inferAndProcess(fields);
+    }
+
+    @Override
+    public List<String> classificationLabels() {
+        return classificationLabels;
+    }
+
+    private List<Double> inferAndProcess(List<Double> fields) {
+        List<Double> modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList());
+        return outputAggregator.processValues(modelInferences);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeStringCollection(featureNames);
+        out.writeNamedWriteableList(models);
+        out.writeNamedWriteable(outputAggregator);
+        targetType.writeTo(out);
+        out.writeBoolean(classificationLabels != null);
+        if (classificationLabels != null) {
+            out.writeStringCollection(classificationLabels);
+        }
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
+        NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models);
+        NamedXContentObjectHelper.writeNamedObjects(builder,
+            params,
+            false,
+            AGGREGATE_OUTPUT.getPreferredName(),
+            Collections.singletonList(outputAggregator));
+        builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
+        if (classificationLabels != null) {
+            builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Ensemble that = (Ensemble) o;
+        return Objects.equals(featureNames, that.featureNames)
+            && Objects.equals(models, that.models)
+            && Objects.equals(targetType, that.targetType)
+            && Objects.equals(classificationLabels, that.classificationLabels)
+            && Objects.equals(outputAggregator, that.outputAggregator);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(featureNames, models, outputAggregator, targetType, classificationLabels);
+    }
+
+    @Override
+    public void validate() {
+        if (this.featureNames != null) {
+            if (this.models.stream()
+                .anyMatch(trainedModel -> trainedModel.getFeatureNames().equals(this.featureNames) == false)) {
+                throw ExceptionsHelper.badRequestException(
+                    "[{}] must be the same and in the same order for each of the {}",
+                    FEATURE_NAMES.getPreferredName(),
+                    TRAINED_MODELS.getPreferredName());
+            }
+        }
+        if (outputAggregator.expectedValueSize() != null &&
+            outputAggregator.expectedValueSize() != models.size()) {
+            throw ExceptionsHelper.badRequestException(
+                "[{}] expects value array of size [{}] but number of models is [{}]",
+                AGGREGATE_OUTPUT.getPreferredName(),
+                outputAggregator.expectedValueSize(),
+                models.size());
+        }
+        if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
+            throw ExceptionsHelper.badRequestException(
+                "[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
+        }
+        this.models.forEach(TrainedModel::validate);
+    }
+
+    public static Builder builder() {
+        return new Builder();
+    }
+
+    public static class Builder {
+        private List<String> featureNames;
+        private List<TrainedModel> trainedModels;
+        private OutputAggregator outputAggregator = new WeightedSum();
+        private TargetType targetType = TargetType.REGRESSION;
+        private List<String> classificationLabels;
+        private boolean modelsAreOrdered;
+
+        private Builder (boolean modelsAreOrdered) {
+            this.modelsAreOrdered = modelsAreOrdered;
+        }
+
+        private static Builder builderForParser() {
+            return new Builder(false);
+        }
+
+        public Builder() {
+            this(true);
+        }
+
+        public Builder setFeatureNames(List<String> featureNames) {
+            this.featureNames = featureNames;
+            return this;
+        }
+
+        public Builder setTrainedModels(List<TrainedModel> trainedModels) {
+            this.trainedModels = trainedModels;
+            return this;
+        }
+
+        public Builder setOutputAggregator(OutputAggregator outputAggregator) {
+            this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
+            return this;
+        }
+
+        public Builder setTargetType(TargetType targetType) {
+            this.targetType = targetType;
+            return this;
+        }
+
+        public Builder setClassificationLabels(List<String> classificationLabels) {
+            this.classificationLabels = classificationLabels;
+            return this;
+        }
+
+        private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
+            if (outputAggregators.size() != 1) {
+                throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.",
+                    AGGREGATE_OUTPUT.getPreferredName());
+            }
+            this.setOutputAggregator(outputAggregators.get(0));
+        }
+
+        private void setTargetType(String targetType) {
+            this.targetType = TargetType.fromString(targetType);
+        }
+
+        private void setModelsAreOrdered(boolean value) {
+            this.modelsAreOrdered = value;
+        }
+
+        public Ensemble build() {
+            // This is essentially a serialization error but the underlying xcontent parsing does not allow us to inject this requirement
+            // So, we verify the models were parsed in an ordered fashion here instead.
+            if (modelsAreOrdered == false && trainedModels != null && trainedModels.size() > 1) {
+                throw ExceptionsHelper.badRequestException("[trained_models] needs to be an array of objects");
+            }
+            return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
+        }
+    }
+}

+ 10 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LenientlyParsedOutputAggregator.java

@@ -0,0 +1,10 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+
+public interface LenientlyParsedOutputAggregator extends OutputAggregator {
+}

+ 47 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
+
+import java.util.List;
+
+public interface OutputAggregator extends NamedXContentObject, NamedWriteable {
+
+    /**
+     * @return The expected size of the values array when aggregating. `null` implies there is no expected size.
+     */
+    Integer expectedValueSize();
+
+    /**
+     * This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(List)} method.
+     *
+     * Two major types of pre-processed values could be returned:
+     *   - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(List)}
+     *   - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(List)}
+     * @param values the values to process
+     * @return A new list containing the processed values or the same list if no processing is required
+     */
+    List<Double> processValues(List<Double> values);
+
+    /**
+     * Function to aggregate the processed values into a single double
+     *
+     * This may be as simple as returning the index of the maximum value.
+     *
+     * Or as complex as a mathematical reduction of all the passed values (i.e. summation, average, etc.).
+     *
+     * @param processedValues The values to aggregate
+     * @return the aggregated value.
+     */
+    double aggregate(List<Double> processedValues);
+
+    /**
+     * @return The name of the output aggregator
+     */
+    String getName();
+}

+ 10 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/StrictlyParsedOutputAggregator.java

@@ -0,0 +1,10 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+
+public interface StrictlyParsedOutputAggregator extends OutputAggregator {
+}

+ 161 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java

@@ -0,0 +1,161 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax;
+
+public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
+
+    public static final ParseField NAME = new ParseField("weighted_mode");
+    public static final ParseField WEIGHTS = new ParseField("weights");
+
+    private static final ConstructingObjectParser<WeightedMode, Void> LENIENT_PARSER = createParser(true);
+    private static final ConstructingObjectParser<WeightedMode, Void> STRICT_PARSER = createParser(false);
+
+    @SuppressWarnings("unchecked")
+    private static ConstructingObjectParser<WeightedMode, Void> createParser(boolean lenient) {
+        ConstructingObjectParser<WeightedMode, Void> parser = new ConstructingObjectParser<>(
+            NAME.getPreferredName(),
+            lenient,
+            a -> new WeightedMode((List<Double>)a[0]));
+        parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
+        return parser;
+    }
+
+    public static WeightedMode fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null);
+    }
+
+    public static WeightedMode fromXContentLenient(XContentParser parser) {
+        return LENIENT_PARSER.apply(parser, null);
+    }
+
+    private final List<Double> weights;
+
+    WeightedMode() {
+        this.weights = null;
+    }
+
+    public WeightedMode(List<Double> weights) {
+        this.weights = weights == null ? null : Collections.unmodifiableList(weights);
+    }
+
+    public WeightedMode(StreamInput in) throws IOException {
+        if (in.readBoolean()) {
+            this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
+        } else {
+            this.weights = null;
+        }
+    }
+
+    @Override
+    public Integer expectedValueSize() {
+        return this.weights == null ? null : this.weights.size();
+    }
+
+    @Override
+    public List<Double> processValues(List<Double> values) {
+        Objects.requireNonNull(values, "values must not be null");
+        if (weights != null && values.size() != weights.size()) {
+            throw new IllegalArgumentException("values must be the same length as weights.");
+        }
+        List<Integer> freqArray = new ArrayList<>();
+        Integer maxVal = 0;
+        for (Double value : values) {
+            if (value == null) {
+                throw new IllegalArgumentException("values must not contain null values");
+            }
+            if (Double.isNaN(value) || Double.isInfinite(value) || value < 0.0 || value != Math.rint(value)) {
+                throw new IllegalArgumentException("values must be whole, non-infinite, and positive");
+            }
+            Integer integerValue = value.intValue();
+            freqArray.add(integerValue);
+            if (integerValue > maxVal) {
+                maxVal = integerValue;
+            }
+        }
+        List<Double> frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY));
+        for (int i = 0; i < freqArray.size(); i++) {
+            Double weight = weights == null ? 1.0 : weights.get(i);
+            Integer value = freqArray.get(i);
+            Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight;
+            frequencies.set(value, frequency);
+        }
+        return softMax(frequencies);
+    }
+
+    @Override
+    public double aggregate(List<Double> values) {
+        Objects.requireNonNull(values, "values must not be null");
+        int bestValue = 0;
+        double bestFreq = Double.NEGATIVE_INFINITY;
+        for (int i = 0; i < values.size(); i++) {
+            if (values.get(i) == null) {
+                throw new IllegalArgumentException("values must not contain null values");
+            }
+            if (values.get(i) > bestFreq) {
+                bestFreq = values.get(i);
+                bestValue = i;
+            }
+        }
+        return bestValue;
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeBoolean(weights != null);
+        if (weights != null) {
+            out.writeCollection(weights, StreamOutput::writeDouble);
+        }
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (weights != null) {
+            builder.field(WEIGHTS.getPreferredName(), weights);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        WeightedMode that = (WeightedMode) o;
+        return Objects.equals(weights, that.weights);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(weights);
+    }
+}

+ 138 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java

@@ -0,0 +1,138 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
+
+    public static final ParseField NAME = new ParseField("weighted_sum");
+    public static final ParseField WEIGHTS = new ParseField("weights");
+
+    private static final ConstructingObjectParser<WeightedSum, Void> LENIENT_PARSER = createParser(true);
+    private static final ConstructingObjectParser<WeightedSum, Void> STRICT_PARSER = createParser(false);
+
+    @SuppressWarnings("unchecked")
+    private static ConstructingObjectParser<WeightedSum, Void> createParser(boolean lenient) {
+        ConstructingObjectParser<WeightedSum, Void> parser = new ConstructingObjectParser<>(
+            NAME.getPreferredName(),
+            lenient,
+            a -> new WeightedSum((List<Double>)a[0]));
+        parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
+        return parser;
+    }
+
+    public static WeightedSum fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null);
+    }
+
+    public static WeightedSum fromXContentLenient(XContentParser parser) {
+        return LENIENT_PARSER.apply(parser, null);
+    }
+
+    private final List<Double> weights;
+
+    WeightedSum() {
+        this.weights = null;
+    }
+
+    public WeightedSum(List<Double> weights) {
+        this.weights = weights == null ? null : Collections.unmodifiableList(weights);
+    }
+
+    public WeightedSum(StreamInput in) throws IOException {
+        if (in.readBoolean()) {
+            this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
+        } else {
+            this.weights = null;
+        }
+    }
+
+    @Override
+    public List<Double> processValues(List<Double> values) {
+        Objects.requireNonNull(values, "values must not be null");
+        if (weights == null) {
+            return values;
+        }
+        if (values.size() != weights.size()) {
+            throw new IllegalArgumentException("values must be the same length as weights.");
+        }
+        return IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).boxed().collect(Collectors.toList());
+    }
+
+    @Override
+    public double aggregate(List<Double> values) {
+        Objects.requireNonNull(values, "values must not be null");
+        if (values.isEmpty()) {
+            throw new IllegalArgumentException("values must not be empty");
+        }
+        Optional<Double> summation = values.stream().reduce(Double::sum);
+        if (summation.isPresent()) {
+            return summation.get();
+        }
+        throw new IllegalArgumentException("values must not contain null values");
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeBoolean(weights != null);
+        if (weights != null) {
+            out.writeCollection(weights, StreamOutput::writeDouble);
+        }
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (weights != null) {
+            builder.field(WEIGHTS.getPreferredName(), weights);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        WeightedSum that = (WeightedSum) o;
+        return Objects.equals(weights, that.weights);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(weights);
+    }
+
+    @Override
+    public Integer expectedValueSize() {
+        return weights == null ? null : this.weights.size();
+    }
+}

+ 166 - 67
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java

@@ -9,11 +9,13 @@ import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.util.CachedSupplier;
 import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
@@ -31,10 +33,13 @@ import java.util.stream.Collectors;
 
 public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
 
+    // TODO should we have regression/classification sub-classes that accept the builder?
     public static final ParseField NAME = new ParseField("tree");
 
     public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
     public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure");
+    public static final ParseField TARGET_TYPE = new ParseField("target_type");
+    public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
 
     private static final ObjectParser<Tree.Builder, Void> LENIENT_PARSER = createParser(true);
     private static final ObjectParser<Tree.Builder, Void> STRICT_PARSER = createParser(false);
@@ -46,6 +51,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
             Tree.Builder::new);
         parser.declareStringArray(Tree.Builder::setFeatureNames, FEATURE_NAMES);
         parser.declareObjectArray(Tree.Builder::setNodes, (p, c) -> TreeNode.fromXContent(p, lenient), TREE_STRUCTURE);
+        parser.declareString(Tree.Builder::setTargetType, TARGET_TYPE);
+        parser.declareStringArray(Tree.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
         return parser;
     }
 
@@ -59,15 +66,28 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
 
     private final List<String> featureNames;
     private final List<TreeNode> nodes;
+    private final TargetType targetType;
+    private final List<String> classificationLabels;
+    private final CachedSupplier<Double> highestOrderCategory;
 
-    Tree(List<String> featureNames, List<TreeNode> nodes) {
+    Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
         this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
         this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE));
+        this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
+        this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
+        this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
     }
 
     public Tree(StreamInput in) throws IOException {
         this.featureNames = Collections.unmodifiableList(in.readStringList());
         this.nodes = Collections.unmodifiableList(in.readList(TreeNode::new));
+        this.targetType = TargetType.fromStream(in);
+        if (in.readBoolean()) {
+            this.classificationLabels = Collections.unmodifiableList(in.readStringList());
+        } else {
+            this.classificationLabels = null;
+        }
+        this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
     }
 
     @Override
@@ -90,7 +110,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         return infer(features);
     }
 
-    private double infer(List<Double> features) {
+    @Override
+    public double infer(List<Double> features) {
         TreeNode node = nodes.get(0);
         while(node.isLeaf() == false) {
             node = nodes.get(node.compare(features));
@@ -115,13 +136,40 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
     }
 
     @Override
-    public boolean isClassification() {
-        return false;
+    public TargetType targetType() {
+        return targetType;
+    }
+
+    @Override
+    public List<Double> classificationProbability(Map<String, Object> fields) {
+        if ((targetType == TargetType.CLASSIFICATION) == false) {
+            throw new UnsupportedOperationException(
+                "Cannot determine classification probability with target_type [" + targetType.toString() + "]");
+        }
+        return classificationProbability(featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()));
+    }
+
+    @Override
+    public List<Double> classificationProbability(List<Double> fields) {
+        if ((targetType == TargetType.CLASSIFICATION) == false) {
+            throw new UnsupportedOperationException(
+                "Cannot determine classification probability with target_type [" + targetType.toString() + "]");
+        }
+        double label = infer(fields);
+        // If we are classification, we should assume that the inference return value is whole.
+        assert label == Math.rint(label);
+        double maxCategory = this.highestOrderCategory.get();
+        // If we are classification, we should assume that the largest leaf value is whole.
+        assert maxCategory == Math.rint(maxCategory);
+        List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
+        // TODO, eventually have TreeNodes contain confidence levels
+        list.set(Double.valueOf(label).intValue(), 1.0);
+        return list;
     }
 
     @Override
-    public List<Double> inferProbabilities(Map<String, Object> fields) {
-        throw new UnsupportedOperationException("Cannot infer probabilities against a regression model.");
+    public List<String> classificationLabels() {
+        return classificationLabels;
     }
 
     @Override
@@ -133,6 +181,11 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
     public void writeTo(StreamOutput out) throws IOException {
         out.writeStringCollection(featureNames);
         out.writeCollection(nodes);
+        targetType.writeTo(out);
+        out.writeBoolean(classificationLabels != null);
+        if (classificationLabels != null) {
+            out.writeStringCollection(classificationLabels);
+        }
     }
 
     @Override
@@ -140,6 +193,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         builder.startObject();
         builder.field(FEATURE_NAMES.getPreferredName(), featureNames);
         builder.field(TREE_STRUCTURE.getPreferredName(), nodes);
+        builder.field(TARGET_TYPE.getPreferredName(), targetType.toString());
+        if(classificationLabels != null) {
+            builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
+        }
         builder.endObject();
         return  builder;
     }
@@ -155,22 +212,96 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         if (o == null || getClass() != o.getClass()) return false;
         Tree that = (Tree) o;
         return Objects.equals(featureNames, that.featureNames)
-            && Objects.equals(nodes, that.nodes);
+            && Objects.equals(nodes, that.nodes)
+            && Objects.equals(targetType, that.targetType)
+            && Objects.equals(classificationLabels, that.classificationLabels);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(featureNames, nodes);
+        return Objects.hash(featureNames, nodes, targetType, classificationLabels);
     }
 
     public static Builder builder() {
         return new Builder();
     }
 
+    @Override
+    public void validate() {
+        checkTargetType();
+        detectMissingNodes();
+        detectCycle();
+    }
+
+    private void checkTargetType() {
+        if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
+            throw ExceptionsHelper.badRequestException(
+                "[target_type] should be [classification] if [classification_labels] is provided, and vice versa");
+        }
+    }
+
+    private void detectCycle() {
+        if (nodes.isEmpty()) {
+            return;
+        }
+        Set<Integer> visited = new HashSet<>(nodes.size());
+        Queue<Integer> toVisit = new ArrayDeque<>(nodes.size());
+        toVisit.add(0);
+        while(toVisit.isEmpty() == false) {
+            Integer nodeIdx = toVisit.remove();
+            if (visited.contains(nodeIdx)) {
+                throw ExceptionsHelper.badRequestException("[tree] contains cycle at node {}", nodeIdx);
+            }
+            visited.add(nodeIdx);
+            TreeNode treeNode = nodes.get(nodeIdx);
+            if (treeNode.getLeftChild() >= 0) {
+                toVisit.add(treeNode.getLeftChild());
+            }
+            if (treeNode.getRightChild() >= 0) {
+                toVisit.add(treeNode.getRightChild());
+            }
+        }
+    }
+
+    private void detectMissingNodes() {
+        if (nodes.isEmpty()) {
+            return;
+        }
+
+        List<Integer> missingNodes = new ArrayList<>();
+        for (int i = 0; i < nodes.size(); i++) {
+            TreeNode currentNode = nodes.get(i);
+            if (currentNode == null) {
+                continue;
+            }
+            if (nodeMissing(currentNode.getLeftChild(), nodes)) {
+                missingNodes.add(currentNode.getLeftChild());
+            }
+            if (nodeMissing(currentNode.getRightChild(), nodes)) {
+                missingNodes.add(currentNode.getRightChild());
+            }
+        }
+        if (missingNodes.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException("[tree] contains missing nodes {}", missingNodes);
+        }
+    }
+
+    private static boolean nodeMissing(int nodeIdx, List<TreeNode> nodes) {
+        return nodeIdx >= nodes.size();
+    }
+
+    private Double maxLeafValue() {
+        return targetType == TargetType.CLASSIFICATION ?
+            this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() :
+            null;
+    }
+
     public static class Builder {
         private List<String> featureNames;
         private ArrayList<TreeNode.Builder> nodes;
         private int numNodes;
+        private TargetType targetType = TargetType.REGRESSION;
+        private List<String> classificationLabels;
 
         public Builder() {
             nodes = new ArrayList<>();
@@ -185,13 +316,18 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
             return this;
         }
 
+        public Builder setRoot(TreeNode.Builder root) {
+            nodes.set(0, root);
+            return this;
+        }
+
         public Builder addNode(TreeNode.Builder node) {
             nodes.add(node);
             return this;
         }
 
         public Builder setNodes(List<TreeNode.Builder> nodes) {
-            this.nodes = new ArrayList<>(nodes);
+            this.nodes = new ArrayList<>(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE.getPreferredName()));
             return this;
         }
 
@@ -199,6 +335,21 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
             return setNodes(Arrays.asList(nodes));
         }
 
+
+        public Builder setTargetType(TargetType targetType) {
+            this.targetType = targetType;
+            return this;
+        }
+
+        public Builder setClassificationLabels(List<String> classificationLabels) {
+            this.classificationLabels = classificationLabels;
+            return this;
+        }
+
+        private void setTargetType(String targetType) {
+            this.targetType = TargetType.fromString(targetType);
+        }
+
         /**
          * Add a decision node. Space for the child nodes is allocated
          * @param nodeIndex         Where to place the node. This is either 0 (root) or an existing child node index
@@ -231,61 +382,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
             return node;
         }
 
-        void detectCycle(List<TreeNode.Builder> nodes) {
-            if (nodes.isEmpty()) {
-                return;
-            }
-            Set<Integer> visited = new HashSet<>();
-            Queue<Integer> toVisit = new ArrayDeque<>(nodes.size());
-            toVisit.add(0);
-            while(toVisit.isEmpty() == false) {
-                Integer nodeIdx = toVisit.remove();
-                if (visited.contains(nodeIdx)) {
-                    throw new IllegalArgumentException("[tree] contains cycle at node " + nodeIdx);
-                }
-                visited.add(nodeIdx);
-                TreeNode.Builder treeNode = nodes.get(nodeIdx);
-                if (treeNode.getLeftChild() != null) {
-                    toVisit.add(treeNode.getLeftChild());
-                }
-                if (treeNode.getRightChild() != null) {
-                    toVisit.add(treeNode.getRightChild());
-                }
-            }
-        }
-
-        void detectNullOrMissingNode(List<TreeNode.Builder> nodes) {
-            if (nodes.isEmpty()) {
-                return;
-            }
-            if (nodes.get(0) == null) {
-                throw new IllegalArgumentException("[tree] must have non-null root node.");
-            }
-            List<Integer> nullOrMissingNodes = new ArrayList<>();
-            for (int i = 0; i < nodes.size(); i++) {
-                TreeNode.Builder currentNode = nodes.get(i);
-                if (currentNode == null) {
-                    continue;
-                }
-                if (nodeNullOrMissing(currentNode.getLeftChild())) {
-                    nullOrMissingNodes.add(currentNode.getLeftChild());
-                }
-                if (nodeNullOrMissing(currentNode.getRightChild())) {
-                    nullOrMissingNodes.add(currentNode.getRightChild());
-                }
-            }
-            if (nullOrMissingNodes.isEmpty() == false) {
-                throw new IllegalArgumentException("[tree] contains null or missing nodes " + nullOrMissingNodes);
-            }
-        }
-
-        private boolean nodeNullOrMissing(Integer nodeIdx) {
-            if (nodeIdx == null) {
-                return false;
-            }
-            return nodeIdx >= nodes.size() || nodes.get(nodeIdx) == null;
-        }
-
         /**
          * Sets the node at {@code nodeIndex} to a leaf node.
          * @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)}
@@ -301,10 +397,13 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
         }
 
         public Tree build() {
-            detectNullOrMissingNode(nodes);
-            detectCycle(nodes);
+            if (nodes.stream().anyMatch(Objects::isNull)) {
+                throw ExceptionsHelper.badRequestException("[tree] cannot contain null nodes");
+            }
             return new Tree(featureNames,
-                nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()));
+                nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()),
+                targetType,
+                classificationLabels);
         }
     }
 

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java

@@ -143,7 +143,7 @@ public class TreeNode implements ToXContentObject, Writeable {
     }
 
     public boolean isLeaf() {
-        return leftChild < 1;
+        return leftChild < 0;
     }
 
     public int compare(List<Double> features) {

+ 52 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java

@@ -0,0 +1,52 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.utils;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+public final class Statistics {
+
+    private Statistics(){}
+
+    /**
+     * Calculates the softMax of the passed values.
+     *
+     * Any {@link Double#isInfinite()}, {@link Double#NaN}, or `null` values are ignored in calculation and returned as 0.0 in the
+     * softMax.
+     * @param values Values on which to run SoftMax.
+     * @return A new list containing the softmax of the passed values
+     */
+    public static List<Double> softMax(List<Double> values) {
+        Double expSum = 0.0;
+        Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null);
+        if (max == null) {
+            throw new IllegalArgumentException("no valid values present");
+        }
+        List<Double> exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max)
+            .collect(Collectors.toList());
+        for (int i = 0; i < exps.size(); i++) {
+            if (isInvalid(exps.get(i)) == false) {
+                Double exp = Math.exp(exps.get(i));
+                expSum += exp;
+                exps.set(i, exp);
+            }
+        }
+        for (int i = 0; i < exps.size(); i++) {
+            if (isInvalid(exps.get(i))) {
+                exps.set(i, 0.0);
+            } else {
+                exps.set(i, exps.get(i)/expSum);
+            }
+        }
+        return exps;
+    }
+
+    public static boolean isInvalid(Double v) {
+        return v == null || Double.isInfinite(v) || Double.isNaN(v);
+    }
+
+}

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java

@@ -17,6 +17,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
@@ -157,7 +158,7 @@ public class NamedXContentObjectsTests extends AbstractXContentTestCase<NamedXCo
         NamedObjectContainer container = new NamedObjectContainer();
         container.setPreProcessors(preProcessors);
         container.setUseExplicitPreprocessorOrder(true);
-        container.setModel(TreeTests.buildRandomTree(5, 4));
+        container.setModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom()));
         return container;
     }
 

+ 402 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java

@@ -0,0 +1,402 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.equalTo;
+
+public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
+
+    private boolean lenient;
+
+    @Before
+    public void chooseStrictOrLenient() {
+        lenient = randomBoolean();
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return lenient;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> !field.isEmpty();
+    }
+
+    @Override
+    protected Ensemble doParseInstance(XContentParser parser) throws IOException {
+        return lenient ? Ensemble.fromXContentLenient(parser) : Ensemble.fromXContentStrict(parser);
+    }
+
+    public static Ensemble createRandom() {
+        int numberOfFeatures = randomIntBetween(1, 10);
+        List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList());
+        int numberOfModels = randomIntBetween(1, 10);
+        List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
+            .limit(numberOfModels)
+            .collect(Collectors.toList());
+        List<Double> weights = randomBoolean() ?
+            null :
+            Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
+        OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
+        List<String> categoryLabels = null;
+        if (randomBoolean()) {
+            categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
+        }
+
+        return new Ensemble(featureNames,
+            models,
+            outputAggregator,
+            randomFrom(TargetType.values()),
+            categoryLabels);
+    }
+
+    @Override
+    protected Ensemble createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected Writeable.Reader<Ensemble> instanceReader() {
+        return Ensemble::new;
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        return new NamedXContentRegistry(namedXContent);
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
+        entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(entries);
+    }
+
+    public void testEnsembleWithModelsThatHaveDifferentFeatureNames() {
+        List<String> featureNames = Arrays.asList("foo", "bar", "baz", "farequote");
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
+            Ensemble.builder().setFeatureNames(featureNames)
+                .setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("bar", "foo", "baz", "farequote"), 6)))
+                .build()
+                .validate();
+        });
+        assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
+
+        ex = expectThrows(ElasticsearchException.class, () -> {
+            Ensemble.builder().setFeatureNames(featureNames)
+                .setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("completely_different"), 6)))
+                .build()
+                .validate();
+        });
+        assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models"));
+    }
+
+    public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() {
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        int numberOfModels = 5;
+        List<Double> weights = new ArrayList<>(numberOfModels + 2);
+        for (int i = 0; i < numberOfModels + 2; i++) {
+            weights.add(randomDouble());
+        }
+        OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
+
+        List<TrainedModel> models = new ArrayList<>(numberOfModels);
+        for (int i = 0; i < numberOfModels; i++) {
+            models.add(TreeTests.buildRandomTree(featureNames, 6));
+        }
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
+            Ensemble.builder()
+                .setTrainedModels(models)
+                .setOutputAggregator(outputAggregator)
+                .setFeatureNames(featureNames)
+                .build()
+                .validate();
+        });
+        assertThat(ex.getMessage(), equalTo("[aggregate_output] expects value array of size [7] but number of models is [5]"));
+    }
+
+    public void testEnsembleWithInvalidModel() {
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        expectThrows(ElasticsearchException.class, () -> {
+            Ensemble.builder()
+                .setFeatureNames(featureNames)
+                .setTrainedModels(Arrays.asList(
+                // Tree with loop
+                Tree.builder()
+                    .setNodes(TreeNode.builder(0)
+                    .setLeftChild(1)
+                    .setSplitFeature(1)
+                    .setThreshold(randomDouble()),
+                TreeNode.builder(0)
+                    .setLeftChild(0)
+                    .setSplitFeature(1)
+                    .setThreshold(randomDouble()))
+                    .setFeatureNames(featureNames)
+                    .build()))
+                .build()
+                .validate();
+        });
+    }
+
+    public void testEnsembleWithTargetTypeAndLabelsMismatch() {
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
+            Ensemble.builder()
+                .setFeatureNames(featureNames)
+                .setTrainedModels(Arrays.asList(
+                    Tree.builder()
+                        .setNodes(TreeNode.builder(0)
+                                .setLeftChild(1)
+                                .setSplitFeature(1)
+                                .setThreshold(randomDouble()))
+                        .setFeatureNames(featureNames)
+                        .build()))
+                .setClassificationLabels(Arrays.asList("label1", "label2"))
+                .build()
+                .validate();
+        });
+        assertThat(ex.getMessage(), equalTo(msg));
+        ex = expectThrows(ElasticsearchException.class, () -> {
+            Ensemble.builder()
+                .setFeatureNames(featureNames)
+                .setTrainedModels(Arrays.asList(
+                    Tree.builder()
+                        .setNodes(TreeNode.builder(0)
+                            .setLeftChild(1)
+                            .setSplitFeature(1)
+                            .setThreshold(randomDouble()))
+                        .setFeatureNames(featureNames)
+                        .build()))
+                .setTargetType(TargetType.CLASSIFICATION)
+                .build()
+                .validate();
+        });
+        assertThat(ex.getMessage(), equalTo(msg));
+    }
+
+    public void testClassificationProbability() {
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        Tree tree1 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(1.0))
+            .addNode(TreeNode.builder(2)
+                .setThreshold(0.8)
+                .setSplitFeature(1)
+                .setLeftChild(3)
+                .setRightChild(4))
+            .addNode(TreeNode.builder(3).setLeafValue(0.0))
+            .addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
+        Tree tree2 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(0.0))
+            .addNode(TreeNode.builder(2).setLeafValue(1.0))
+            .build();
+        Tree tree3 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(1)
+                .setThreshold(1.0))
+            .addNode(TreeNode.builder(1).setLeafValue(1.0))
+            .addNode(TreeNode.builder(2).setLeafValue(0.0))
+            .build();
+        Ensemble ensemble = Ensemble.builder()
+            .setTargetType(TargetType.CLASSIFICATION)
+            .setFeatureNames(featureNames)
+            .setTrainedModels(Arrays.asList(tree1, tree2, tree3))
+            .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
+            .build();
+
+        List<Double> featureVector = Arrays.asList(0.4, 0.0);
+        Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
+        List<Double> expected = Arrays.asList(0.231475216, 0.768524783);
+        double eps = 0.000001;
+        List<Double> probabilities = ensemble.classificationProbability(featureMap);
+        for(int i = 0; i < expected.size(); i++) {
+            assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
+        }
+
+        featureVector = Arrays.asList(2.0, 0.7);
+        featureMap = zipObjMap(featureNames, featureVector);
+        expected = Arrays.asList(0.3100255188, 0.689974481);
+        probabilities = ensemble.classificationProbability(featureMap);
+        for(int i = 0; i < expected.size(); i++) {
+            assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
+        }
+
+        featureVector = Arrays.asList(0.0, 1.0);
+        featureMap = zipObjMap(featureNames, featureVector);
+        expected = Arrays.asList(0.231475216, 0.768524783);
+        probabilities = ensemble.classificationProbability(featureMap);
+        for(int i = 0; i < expected.size(); i++) {
+            assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
+        }
+    }
+
+    public void testClassificationInference() {
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        Tree tree1 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(1.0))
+            .addNode(TreeNode.builder(2)
+                .setThreshold(0.8)
+                .setSplitFeature(1)
+                .setLeftChild(3)
+                .setRightChild(4))
+            .addNode(TreeNode.builder(3).setLeafValue(0.0))
+            .addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
+        Tree tree2 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(0.0))
+            .addNode(TreeNode.builder(2).setLeafValue(1.0))
+            .build();
+        Tree tree3 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(1)
+                .setThreshold(1.0))
+            .addNode(TreeNode.builder(1).setLeafValue(1.0))
+            .addNode(TreeNode.builder(2).setLeafValue(0.0))
+            .build();
+        Ensemble ensemble = Ensemble.builder()
+            .setTargetType(TargetType.CLASSIFICATION)
+            .setFeatureNames(featureNames)
+            .setTrainedModels(Arrays.asList(tree1, tree2, tree3))
+            .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
+            .build();
+
+        List<Double> featureVector = Arrays.asList(0.4, 0.0);
+        Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
+        assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
+
+        featureVector = Arrays.asList(2.0, 0.7);
+        featureMap = zipObjMap(featureNames, featureVector);
+        assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
+
+        featureVector = Arrays.asList(0.0, 1.0);
+        featureMap = zipObjMap(featureNames, featureVector);
+        assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
+    }
+
+    public void testRegressionInference() {
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        Tree tree1 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(0.3))
+            .addNode(TreeNode.builder(2)
+                .setThreshold(0.8)
+                .setSplitFeature(1)
+                .setLeftChild(3)
+                .setRightChild(4))
+            .addNode(TreeNode.builder(3).setLeafValue(0.1))
+            .addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
+        Tree tree2 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(1.5))
+            .addNode(TreeNode.builder(2).setLeafValue(0.9))
+            .build();
+        Ensemble ensemble = Ensemble.builder()
+            .setTargetType(TargetType.REGRESSION)
+            .setFeatureNames(featureNames)
+            .setTrainedModels(Arrays.asList(tree1, tree2))
+            .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5)))
+            .build();
+
+        List<Double> featureVector = Arrays.asList(0.4, 0.0);
+        Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
+        assertEquals(0.9, ensemble.infer(featureMap), 0.00001);
+
+        featureVector = Arrays.asList(2.0, 0.7);
+        featureMap = zipObjMap(featureNames, featureVector);
+        assertEquals(0.5, ensemble.infer(featureMap), 0.00001);
+
+        // Test with NO aggregator supplied, verifies default behavior of non-weighted sum
+        ensemble = Ensemble.builder()
+            .setTargetType(TargetType.REGRESSION)
+            .setFeatureNames(featureNames)
+            .setTrainedModels(Arrays.asList(tree1, tree2))
+            .build();
+
+        featureVector = Arrays.asList(0.4, 0.0);
+        featureMap = zipObjMap(featureNames, featureVector);
+        assertEquals(1.8, ensemble.infer(featureMap), 0.00001);
+
+        featureVector = Arrays.asList(2.0, 0.7);
+        featureMap = zipObjMap(featureNames, featureVector);
+        assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
+    }
+
+    private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
+        return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
+    }
+}

+ 51 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java

@@ -0,0 +1,51 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.junit.Before;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public abstract class WeightedAggregatorTests<T extends OutputAggregator> extends AbstractSerializingTestCase<T> {
+
+    protected boolean lenient;
+
+    @Before
+    public void chooseStrictOrLenient() {
+        lenient = randomBoolean();
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return lenient;
+    }
+
+    public void testWithNullValues() {
+        OutputAggregator outputAggregator = createTestInstance();
+        NullPointerException ex = expectThrows(NullPointerException.class, () -> outputAggregator.processValues(null));
+        assertThat(ex.getMessage(), equalTo("values must not be null"));
+    }
+
+    public void testWithValuesOfWrongLength() {
+        int numberOfValues = randomIntBetween(5, 10);
+        List<Double> values = new ArrayList<>(numberOfValues);
+        for (int i = 0; i < numberOfValues; i++) {
+            values.add(randomDouble());
+        }
+
+        OutputAggregator outputAggregatorWithTooFewWeights = createTestInstance(randomIntBetween(1, numberOfValues - 1));
+        expectThrows(IllegalArgumentException.class, () -> outputAggregatorWithTooFewWeights.processValues(values));
+
+        OutputAggregator outputAggregatorWithTooManyWeights = createTestInstance(randomIntBetween(numberOfValues + 1, numberOfValues + 10));
+        expectThrows(IllegalArgumentException.class, () -> outputAggregatorWithTooManyWeights.processValues(values));
+    }
+
+    abstract T createTestInstance(int numberOfWeights);
+}

+ 58 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java

@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
+
+    @Override
+    WeightedMode createTestInstance(int numberOfWeights) {
+        List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
+        return new WeightedMode(weights);
+    }
+
+    @Override
+    protected WeightedMode doParseInstance(XContentParser parser) throws IOException {
+        return lenient ? WeightedMode.fromXContentLenient(parser) : WeightedMode.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected WeightedMode createTestInstance() {
+        return randomBoolean() ? new WeightedMode() : createTestInstance(randomIntBetween(1, 100));
+    }
+
+    @Override
+    protected Writeable.Reader<WeightedMode> instanceReader() {
+        return WeightedMode::new;
+    }
+
+    public void testAggregate() {
+        List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
+        List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
+
+        WeightedMode weightedMode = new WeightedMode(ones);
+        assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
+
+        List<Double> variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0);
+
+        weightedMode = new WeightedMode(variedWeights);
+        assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0));
+
+        weightedMode = new WeightedMode();
+        assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
+    }
+}

+ 58 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java

@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
+
+    @Override
+    WeightedSum createTestInstance(int numberOfWeights) {
+        List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
+        return new WeightedSum(weights);
+    }
+
+    @Override
+    protected WeightedSum doParseInstance(XContentParser parser) throws IOException {
+        return lenient ? WeightedSum.fromXContentLenient(parser) : WeightedSum.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected WeightedSum createTestInstance() {
+        return randomBoolean() ? new WeightedSum() : createTestInstance(randomIntBetween(1, 100));
+    }
+
+    @Override
+    protected Writeable.Reader<WeightedSum> instanceReader() {
+        return WeightedSum::new;
+    }
+
+    public void testAggregate() {
+        List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
+        List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
+
+        WeightedSum weightedSum = new WeightedSum(ones);
+        assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
+
+        List<Double> variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0);
+
+        weightedSum = new WeightedSum(variedWeights);
+        assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(28.0));
+
+        weightedSum = new WeightedSum();
+        assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
+    }
+}

+ 103 - 24
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java

@@ -5,9 +5,12 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
 
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 import org.junit.Before;
 
 import java.io.IOException;
@@ -47,23 +50,23 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
         return field -> field.startsWith("feature_names");
     }
 
-
     @Override
     protected Tree createTestInstance() {
         return createRandom();
     }
 
     public static Tree createRandom() {
-        return buildRandomTree(randomIntBetween(2, 15),  6);
+        int numberOfFeatures = randomIntBetween(1, 10);
+        List<String> featureNames = new ArrayList<>();
+        for (int i = 0; i < numberOfFeatures; i++) {
+            featureNames.add(randomAlphaOfLength(10));
+        }
+        return buildRandomTree(featureNames,  6);
     }
 
-    public static Tree buildRandomTree(int numFeatures, int depth) {
-
+    public static Tree buildRandomTree(List<String> featureNames, int depth) {
         Tree.Builder builder = Tree.builder();
-        List<String> featureNames = new ArrayList<>(numFeatures);
-        for(int i = 0; i < numFeatures; i++) {
-            featureNames.add(randomAlphaOfLength(10));
-        }
+        int numFeatures = featureNames.size() - 1;
         builder.setFeatureNames(featureNames);
 
         TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble());
@@ -84,8 +87,14 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
             }
             childNodes = nextNodes;
         }
+        List<String> categoryLabels = null;
+        if (randomBoolean()) {
+            categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
+        }
 
-        return builder.build();
+        return builder.setTargetType(randomFrom(TargetType.values()))
+            .setClassificationLabels(categoryLabels)
+            .build();
     }
 
     @Override
@@ -96,7 +105,7 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
     public void testInfer() {
         // Build a tree with 2 nodes and 3 leaves using 2 features
         // The leaves have unique values 0.1, 0.2, 0.3
-        Tree.Builder builder = Tree.builder();
+        Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
         TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
         builder.addLeaf(rootNode.getRightChild(), 0.3);
         TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
@@ -124,37 +133,76 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
         assertEquals(0.2, tree.infer(featureMap), 0.00001);
     }
 
+    public void testTreeClassificationProbability() {
+        // Build a tree with 2 nodes and 3 leaves using 2 features
+        // The leaves have unique values 0.1, 0.2, 0.3
+        Tree.Builder builder = Tree.builder().setTargetType(TargetType.CLASSIFICATION);
+        TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
+        builder.addLeaf(rootNode.getRightChild(), 1.0);
+        TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
+        builder.addLeaf(leftChildNode.getLeftChild(), 1.0);
+        builder.addLeaf(leftChildNode.getRightChild(), 0.0);
+
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        Tree tree = builder.setFeatureNames(featureNames).build();
+
+        // This feature vector should hit the right child of the root node
+        List<Double> featureVector = Arrays.asList(0.6, 0.0);
+        Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
+        assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
+
+        // This should hit the left child of the left child of the root node
+        // i.e. it takes the path left, left
+        featureVector = Arrays.asList(0.3, 0.7);
+        featureMap = zipObjMap(featureNames, featureVector);
+        assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
+
+        // This should hit the right child of the left child of the root node
+        // i.e. it takes the path left, right
+        featureVector = Arrays.asList(0.3, 0.9);
+        featureMap = zipObjMap(featureNames, featureVector);
+        assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap));
+    }
+
     public void testTreeWithNullRoot() {
-        IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
-            () -> Tree.builder().setNodes(Collections.singletonList(null))
+        ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
+            () -> Tree.builder()
+                .setNodes(Collections.singletonList(null))
+                .setFeatureNames(Arrays.asList("foo", "bar"))
                 .build());
-        assertThat(ex.getMessage(), equalTo("[tree] must have non-null root node."));
+        assertThat(ex.getMessage(), equalTo("[tree] cannot contain null nodes"));
     }
 
     public void testTreeWithInvalidNode() {
-        IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
-            () -> Tree.builder().setNodes(TreeNode.builder(0)
+        ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
+            () -> Tree.builder()
+                .setNodes(TreeNode.builder(0)
                 .setLeftChild(1)
                 .setSplitFeature(1)
                 .setThreshold(randomDouble()))
-                .build());
-        assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]"));
+                .setFeatureNames(Arrays.asList("foo", "bar"))
+                .build().validate());
+        assertThat(ex.getMessage(), equalTo("[tree] contains missing nodes [1]"));
     }
 
     public void testTreeWithNullNode() {
-        IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
-            () -> Tree.builder().setNodes(TreeNode.builder(0)
+        ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
+            () -> Tree.builder()
+                .setNodes(TreeNode.builder(0)
                 .setLeftChild(1)
                 .setSplitFeature(1)
                 .setThreshold(randomDouble()),
                 null)
-                .build());
-        assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]"));
+                .setFeatureNames(Arrays.asList("foo", "bar"))
+                .build()
+                .validate());
+        assertThat(ex.getMessage(), equalTo("[tree] cannot contain null nodes"));
     }
 
     public void testTreeWithCycle() {
-        IllegalArgumentException ex = expectThrows(IllegalArgumentException.class,
-            () -> Tree.builder().setNodes(TreeNode.builder(0)
+        ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
+            () -> Tree.builder()
+                .setNodes(TreeNode.builder(0)
                     .setLeftChild(1)
                     .setSplitFeature(1)
                     .setThreshold(randomDouble()),
@@ -162,10 +210,41 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
                     .setLeftChild(0)
                     .setSplitFeature(1)
                     .setThreshold(randomDouble()))
-                .build());
+                .setFeatureNames(Arrays.asList("foo", "bar"))
+                .build()
+                .validate());
         assertThat(ex.getMessage(), equalTo("[tree] contains cycle at node 0"));
     }
 
+    public void testTreeWithTargetTypeAndLabelsMismatch() {
+        List<String> featureNames = Arrays.asList("foo", "bar");
+        String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
+            Tree.builder()
+                .setRoot(TreeNode.builder(0)
+                        .setLeftChild(1)
+                        .setSplitFeature(1)
+                        .setThreshold(randomDouble()))
+                .setFeatureNames(featureNames)
+                .setClassificationLabels(Arrays.asList("label1", "label2"))
+                .build()
+                .validate();
+        });
+        assertThat(ex.getMessage(), equalTo(msg));
+        ex = expectThrows(ElasticsearchException.class, () -> {
+            Tree.builder()
+                .setRoot(TreeNode.builder(0)
+                    .setLeftChild(1)
+                    .setSplitFeature(1)
+                    .setThreshold(randomDouble()))
+                .setFeatureNames(featureNames)
+                .setTargetType(TargetType.CLASSIFICATION)
+                .build()
+                .validate();
+        });
+        assertThat(ex.getMessage(), equalTo(msg));
+    }
+
     private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
         return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
     }

+ 33 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java

@@ -0,0 +1,33 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.utils;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.hamcrest.Matchers.closeTo;
+
+public class StatisticsTests extends ESTestCase {
+
+    public void testSoftMax() {
+        List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, 1.0, -0.5, null, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0);
+        List<Double> softMax = Statistics.softMax(values);
+
+        List<Double> expected = Arrays.asList(0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042);
+
+        for(int i = 0; i < expected.size(); i++) {
+            assertThat(softMax.get(i), closeTo(expected.get(i), 0.000001));
+        }
+    }
+
+    public void testSoftMaxWithNoValidValues() {
+        List<Double> values = Arrays.asList(Double.NEGATIVE_INFINITY, null, Double.NaN, Double.POSITIVE_INFINITY);
+        expectThrows(IllegalArgumentException.class, () -> Statistics.softMax(values));
+    }
+
+}