Browse Source

[ML] add exponent output aggregator to inference (#58933)

* [ML] add exponent output aggregator to inference

* fixing docs
Benjamin Trent 5 years ago
parent
commit
6238d4fc49

+ 4 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java

@@ -24,6 +24,7 @@ import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
@@ -82,6 +83,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
             new ParseField(LogisticRegression.NAME),
             LogisticRegression::fromXContent));
+        namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class,
+            new ParseField(Exponent.NAME),
+            Exponent::fromXContent));
 
         return namedXContent;
     }

+ 83 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Exponent.java

@@ -0,0 +1,83 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
+
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+
+public class Exponent implements OutputAggregator {
+
+    public static final String NAME = "exponent";
+    public static final ParseField WEIGHTS = new ParseField("weights");
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<Exponent, Void> PARSER = new ConstructingObjectParser<>(
+        NAME,
+        true,
+        a -> new Exponent((List<Double>)a[0]));
+    static {
+        PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
+    }
+
+    public static Exponent fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final List<Double> weights;
+
+    public Exponent(List<Double> weights) {
+        this.weights = weights;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (weights != null) {
+            builder.field(WEIGHTS.getPreferredName(), weights);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Exponent that = (Exponent) o;
+        return Objects.equals(weights, that.weights);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(weights);
+    }
+}

+ 4 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -80,6 +80,7 @@ import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
 import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
 import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
@@ -703,7 +704,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(68, namedXContents.size());
+        assertEquals(69, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -788,9 +789,9 @@ public class RestHighLevelClientTests extends ESTestCase {
         assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME));
         assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
         assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
-        assertEquals(Integer.valueOf(3),
+        assertEquals(Integer.valueOf(4),
             categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class));
-        assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME));
+        assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME, LogisticRegression.NAME, Exponent.NAME));
         assertEquals(Integer.valueOf(2),
             categories.get(org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig.class));
         assertThat(names, hasItems(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName()));

+ 2 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java

@@ -73,7 +73,8 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
             categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));
         }
         List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
-        OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) :
+        OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ?
+            randomFrom(new WeightedSum(weights), new Exponent(weights)) :
             randomFrom(
                 new WeightedMode(
                     categoryLabels != null ? categoryLabels.size() : randomIntBetween(2, 10),

+ 51 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/ExponentTests.java

@@ -0,0 +1,51 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class ExponentTests extends AbstractXContentTestCase<Exponent> {
+
+    Exponent createTestInstance(int numberOfWeights) {
+        return new Exponent(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
+    }
+
+    @Override
+    protected Exponent doParseInstance(XContentParser parser) throws IOException {
+        return Exponent.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Exponent createTestInstance() {
+        return randomBoolean() ? new Exponent(null) : createTestInstance(randomIntBetween(1, 100));
+    }
+
+}

+ 105 - 75
docs/reference/ml/df-analytics/apis/put-inference.asciidoc

@@ -27,7 +27,7 @@ experimental[]
 [[ml-put-inference-prereq]]
 ==== {api-prereq-title}
 
-If the {es} {security-features} are enabled, you must have the following 
+If the {es} {security-features} are enabled, you must have the following
 built-in roles and privileges:
 
 * `machine_learning_admin`
@@ -38,15 +38,15 @@ For more information, see <<security-privileges>> and <<built-in-roles>>.
 [[ml-put-inference-desc]]
 ==== {api-description-title}
 
-The create {infer} trained model API enables you to supply a trained model that 
-is not created by {dfanalytics}. 
+The create {infer} trained model API enables you to supply a trained model that
+is not created by {dfanalytics}.
 
 
 [[ml-put-inference-path-params]]
 ==== {api-path-parms-title}
 
 `<model_id>`::
-(Required, string) 
+(Required, string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 
 [role="child_attributes"]
@@ -54,14 +54,14 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 ==== {api-request-body-title}
 
 `compressed_definition`::
-(Required, string) 
-The compressed (GZipped and Base64 encoded) {infer} definition of the model. 
+(Required, string)
+The compressed (GZipped and Base64 encoded) {infer} definition of the model.
 If `compressed_definition` is specified, then `definition` cannot be specified.
 
 //Begin definition
 `definition`::
-(Required, object) 
-The {infer} definition for the model. If `definition` is specified, then 
+(Required, object)
+The {infer} definition for the model. If `definition` is specified, then
 `compressed_definition` cannot be specified.
 +
 .Properties of `definition`
@@ -77,58 +77,58 @@ Collection of preprocessors. See <<ml-put-inference-preprocessor-example>>.
 =====
 //Begin frequency encoding
 `frequency_encoding`::
-(Required, object) 
+(Required, object)
 Defines a frequency encoding for a field.
 +
 .Properties of `frequency_encoding`
 [%collapsible%open]
 ======
 `feature_name`::
-(Required, string) 
+(Required, string)
 The name of the resulting feature.
 
 `field`::
-(Required, string) 
+(Required, string)
 The field name to encode.
 
 `frequency_map`::
-(Required, object map of string:double) 
+(Required, object map of string:double)
 Object that maps the field value to the frequency encoded value.
 ======
 //End frequency encoding
 
 //Begin one hot encoding
 `one_hot_encoding`::
-(Required, object) 
+(Required, object)
 Defines a one hot encoding map for a field.
 +
 .Properties of `one_hot_encoding`
 [%collapsible%open]
 ======
 `field`::
-(Required, string) 
+(Required, string)
 The field name to encode.
 
 `hot_map`::
-(Required, object map of strings) 
+(Required, object map of strings)
 String map of "field_value: one_hot_column_name".
 ======
 //End one hot encoding
 
 //Begin target mean encoding
 `target_mean_encoding`::
-(Required, object) 
+(Required, object)
 Defines a target mean encoding for a field.
 +
 .Properties of `target_mean_encoding`
 [%collapsible%open]
 ======
 `default_value`:::
-(Required, double) 
+(Required, double)
 The feature value if the field value is not in the `target_map`.
 
 `feature_name`:::
-(Required, string) 
+(Required, string)
 The name of the resulting feature.
 
 `field`:::
@@ -136,7 +136,7 @@ The name of the resulting feature.
 The field name to encode.
 
 `target_map`:::
-(Required, object map of string:double) 
+(Required, object map of string:double)
 Object that maps the field value to the target mean value.
 ======
 //End target mean encoding
@@ -145,7 +145,7 @@ Object that maps the field value to the target mean value.
 
 //Begin trained model
 `trained_model`::
-(Required, object) 
+(Required, object)
 The definition of the trained model.
 +
 .Properties of `trained_model`
@@ -153,14 +153,14 @@ The definition of the trained model.
 =====
 //Begin tree
 `tree`::
-(Required, object) 
+(Required, object)
 The definition for a binary decision tree.
 +
 .Properties of `tree`
 [%collapsible%open]
 ======
 `classification_labels`:::
-(Optional, string) An array of classification labels (used for 
+(Optional, string) An array of classification labels (used for
 `classification`).
 
 `feature_names`:::
@@ -168,26 +168,26 @@ The definition for a binary decision tree.
 Features expected by the tree, in their expected order.
 
 `target_type`:::
-(Required, string) 
+(Required, string)
 String indicating the model target type; `regression` or `classification`.
 
 `tree_structure`:::
-(Required, object) 
-An array of `tree_node` objects. The nodes must be in ordinal order by their 
+(Required, object)
+An array of `tree_node` objects. The nodes must be in ordinal order by their
 `tree_node.node_index` value.
 ======
 //End tree
 
 //Begin tree node
 `tree_node`::
-(Required, object) 
+(Required, object)
 The definition of a node in a tree.
 +
 --
 There are two major types of nodes: leaf nodes and not-leaf nodes.
 
 * Leaf nodes only need `node_index` and `leaf_value` defined.
-* All other nodes need `split_feature`, `left_child`, `right_child`, 
+* All other nodes need `split_feature`, `left_child`, `right_child`,
   `threshold`, `decision_type`, and `default_left` defined.
 --
 +
@@ -195,41 +195,41 @@ There are two major types of nodes: leaf nodes and not-leaf nodes.
 [%collapsible%open]
 ======
 `decision_type`::
-(Optional, string) 
-Indicates the positive value (in other words, when to choose the left node) 
+(Optional, string)
+Indicates the positive value (in other words, when to choose the left node)
 decision type. Supported `lt`, `lte`, `gt`, `gte`. Defaults to `lte`.
 
 `default_left`::
-(Optional, boolean) 
-Indicates whether to default to the left when the feature is missing. Defaults 
+(Optional, boolean)
+Indicates whether to default to the left when the feature is missing. Defaults
 to `true`.
 
 `leaf_value`::
-(Optional, double) 
-The leaf value of the of the node, if the value is a leaf (in other words, no 
+(Optional, double)
+The leaf value of the of the node, if the value is a leaf (in other words, no
 children).
 
 `left_child`::
-(Optional, integer) 
+(Optional, integer)
 The index of the left child.
 
 `node_index`::
-(Integer) 
+(Integer)
 The index of the current node.
 
 `right_child`::
-(Optional, integer) 
+(Optional, integer)
 The index of the right child.
 
 `split_feature`::
-(Optional, integer) 
+(Optional, integer)
 The index of the feature value in the feature array.
 
 `split_gain`::
 (Optional, double) The information gain from the split.
 
 `threshold`::
-(Optional, double) 
+(Optional, double)
 The decision threshold with which to compare the feature value.
 ======
 //End tree node
@@ -244,9 +244,9 @@ The definition for an ensemble model. See <<ml-put-inference-model-example>>.
 ======
 //Begin aggregate output
 `aggregate_output`::
-(Required, object) 
-An aggregated output object that defines how to aggregate the outputs of the 
-`trained_models`. Supported objects are `weighted_mode`, `weighted_sum`, and 
+(Required, object)
+An aggregated output object that defines how to aggregate the outputs of the
+`trained_models`. Supported objects are `weighted_mode`, `weighted_sum`, and
 `logistic_regression`. See <<ml-put-inference-aggregated-output-example>>.
 +
 .Properties of `aggregate_output`
@@ -254,65 +254,82 @@ An aggregated output object that defines how to aggregate the outputs of the
 =======
 //Begin logistic regression
 `logistic_regression`::
-(Optional, object) 
-This `aggregated_output` type works with binary classification (classification 
-for values [0, 1]). It multiplies the outputs (in the case of the `ensemble` 
-model, the inference model values) by the supplied `weights`. The resulting 
-vector is summed and passed to a 
-https://en.wikipedia.org/wiki/Sigmoid_function[`sigmoid` function]. The result 
-of the `sigmoid` function is considered the probability of class 1 (`P_1`), 
-consequently, the probability of class 0 is `1 - P_1`. The class with the 
-highest probability (either 0 or 1) is then returned. For more information about 
-logistic regression, see 
+(Optional, object)
+This `aggregated_output` type works with binary classification (classification
+for values [0, 1]). It multiplies the outputs (in the case of the `ensemble`
+model, the inference model values) by the supplied `weights`. The resulting
+vector is summed and passed to a
+https://en.wikipedia.org/wiki/Sigmoid_function[`sigmoid` function]. The result
+of the `sigmoid` function is considered the probability of class 1 (`P_1`),
+consequently, the probability of class 0 is `1 - P_1`. The class with the
+highest probability (either 0 or 1) is then returned. For more information about
+logistic regression, see
 https://en.wikipedia.org/wiki/Logistic_regression[this wiki article].
 +
 .Properties of `logistic_regression`
 [%collapsible%open]
 ========
 `weights`:::
-(Required, double) 
-The weights to multiply by the input values (the inference values of the trained 
+(Required, double)
+The weights to multiply by the input values (the inference values of the trained
 models).
 ========
 //End logistic regression
 
 //Begin weighted sum
 `weighted_sum`::
-(Optional, object) 
-This `aggregated_output` type works with regression. The weighted sum of the 
+(Optional, object)
+This `aggregated_output` type works with regression. The weighted sum of the
 input values.
 +
 .Properties of `weighted_sum`
 [%collapsible%open]
 ========
 `weights`:::
-(Required, double) 
-The weights to multiply by the input values (the inference values of the trained 
+(Required, double)
+The weights to multiply by the input values (the inference values of the trained
 models).
 ========
 //End weighted sum
 
 //Begin weighted mode
 `weighted_mode`::
-(Optional, object) 
-This `aggregated_output` type works with regression or classification. It takes 
-a weighted vote of the input values. The most common input value (taking the 
+(Optional, object)
+This `aggregated_output` type works with regression or classification. It takes
+a weighted vote of the input values. The most common input value (taking the
 weights into account) is returned.
 +
 .Properties of `weighted_mode`
 [%collapsible%open]
 ========
 `weights`:::
-(Required, double) 
-The weights to multiply by the input values (the inference values of the trained 
+(Required, double)
+The weights to multiply by the input values (the inference values of the trained
 models).
 ========
 //End weighted mode
+
+//Begin exponent
+`exponent`::
+(Optional, object)
+This `aggregated_output` type works with regression. It takes a weighted sum of
+the input values and passes the result to an exponent function
+(`e^x` where `x` is the sum of the weighted values).
++
+.Properties of `exponent`
+[%collapsible%open]
+========
+`weights`:::
+(Required, double)
+The weights to multiply by the input values (the inference values of the trained
+models).
+========
+//End exponent
 =======
 //End aggregate output
 
 `classification_labels`::
-(Optional, string) 
+(Optional, string)
 An array of classification labels.
 
 `feature_names`::
@@ -320,12 +337,12 @@ An array of classification labels.
 Features expected by the ensemble, in their expected order.
 
 `target_type`::
-(Required, string) 
+(Required, string)
 String indicating the model target type; `regression` or `classification.`
 
 `trained_models`::
 (Required, object)
-An array of `trained_model` objects. Supported trained models are `tree` and 
+An array of `trained_model` objects. Supported trained models are `tree` and
 `ensemble`.
 ======
 //End ensemble
@@ -337,7 +354,7 @@ An array of `trained_model` objects. Supported trained models are `tree` and
 //End definition
 
 `description`::
-(Optional, string) 
+(Optional, string)
 A human-readable description of the {infer} trained model.
 
 //Begin inference_config
@@ -398,24 +415,24 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification
 
 //Begin input
 `input`::
-(Required, object) 
+(Required, object)
 The input field names for the model definition.
 +
 .Properties of `input`
 [%collapsible%open]
 ====
 `field_names`:::
-(Required, string) 
+(Required, string)
 An array of input field names for the model.
 ====
 //End input
 
 `metadata`::
-(Optional, object) 
+(Optional, object)
 An object map that contains metadata about the model.
 
 `tags`::
-(Optional, string) 
+(Optional, string)
 An array of tags to organize the model.
 
 
@@ -451,10 +468,10 @@ The next example shows a `one_hot_encoding` preprocessor object:
 
 [source,js]
 ----------------------------------
-{ 
-   "one_hot_encoding":{ 
+{
+   "one_hot_encoding":{
       "field":"FlightDelayType",
-      "hot_map":{ 
+      "hot_map":{
          "Carrier Delay":"FlightDelayType_Carrier Delay",
          "NAS Delay":"FlightDelayType_NAS Delay",
          "No Delay":"FlightDelayType_No Delay",
@@ -521,7 +538,7 @@ The first example shows a `trained_model` object:
             "left_child":1,
             "right_child":2
          },
-         ...         
+         ...
          {
             "node_index":9,
             "leaf_value":-27.68987349695448
@@ -615,8 +632,21 @@ Example of a `weighted_mode` object:
 //NOTCONSOLE
 
 
+Example of an `exponent` object:
+
+[source,js]
+----------------------------------
+"aggregate_output" : {
+  "exponent" : {
+    "weights" : [1.0, 1.0, 1.0, 1.0, 1.0]
+  }
+}
+----------------------------------
+//NOTCONSOLE
+
+
 [[ml-put-inference-json-schema]]
 ===== {infer-cap} JSON schema
 
-For the full JSON schema of model {infer}, 
+For the full JSON schema of model {infer},
 https://github.com/elastic/ml-json-schemas[click here].

+ 10 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInfe
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Exponent;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
@@ -91,6 +92,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
             LogisticRegression.NAME,
             LogisticRegression::fromXContentLenient));
+        namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class,
+            Exponent.NAME,
+            Exponent::fromXContentLenient));
 
         // Model Strict
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict));
@@ -109,6 +113,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
             LogisticRegression.NAME,
             LogisticRegression::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class,
+            Exponent.NAME,
+            Exponent::fromXContentStrict));
 
         // Inference Configs
         namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedInferenceConfig.class, ClassificationConfig.NAME,
@@ -165,6 +172,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
             LogisticRegression.NAME.getPreferredName(),
             LogisticRegression::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class,
+            Exponent.NAME.getPreferredName(),
+            Exponent::new));
 
         // Inference Results
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,

+ 157 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Exponent.java

@@ -0,0 +1,157 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+import org.apache.lucene.util.RamUsageEstimator;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+public class Exponent implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
+
+    public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Exponent.class);
+    public static final ParseField NAME = new ParseField("exponent");
+    public static final ParseField WEIGHTS = new ParseField("weights");
+
+    private static final ConstructingObjectParser<Exponent, Void> LENIENT_PARSER = createParser(true);
+    private static final ConstructingObjectParser<Exponent, Void> STRICT_PARSER = createParser(false);
+
+    @SuppressWarnings("unchecked")
+    private static ConstructingObjectParser<Exponent, Void> createParser(boolean lenient) {
+        ConstructingObjectParser<Exponent, Void> parser = new ConstructingObjectParser<>(
+            NAME.getPreferredName(),
+            lenient,
+            a -> new Exponent((List<Double>)a[0]));
+        parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
+        return parser;
+    }
+
+    public static Exponent fromXContentStrict(XContentParser parser) {
+        return STRICT_PARSER.apply(parser, null);
+    }
+
+    public static Exponent fromXContentLenient(XContentParser parser) {
+        return LENIENT_PARSER.apply(parser, null);
+    }
+
+    private final double[] weights;
+
+    Exponent() {
+        this((List<Double>) null);
+    }
+
+    private Exponent(List<Double> weights) {
+        this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
+    }
+
+    public Exponent(double[] weights) {
+        this.weights = weights;
+    }
+
+    public Exponent(StreamInput in) throws IOException {
+        if (in.readBoolean()) {
+            this.weights = in.readDoubleArray();
+        } else {
+            this.weights = null;
+        }
+    }
+
+    @Override
+    public Integer expectedValueSize() {
+        return this.weights == null ? null : this.weights.length;
+    }
+
+    @Override
+    public double[] processValues(double[][] values) {
+        Objects.requireNonNull(values, "values must not be null");
+        if (weights != null && values.length != weights.length) {
+            throw new IllegalArgumentException("values must be the same length as weights.");
+        }
+        assert values[0].length == 1;
+        double[] processed = new double[values.length];
+        for (int i = 0; i < values.length; ++i) {
+            if (weights != null) {
+                processed[i] = weights[i] * values[i][0];
+            } else {
+                processed[i] = values[i][0];
+            }
+        }
+        return processed;
+    }
+
+    @Override
+    public double aggregate(double[] values) {
+        Objects.requireNonNull(values, "values must not be null");
+        double sum = 0.0;
+        for (double val : values) {
+            if (Double.isFinite(val)) {
+                sum += val;
+            }
+        }
+        return Math.exp(sum);
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public boolean compatibleWith(TargetType targetType) {
+        return TargetType.REGRESSION.equals(targetType);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeBoolean(weights != null);
+        if (weights != null) {
+            out.writeDoubleArray(weights);
+        }
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (weights != null) {
+            builder.field(WEIGHTS.getPreferredName(), weights);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Exponent that = (Exponent) o;
+        return Arrays.equals(weights, that.weights);
+    }
+
+    @Override
+    public int hashCode() {
+        return Arrays.hashCode(weights);
+    }
+
+    @Override
+    public long ramBytesUsed() {
+        long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
+        return SHALLOW_SIZE + weightSize;
+    }
+}

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java

@@ -80,7 +80,8 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
             categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));
         }
 
-        OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) :
+        OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ?
+            randomFrom(new WeightedSum(weights), new Exponent(weights)) :
             randomFrom(
                 new WeightedMode(
                     weights,

+ 69 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/ExponentTests.java

@@ -0,0 +1,69 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+
+import java.io.IOException;
+import java.util.stream.Stream;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.closeTo;
+
+public class ExponentTests extends WeightedAggregatorTests<Exponent> {
+
+    @Override
+    Exponent createTestInstance(int numberOfWeights) {
+        double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray();
+        return new Exponent(weights);
+    }
+
+    @Override
+    protected Exponent doParseInstance(XContentParser parser) throws IOException {
+        return lenient ? Exponent.fromXContentLenient(parser) : Exponent.fromXContentStrict(parser);
+    }
+
+    @Override
+    protected Exponent createTestInstance() {
+        return randomBoolean() ? new Exponent() : createTestInstance(randomIntBetween(1, 100));
+    }
+
+    @Override
+    protected Writeable.Reader<Exponent> instanceReader() {
+        return Exponent::new;
+    }
+
+    public void testAggregate() {
+        double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
+        double[][] values = new double[][]{
+            new double[] {.01},
+            new double[] {.2},
+            new double[] {.002},
+            new double[] {-.01},
+            new double[] {.1}
+        };
+
+        Exponent exponent = new Exponent(ones);
+        assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(1.35256, 0.00001));
+
+        double[] variedWeights = new double[]{.01, -1.0, .1, 0.0, 0.0};
+
+        exponent = new Exponent(variedWeights);
+        assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(0.81897, 0.00001));
+
+        exponent = new Exponent();
+        assertThat(exponent.aggregate(exponent.processValues(values)), closeTo(1.35256, 0.00001));
+    }
+
+    public void testCompatibleWith() {
+        Exponent exponent = createTestInstance();
+        assertThat(exponent.compatibleWith(TargetType.CLASSIFICATION), is(false));
+        assertThat(exponent.compatibleWith(TargetType.REGRESSION), is(true));
+    }
+}