Browse Source

Pipeline Inference Aggregation (#58193)

Adds a pipeline aggregation that loads a model and performs inference on the 
input aggregation results.
David Kyle 5 years ago
parent
commit
7daed3b8af
45 changed files with 2116 additions and 209 deletions
  1. 74 0
      docs/reference/aggregations/pipeline/inference-bucket-aggregation.asciidoc
  2. 18 1
      test/framework/src/main/java/org/elasticsearch/search/aggregations/BasePipelineAggregationTestCase.java
  3. 7 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  4. 19 72
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java
  5. 47 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java
  6. 3 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java
  7. 10 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java
  8. 31 7
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java
  9. 2 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java
  10. 138 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TopClassEntry.java
  11. 13 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java
  12. 26 12
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java
  13. 39 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigUpdate.java
  14. 8 8
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java
  15. 14 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java
  16. 137 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ResultsFieldUpdate.java
  17. 5 3
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java
  18. 2 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java
  19. 2 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java
  20. 54 15
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java
  21. 9 4
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java
  22. 16 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java
  23. 43 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TopClassEntryTests.java
  24. 21 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java
  25. 26 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdateTests.java
  26. 22 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdateTests.java
  27. 65 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ResultsFieldUpdateTests.java
  28. 2 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java
  29. 2 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java
  30. 1 1
      x-pack/plugin/ml/qa/ml-with-security/build.gradle
  31. 1 1
      x-pack/plugin/ml/qa/ml-with-security/roles.yml
  32. 17 1
      x-pack/plugin/ml/qa/ml-with-security/src/test/java/org/elasticsearch/smoketest/MlWithSecurityInsufficientRoleIT.java
  33. 1 1
      x-pack/plugin/ml/qa/ml-with-security/src/test/java/org/elasticsearch/smoketest/MlWithSecurityUserRoleIT.java
  34. 21 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  35. 214 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilder.java
  36. 133 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregator.java
  37. 98 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregation.java
  38. 131 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java
  39. 0 29
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java
  40. 18 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java
  41. 130 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilderTests.java
  42. 218 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregationTests.java
  43. 0 22
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java
  44. 12 11
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java
  45. 266 0
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/pipeline_inference.yml

+ 74 - 0
docs/reference/aggregations/pipeline/inference-bucket-aggregation.asciidoc

@@ -0,0 +1,74 @@
+[role="xpack"]
+[testenv="basic"]
+[[search-aggregations-pipeline-inference-bucket-aggregation]]
+=== Inference Bucket Aggregation
+
+A parent pipeline aggregation which loads a pre-trained model and performs inference on the
+collated result field from the parent bucket aggregation.
+
+[[inference-bucket-agg-syntax]]
+==== Syntax
+
+A `inference` aggregation looks like this in isolation:
+
+[source,js]
+--------------------------------------------------
+{
+    "inference": {
+        "model_id": "a_model_for_inference", <1>
+        "inference_config": { <2>
+            "regression_config": {
+                "num_top_feature_importance_values": 2
+            }
+        },
+        "buckets_path": {
+            "avg_cost": "avg_agg", <3>
+            "max_cost": "max_agg"
+        }
+    }
+}
+--------------------------------------------------
+// NOTCONSOLE
+<1> The ID of model to use.
+<2> The optional inference config which overrides the model's default settings
+<3> Map the value of `avg_agg` to the model's input field `avg_cost`
+
+[[inference-bucket-params]]
+.`inference` Parameters
+[options="header"]
+|===
+|Parameter Name |Description |Required |Default Value
+| `model_id`         | The ID of the model to load and infer against       | Required  | -
+| `inference_config` | Contains the inference type and its options. There are two types: <<inference-agg-regression-opt,`regression`>> and <<inference-agg-classification-opt,`classification`>>  | Optional | -
+| `buckets_path`     | Defines the paths to the input aggregations and maps the aggregation names to the field names expected by the model.
+See <<buckets-path-syntax>> for more details | Required       | -
+|===
+
+
+==== Configuration options for {infer} models
+The `inference_config` setting is optional and usaully isn't required as the pre-trained models come equipped with sensible defaults.
+In the context of aggregations some options can overridden for each of the 2 types of model.
+
+[discrete]
+[[inference-agg-regression-opt]]
+===== Configuration options for {regression} models
+
+`num_top_feature_importance_values`::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-regression-num-top-feature-importance-values]
+
+[discrete]
+[[inference-agg-classification-opt]]
+===== Configuration options for {classification} models
+
+`num_top_classes`::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-classes]
+
+`num_top_feature_importance_values`::
+(Optional, integer)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-num-top-feature-importance-values]
+
+`prediction_field_type`::
+(Optional, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=inference-config-classification-prediction-field-type]

+ 18 - 1
test/framework/src/main/java/org/elasticsearch/search/aggregations/BasePipelineAggregationTestCase.java

@@ -82,8 +82,11 @@ public abstract class BasePipelineAggregationTestCase<AF extends AbstractPipelin
         List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
         entries.addAll(IndicesModule.getNamedWriteables());
         entries.addAll(searchModule.getNamedWriteables());
+        entries.addAll(additionalNamedWriteables());
         namedWriteableRegistry = new NamedWriteableRegistry(entries);
-        xContentRegistry = new NamedXContentRegistry(searchModule.getNamedXContents());
+        List<NamedXContentRegistry.Entry> xContentEntries = searchModule.getNamedXContents();
+        xContentEntries.addAll(additionalNamedContents());
+        xContentRegistry = new NamedXContentRegistry(xContentEntries);
         //create some random type with some default field, those types will stick around for all of the subclasses
         currentTypes = new String[randomIntBetween(0, 5)];
         for (int i = 0; i < currentTypes.length; i++) {
@@ -99,6 +102,20 @@ public abstract class BasePipelineAggregationTestCase<AF extends AbstractPipelin
         return emptyList();
     }
 
+    /**
+     * Any extra named writeables required not registered by {@link SearchModule}
+     */
+    protected List<NamedWriteableRegistry.Entry> additionalNamedWriteables() {
+        return emptyList();
+    }
+
+    /**
+     * Any extra named xcontents required not registered by {@link SearchModule}
+     */
+    protected List<NamedXContentRegistry.Entry> additionalNamedContents() {
+        return emptyList();
+    }
+
     /**
      * Generic test that creates new AggregatorFactory from the test
      * AggregatorFactory and checks both for equality and asserts equality on

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbeddi
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@@ -20,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedInf
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
@@ -121,6 +123,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             ClassificationConfigUpdate::fromXContentStrict));
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
             RegressionConfigUpdate::fromXContentStrict));
+        namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, ResultsFieldUpdate.NAME,
+            ResultsFieldUpdate::fromXContent));
 
         // Inference models
         namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
@@ -170,6 +174,9 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
             RegressionInferenceResults.NAME,
             RegressionInferenceResults::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
+            WarningInferenceResults.NAME,
+            WarningInferenceResults::new));
 
         // Inference Configs
         namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,

+ 19 - 72
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java

@@ -6,10 +6,9 @@
 package org.elasticsearch.xpack.core.ml.inference.results;
 
 import org.elasticsearch.Version;
-import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.ingest.IngestDocument;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@@ -18,9 +17,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.Objects;
 import java.util.stream.Collectors;
 
@@ -85,6 +82,10 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
         return topClasses;
     }
 
+    public PredictionFieldType getPredictionFieldType() {
+        return predictionFieldType;
+    }
+
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         super.writeTo(out);
@@ -127,6 +128,11 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
         return classificationLabel == null ? super.valueAsString() : classificationLabel;
     }
 
+    @Override
+    public Object predictedValue() {
+        return predictionFieldType.transformPredictedValue(value(), valueAsString());
+    }
+
     @Override
     public void writeResult(IngestDocument document, String parentResultField) {
         ExceptionsHelper.requireNonNull(document, "document");
@@ -138,7 +144,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
                 topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
         }
         if (getFeatureImportance().size() > 0) {
-            document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
+            document.setFieldValue(parentResultField + "." + FEATURE_IMPORTANCE, getFeatureImportance()
                 .stream()
                 .map(FeatureImportance::toMap)
                 .collect(Collectors.toList()));
@@ -150,74 +156,15 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults
         return NAME;
     }
 
-    public static class TopClassEntry implements Writeable {
-
-        public final ParseField CLASS_NAME = new ParseField("class_name");
-        public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
-        public final ParseField CLASS_SCORE = new ParseField("class_score");
-
-        private final Object classification;
-        private final double probability;
-        private final double score;
-
-        public TopClassEntry(Object classification, double probability, double score) {
-            this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
-            this.probability = probability;
-            this.score = score;
-        }
-
-        public TopClassEntry(StreamInput in) throws IOException {
-            if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
-                this.classification = in.readGenericValue();
-            } else {
-                this.classification = in.readString();
-            }
-            this.probability = in.readDouble();
-            this.score = in.readDouble();
-        }
-
-        public Object getClassification() {
-            return classification;
-        }
-
-        public double getProbability() {
-            return probability;
-        }
-
-        public double getScore() {
-            return score;
-        }
-
-        public Map<String, Object> asValueMap() {
-            Map<String, Object> map = new HashMap<>(3, 1.0f);
-            map.put(CLASS_NAME.getPreferredName(), classification);
-            map.put(CLASS_PROBABILITY.getPreferredName(), probability);
-            map.put(CLASS_SCORE.getPreferredName(), score);
-            return map;
-        }
-
-        @Override
-        public void writeTo(StreamOutput out) throws IOException {
-            if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
-                out.writeGenericValue(classification);
-            } else {
-                out.writeString(classification.toString());
-            }
-            out.writeDouble(probability);
-            out.writeDouble(score);
-        }
-
-        @Override
-        public boolean equals(Object object) {
-            if (object == this) { return true; }
-            if (object == null || getClass() != object.getClass()) { return false; }
-            TopClassEntry that = (TopClassEntry) object;
-            return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
+        if (topClasses.size() > 0) {
+            builder.field(topNumClassesField, topClasses);
         }
-
-        @Override
-        public int hashCode() {
-            return Objects.hash(classification, probability, score);
+        if (getFeatureImportance().size() > 0) {
+            builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
         }
+        return builder;
     }
 }

+ 47 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java

@@ -5,23 +5,33 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
+import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.Objects;
 
-public class FeatureImportance implements Writeable {
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+public class FeatureImportance implements Writeable, ToXContentObject {
 
     private final Map<String, Double> classImportance;
     private final double importance;
     private final String featureName;
-    private static final String IMPORTANCE = "importance";
-    private static final String FEATURE_NAME = "feature_name";
+    static final String IMPORTANCE = "importance";
+    static final String FEATURE_NAME = "feature_name";
+    static final String CLASS_IMPORTANCE = "class_importance";
 
     public static FeatureImportance forRegression(String featureName, double importance) {
         return new FeatureImportance(featureName, importance, null);
@@ -31,7 +41,24 @@ public class FeatureImportance implements Writeable {
         return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
     }
 
-    private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
+        new ConstructingObjectParser<>("feature_importance",
+            a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
+        );
+
+    static {
+        PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
+        PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
+        PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
+            new ParseField(FeatureImportance.CLASS_IMPORTANCE));
+    }
+
+    public static FeatureImportance fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
         this.featureName = Objects.requireNonNull(featureName);
         this.importance = importance;
         this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
@@ -79,6 +106,22 @@ public class FeatureImportance implements Writeable {
         return map;
     }
 
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(FEATURE_NAME, featureName);
+        builder.field(IMPORTANCE, importance);
+        if (classImportance != null && classImportance.isEmpty() == false) {
+            builder.startObject(CLASS_IMPORTANCE);
+            for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
+                builder.field(entry.getKey(), entry.getValue());
+            }
+            builder.endObject();
+        }
+        builder.endObject();
+        return builder;
+    }
+
     @Override
     public boolean equals(Object object) {
         if (object == this) { return true; }
@@ -93,5 +136,4 @@ public class FeatureImportance implements Writeable {
     public int hashCode() {
         return Objects.hash(featureName, importance, classImportance);
     }
-
 }

+ 3 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java

@@ -6,10 +6,12 @@
 package org.elasticsearch.xpack.core.ml.inference.results;
 
 import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.common.xcontent.ToXContentFragment;
 import org.elasticsearch.ingest.IngestDocument;
 
-public interface InferenceResults extends NamedWriteable {
+public interface InferenceResults extends NamedWriteable, ToXContentFragment {
 
     void writeResult(IngestDocument document, String parentResultField);
 
+    Object predictedValue();
 }

+ 10 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.core.ml.inference.results;
 
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.ingest.IngestDocument;
 
 import java.io.IOException;
@@ -56,9 +57,18 @@ public class RawInferenceResults implements InferenceResults {
         throw new UnsupportedOperationException("[raw] does not support writing inference results");
     }
 
+    @Override
+    public Object predictedValue() {
+        return null;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME;
     }
 
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        throw new UnsupportedOperationException("[raw] does not support toXContent");
+    }
 }

+ 31 - 7
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java

@@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference.results;
 
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.ingest.IngestDocument;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
@@ -25,18 +26,28 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
     private final String resultsField;
 
     public RegressionInferenceResults(double value, InferenceConfig config) {
-        this(value, (RegressionConfig) config, Collections.emptyList());
+        this(value, config, Collections.emptyList());
     }
 
     public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
-        this(value, (RegressionConfig)config, featureImportance);
+        this(value, ((RegressionConfig)config).getResultsField(),
+            ((RegressionConfig)config).getNumTopFeatureImportanceValues(), featureImportance);
     }
 
-    private RegressionInferenceResults(double value, RegressionConfig regressionConfig, List<FeatureImportance> featureImportance) {
+    public RegressionInferenceResults(double value, String resultsField) {
+        this(value, resultsField, 0, Collections.emptyList());
+    }
+
+    public RegressionInferenceResults(double value, String resultsField,
+                                      List<FeatureImportance> featureImportance) {
+        this(value, resultsField, featureImportance.size(), featureImportance);
+    }
+
+    public RegressionInferenceResults(double value, String resultsField, int topNFeatures,
+                                       List<FeatureImportance> featureImportance) {
         super(value,
-            SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
-                regressionConfig.getNumTopFeatureImportanceValues()));
-        this.resultsField = regressionConfig.getResultsField();
+            SingleValueInferenceResults.takeTopFeatureImportances(featureImportance, topNFeatures));
+        this.resultsField = resultsField;
     }
 
     public RegressionInferenceResults(StreamInput in) throws IOException {
@@ -65,6 +76,11 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
         return Objects.hash(value(), resultsField, getFeatureImportance());
     }
 
+    @Override
+    public Object predictedValue() {
+        return super.value();
+    }
+
     @Override
     public void writeResult(IngestDocument document, String parentResultField) {
         ExceptionsHelper.requireNonNull(document, "document");
@@ -78,9 +94,17 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
         }
     }
 
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field(resultsField, value());
+        if (getFeatureImportance().size() > 0) {
+            builder.field(FEATURE_IMPORTANCE, getFeatureImportance());
+        }
+        return builder;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME;
     }
-
 }

+ 2 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java

@@ -16,6 +16,8 @@ import java.util.stream.Collectors;
 
 public abstract class SingleValueInferenceResults implements InferenceResults {
 
+    public static final String FEATURE_IMPORTANCE = "feature_importance";
+
     private final double value;
     private final List<FeatureImportance> featureImportance;
 

+ 138 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TopClassEntry.java

@@ -0,0 +1,138 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParseException;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+public class TopClassEntry implements Writeable, ToXContentObject {
+
+    public static final ParseField CLASS_NAME = new ParseField("class_name");
+    public static final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
+    public static final ParseField CLASS_SCORE = new ParseField("class_score");
+
+    public static final String NAME = "top_class";
+
+    private static final ConstructingObjectParser<TopClassEntry, Void> PARSER =
+        new ConstructingObjectParser<>(NAME, a -> new TopClassEntry(a[0], (Double) a[1], (Double) a[2]));
+
+    static {
+        PARSER.declareField(constructorArg(), (p, n) -> {
+            Object o;
+            XContentParser.Token token = p.currentToken();
+            if (token == XContentParser.Token.VALUE_STRING) {
+                o = p.text();
+            } else if (token == XContentParser.Token.VALUE_BOOLEAN) {
+                o = p.booleanValue();
+            } else if (token == XContentParser.Token.VALUE_NUMBER) {
+                o = p.doubleValue();
+            } else {
+                throw new XContentParseException(p.getTokenLocation(),
+                    "[" + NAME + "] failed to parse field [" + CLASS_NAME + "] value [" + token
+                    + "] is not a string, boolean or number");
+            }
+            return o;
+        }, CLASS_NAME, ObjectParser.ValueType.VALUE);
+        PARSER.declareDouble(constructorArg(), CLASS_PROBABILITY);
+        PARSER.declareDouble(constructorArg(), CLASS_SCORE);
+    }
+
+    public static TopClassEntry fromXContent(XContentParser parser) throws IOException {
+        return PARSER.parse(parser, null);
+    }
+
+    private final Object classification;
+    private final double probability;
+    private final double score;
+
+    public TopClassEntry(Object classification, double probability, double score) {
+        this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
+        this.probability = probability;
+        this.score = score;
+    }
+
+    public TopClassEntry(StreamInput in) throws IOException {
+        if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
+            this.classification = in.readGenericValue();
+        } else {
+            this.classification = in.readString();
+        }
+        this.probability = in.readDouble();
+        this.score = in.readDouble();
+    }
+
+    public Object getClassification() {
+        return classification;
+    }
+
+    public double getProbability() {
+        return probability;
+    }
+
+    public double getScore() {
+        return score;
+    }
+
+    public Map<String, Object> asValueMap() {
+        Map<String, Object> map = new HashMap<>(3, 1.0f);
+        map.put(CLASS_NAME.getPreferredName(), classification);
+        map.put(CLASS_PROBABILITY.getPreferredName(), probability);
+        map.put(CLASS_SCORE.getPreferredName(), score);
+        return map;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(CLASS_NAME.getPreferredName(), classification);
+        builder.field(CLASS_PROBABILITY.getPreferredName(), probability);
+        builder.field(CLASS_SCORE.getPreferredName(), score);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
+            out.writeGenericValue(classification);
+        } else {
+            out.writeString(classification.toString());
+        }
+        out.writeDouble(probability);
+        out.writeDouble(score);
+    }
+
+    @Override
+    public boolean equals(Object object) {
+        if (object == this) { return true; }
+        if (object == null || getClass() != object.getClass()) { return false; }
+        TopClassEntry that = (TopClassEntry) object;
+        return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(classification, probability, score);
+    }
+}

+ 13 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.ml.inference.results;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.ingest.IngestDocument;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
@@ -55,12 +56,22 @@ public class WarningInferenceResults implements InferenceResults {
     public void writeResult(IngestDocument document, String parentResultField) {
         ExceptionsHelper.requireNonNull(document, "document");
         ExceptionsHelper.requireNonNull(parentResultField, "resultField");
-        document.setFieldValue(parentResultField + "." + "warning", warning);
+        document.setFieldValue(parentResultField + "." + NAME, warning);
+    }
+
+    @Override
+    public Object predictedValue() {
+        return null;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field(NAME, warning);
+        return builder;
     }
 
     @Override
     public String getWriteableName() {
         return NAME;
     }
-
 }

+ 26 - 12
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdate.java

@@ -63,18 +63,18 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
             config.getPredictionFieldType());
     }
 
-    private static final ObjectParser<ClassificationConfigUpdate.Builder, Void> STRICT_PARSER = createParser(false);
+    private static final ObjectParser<Builder, Void> STRICT_PARSER = createParser(false);
 
-    private static ObjectParser<ClassificationConfigUpdate.Builder, Void> createParser(boolean lenient) {
-        ObjectParser<ClassificationConfigUpdate.Builder, Void> parser = new ObjectParser<>(
+    private static ObjectParser<Builder, Void> createParser(boolean lenient) {
+        ObjectParser<Builder, Void> parser = new ObjectParser<>(
             NAME.getPreferredName(),
             lenient,
-            ClassificationConfigUpdate.Builder::new);
-        parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopClasses, NUM_TOP_CLASSES);
-        parser.declareString(ClassificationConfigUpdate.Builder::setResultsField, RESULTS_FIELD);
-        parser.declareString(ClassificationConfigUpdate.Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
-        parser.declareInt(ClassificationConfigUpdate.Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
-        parser.declareString(ClassificationConfigUpdate.Builder::setPredictionFieldType, PREDICTION_FIELD_TYPE);
+            Builder::new);
+        parser.declareInt(Builder::setNumTopClasses, NUM_TOP_CLASSES);
+        parser.declareString(Builder::setResultsField, RESULTS_FIELD);
+        parser.declareString(Builder::setTopClassesResultsField, TOP_CLASSES_RESULTS_FIELD);
+        parser.declareInt(Builder::setNumTopFeatureImportanceValues, NUM_TOP_FEATURE_IMPORTANCE_VALUES);
+        parser.declareString(Builder::setPredictionFieldType, PREDICTION_FIELD_TYPE);
         return parser;
     }
 
@@ -96,6 +96,8 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
         }
         this.numTopFeatureImportanceValues = featureImportance;
         this.predictionFieldType = predictionFieldType;
+
+        InferenceConfigUpdate.checkFieldUniqueness(resultsField, topClassesResultsField);
     }
 
     public ClassificationConfigUpdate(StreamInput in) throws IOException {
@@ -118,6 +120,16 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
         return resultsField;
     }
 
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new Builder()
+            .setNumTopClasses(numTopClasses)
+            .setTopClassesResultsField(topClassesResultsField)
+            .setResultsField(resultsField)
+            .setNumTopFeatureImportanceValues(numTopFeatureImportanceValues)
+            .setPredictionFieldType(predictionFieldType);
+    }
+
     public Integer getNumTopFeatureImportanceValues() {
         return numTopFeatureImportanceValues;
     }
@@ -235,14 +247,14 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
             && (predictionFieldType == null || predictionFieldType.equals(originalConfig.getPredictionFieldType()));
     }
 
-    public static class Builder {
+    public static class Builder implements InferenceConfigUpdate.Builder<Builder, ClassificationConfigUpdate> {
         private Integer numTopClasses;
         private String topClassesResultsField;
         private String resultsField;
         private Integer numTopFeatureImportanceValues;
         private PredictionFieldType predictionFieldType;
 
-        public Builder setNumTopClasses(int numTopClasses) {
+        public Builder setNumTopClasses(Integer numTopClasses) {
             this.numTopClasses = numTopClasses;
             return this;
         }
@@ -252,12 +264,13 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
             return this;
         }
 
+        @Override
         public Builder setResultsField(String resultsField) {
             this.resultsField = resultsField;
             return this;
         }
 
-        public Builder setNumTopFeatureImportanceValues(int numTopFeatureImportanceValues) {
+        public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) {
             this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
             return this;
         }
@@ -271,6 +284,7 @@ public class ClassificationConfigUpdate implements InferenceConfigUpdate {
             return setPredictionFieldType(PredictionFieldType.fromString(predictionFieldType));
         }
 
+        @Override
         public ClassificationConfigUpdate build() {
             return new ClassificationConfigUpdate(numTopClasses,
                 resultsField,

+ 39 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfigUpdate.java

@@ -6,14 +6,53 @@
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
 
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+
 
 public interface InferenceConfigUpdate extends NamedXContentObject, NamedWriteable {
+    Set<String> RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList(
+        WarningInferenceResults.WARNING.getPreferredName(),
+        TrainedModelConfig.MODEL_ID.getPreferredName()));
 
     InferenceConfig apply(InferenceConfig originalConfig);
 
     InferenceConfig toConfig();
 
     boolean isSupported(InferenceConfig config);
+
+    String getResultsField();
+
+    interface Builder<T extends Builder<T, U>, U extends InferenceConfigUpdate> {
+        U build();
+        T setResultsField(String resultsField);
+    }
+
+    Builder<? extends Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder();
+
+    static void checkFieldUniqueness(String... fieldNames) {
+        Set<String> duplicatedFieldNames = new HashSet<>();
+        Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
+        for(String fieldName : fieldNames) {
+            if (fieldName == null) {
+                continue;
+            }
+            if (currentFieldNames.contains(fieldName)) {
+                duplicatedFieldNames.add(fieldName);
+            } else {
+                currentFieldNames.add(fieldName);
+            }
+        }
+        if (duplicatedFieldNames.isEmpty() == false) {
+            throw ExceptionsHelper.badRequestException("Cannot apply inference config." +
+                    " More than one field is configured as {}",
+                duplicatedFieldNames);
+        }
+    }
 }

+ 8 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

@@ -7,8 +7,8 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.collect.Tuple;
-import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.util.ArrayList;
@@ -28,11 +28,11 @@ public final class InferenceHelpers {
     /**
      * @return Tuple of the highest scored index and the top classes
      */
-    public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(double[] probabilities,
-                                                                                                List<String> classificationLabels,
-                                                                                                @Nullable double[] classificationWeights,
-                                                                                                int numToInclude,
-                                                                                                PredictionFieldType predictionFieldType) {
+    public static Tuple<Integer, List<TopClassEntry>> topClasses(double[] probabilities,
+                                                                 List<String> classificationLabels,
+                                                                 @Nullable double[] classificationWeights,
+                                                                 int numToInclude,
+                                                                 PredictionFieldType predictionFieldType) {
 
         if (classificationLabels != null && probabilities.length != classificationLabels.size()) {
             throw ExceptionsHelper
@@ -65,10 +65,10 @@ public final class InferenceHelpers {
             classificationLabels;
 
         int count = numToInclude < 0 ? probabilities.length : Math.min(numToInclude, probabilities.length);
-        List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
+        List<TopClassEntry> topClassEntries = new ArrayList<>(count);
         for(int i = 0; i < count; i++) {
             int idx = sortedIndices[i];
-            topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(
+            topClassEntries.add(new TopClassEntry(
                 predictionFieldType.transformPredictedValue((double)idx, labels.get(idx)),
                 probabilities[idx],
                 scores[idx]));

+ 14 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdate.java

@@ -18,7 +18,6 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
 
-import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.DEFAULT_RESULTS_FIELD;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.NUM_TOP_FEATURE_IMPORTANCE_VALUES;
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.RESULTS_FIELD;
 
@@ -68,6 +67,8 @@ public class RegressionConfigUpdate implements InferenceConfigUpdate {
                 "] must be greater than or equal to 0");
         }
         this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
+
+        InferenceConfigUpdate.checkFieldUniqueness(resultsField);
     }
 
     public RegressionConfigUpdate(StreamInput in) throws IOException {
@@ -75,12 +76,19 @@ public class RegressionConfigUpdate implements InferenceConfigUpdate {
         this.numTopFeatureImportanceValues = in.readOptionalVInt();
     }
 
-    public int getNumTopFeatureImportanceValues() {
-        return numTopFeatureImportanceValues == null ? 0 : numTopFeatureImportanceValues;
+    public Integer getNumTopFeatureImportanceValues() {
+        return numTopFeatureImportanceValues;
     }
 
     public String getResultsField() {
-        return resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField;
+        return resultsField;
+    }
+
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new Builder()
+            .setNumTopFeatureImportanceValues(numTopFeatureImportanceValues)
+            .setResultsField(resultsField);
     }
 
     @Override
@@ -165,10 +173,11 @@ public class RegressionConfigUpdate implements InferenceConfigUpdate {
                 || originalConfig.getNumTopFeatureImportanceValues() == numTopFeatureImportanceValues);
     }
 
-    public static class Builder {
+    public static class Builder implements InferenceConfigUpdate.Builder<Builder, RegressionConfigUpdate> {
         private String resultsField;
         private Integer numTopFeatureImportanceValues;
 
+        @Override
         public Builder setResultsField(String resultsField) {
             this.resultsField = resultsField;
             return this;

+ 137 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ResultsFieldUpdate.java

@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig.RESULTS_FIELD;
+
+/**
+ * A config update that sets the results field only.
+ * Supports any type of {@link InferenceConfig}
+ */
+public class ResultsFieldUpdate implements InferenceConfigUpdate {
+
+    public static final ParseField NAME = new ParseField("field_update");
+
+    private static final ConstructingObjectParser<ResultsFieldUpdate, Void> PARSER =
+        new ConstructingObjectParser<>(NAME.getPreferredName(), args -> new ResultsFieldUpdate((String) args[0]));
+
+    static {
+        PARSER.declareString(constructorArg(), RESULTS_FIELD);
+    }
+
+    public static ResultsFieldUpdate fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final String resultsField;
+
+    public ResultsFieldUpdate(String resultsField) {
+        this.resultsField = Objects.requireNonNull(resultsField);
+    }
+
+    public ResultsFieldUpdate(StreamInput in) throws IOException {
+        resultsField = in.readString();
+    }
+
+    @Override
+    public InferenceConfig apply(InferenceConfig originalConfig) {
+        if (originalConfig instanceof ClassificationConfig) {
+            ClassificationConfigUpdate update = new ClassificationConfigUpdate(null, resultsField, null, null, null);
+            return update.apply(originalConfig);
+        } else if (originalConfig instanceof RegressionConfig) {
+            RegressionConfigUpdate update = new RegressionConfigUpdate(resultsField, null);
+            return update.apply(originalConfig);
+        } else {
+            throw ExceptionsHelper.badRequestException(
+                "Inference config of unknown type [{}] can not be updated", originalConfig.getName());
+        }
+    }
+
+    @Override
+    public InferenceConfig toConfig() {
+        return new RegressionConfig(resultsField);
+    }
+
+    @Override
+    public boolean isSupported(InferenceConfig config) {
+        return true;
+    }
+
+    @Override
+    public String getResultsField() {
+        return resultsField;
+    }
+
+    @Override
+    public InferenceConfigUpdate.Builder<? extends InferenceConfigUpdate.Builder<?, ?>, ? extends InferenceConfigUpdate> newBuilder() {
+        return new Builder().setResultsField(resultsField);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(resultsField);
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        ResultsFieldUpdate that = (ResultsFieldUpdate) o;
+        return Objects.equals(resultsField, that.resultsField);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(resultsField);
+    }
+
+    public static class Builder implements InferenceConfigUpdate.Builder<Builder, ResultsFieldUpdate> {
+        private String resultsField;
+
+        @Override
+        public Builder setResultsField(String resultsField) {
+            this.resultsField = resultsField;
+            return this;
+        }
+
+        @Override
+        public ResultsFieldUpdate build() {
+            return new ResultsFieldUpdate(resultsField);
+        }
+    }
+}

+ 5 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
 
 import org.apache.lucene.util.RamUsageEstimator;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentParser;
@@ -14,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
@@ -85,12 +87,12 @@ public class EnsembleInferenceModel implements InferenceModel {
     private EnsembleInferenceModel(List<InferenceModel> models,
                                    OutputAggregator outputAggregator,
                                    TargetType targetType,
-                                   List<String> classificationLabels,
+                                   @Nullable List<String> classificationLabels,
                                    List<Double> classificationWeights) {
         this.models = ExceptionsHelper.requireNonNull(models, TRAINED_MODELS);
         this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT);
         this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
-        this.classificationLabels = classificationLabels == null ? null : classificationLabels;
+        this.classificationLabels = classificationLabels;
         this.classificationWeights = classificationWeights == null ?
             null :
             classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
@@ -204,7 +206,7 @@ public class EnsembleInferenceModel implements InferenceModel {
                 ClassificationConfig classificationConfig = (ClassificationConfig) config;
                 assert classificationWeights == null || processedInferences.length == classificationWeights.length;
                 // Adjust the probabilities according to the thresholds
-                Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
+                Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
                     processedInferences,
                     classificationLabels,
                     classificationWeights,

+ 2 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java

@@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInference
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
@@ -173,7 +174,7 @@ public class TreeInferenceModel implements InferenceModel {
         switch (targetType) {
             case CLASSIFICATION:
                 ClassificationConfig classificationConfig = (ClassificationConfig) config;
-                Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
+                Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
                     classificationProbability(value),
                     classificationLabels,
                     null,

+ 2 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java

@@ -16,6 +16,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
@@ -134,7 +135,7 @@ public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, Lenie
         double[] probabilities = softMax(scores);
 
         ClassificationConfig classificationConfig = (ClassificationConfig) config;
-        Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses = InferenceHelpers.topClasses(
+        Tuple<Integer, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(
             probabilities,
             LANGUAGE_NAMES,
             null,

+ 54 - 15
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.ingest.IngestDocument;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
@@ -31,21 +32,24 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
             FeatureImportanceTests::randomClassification :
             FeatureImportanceTests::randomRegression;
 
-        return new ClassificationInferenceResults(randomDouble(),
+        ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig();
+        Double value = randomDouble();
+        if (config.getPredictionFieldType() == PredictionFieldType.BOOLEAN) {
+            // value must be close to 0 or 1
+            value = randomBoolean() ? 0.0 : 1.0;
+        }
+
+        return new ClassificationInferenceResults(value,
             randomBoolean() ? null : randomAlphaOfLength(10),
             randomBoolean() ? null :
-                Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry)
+                Stream.generate(TopClassEntryTests::createRandomTopClassEntry)
                     .limit(randomIntBetween(0, 10))
                     .collect(Collectors.toList()),
             randomBoolean() ? null :
                 Stream.generate(featureImportanceCtor)
                     .limit(randomIntBetween(1, 10))
                     .collect(Collectors.toList()),
-            ClassificationConfigTests.randomClassificationConfig());
-    }
-
-    private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() {
-        return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble(), randomDouble());
+            config);
     }
 
     public void testWriteResultsWithClassificationLabel() {
@@ -70,10 +74,10 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
 
     @SuppressWarnings("unchecked")
     public void testWriteResultsWithTopClasses() {
-        List<ClassificationInferenceResults.TopClassEntry> entries = Arrays.asList(
-            new ClassificationInferenceResults.TopClassEntry("foo", 0.7, 0.7),
-            new ClassificationInferenceResults.TopClassEntry("bar", 0.2, 0.2),
-            new ClassificationInferenceResults.TopClassEntry("baz", 0.1, 0.1));
+        List<TopClassEntry> entries = Arrays.asList(
+            new TopClassEntry("foo", 0.7, 0.7),
+            new TopClassEntry("bar", 0.2, 0.2),
+            new TopClassEntry("baz", 0.1, 0.1));
         ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
             "foo",
             entries,
@@ -84,8 +88,8 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
         List<?> list = document.getFieldValue("result_field.bar", List.class);
         assertThat(list.size(), equalTo(3));
 
-        for(int i = 0; i < 3; i++) {
-            Map<String, Object> map = (Map<String, Object>)list.get(i);
+        for (int i = 0; i < 3; i++) {
+            Map<String, Object> map = (Map<String, Object>) list.get(i);
             assertThat(map, equalTo(entries.get(i).asValueMap()));
         }
 
@@ -110,11 +114,11 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
 
         assertThat(document.getFieldValue("result_field.predicted_value", String.class), equalTo("foo"));
         @SuppressWarnings("unchecked")
-        List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>)document.getFieldValue(
+        List<Map<String, Object>> writtenImportance = (List<Map<String, Object>>) document.getFieldValue(
             "result_field.feature_importance",
             List.class);
         assertThat(writtenImportance, hasSize(3));
-        importanceList.sort((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
+        importanceList.sort((l, r) -> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())));
         for (int i = 0; i < 3; i++) {
             Map<String, Object> objectMap = writtenImportance.get(i);
             FeatureImportance importance = importanceList.get(i);
@@ -135,4 +139,39 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
     protected Writeable.Reader<ClassificationInferenceResults> instanceReader() {
         return ClassificationInferenceResults::new;
     }
+
+    public void testToXContent() {
+        ClassificationConfig toStringConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.STRING);
+        ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, null, toStringConfig);
+        String stringRep = Strings.toString(result);
+        String expected = "{\"predicted_value\":\"1.0\"}";
+        assertEquals(expected, stringRep);
+
+        ClassificationConfig toDoubleConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.NUMBER);
+        result = new ClassificationInferenceResults(1.0, null, null, toDoubleConfig);
+        stringRep = Strings.toString(result);
+        expected = "{\"predicted_value\":1.0}";
+        assertEquals(expected, stringRep);
+
+        ClassificationConfig boolFieldConfig = new ClassificationConfig(1, null, null, null, PredictionFieldType.BOOLEAN);
+        result = new ClassificationInferenceResults(1.0, null, null, boolFieldConfig);
+        stringRep = Strings.toString(result);
+        expected = "{\"predicted_value\":true}";
+        assertEquals(expected, stringRep);
+
+        ClassificationConfig config = new ClassificationConfig(1);
+        result = new ClassificationInferenceResults(1.0, "label1", null, config);
+        stringRep = Strings.toString(result);
+        expected = "{\"predicted_value\":\"label1\"}";
+        assertEquals(expected, stringRep);
+
+        FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
+        TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
+        result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
+            Collections.singletonList(fi), config);
+        stringRep = Strings.toString(result);
+        expected = "{\"predicted_value\":\"label1\"," +
+            "\"top_classes\":[{\"class_name\":\"class\",\"class_probability\":1.0,\"class_score\":1.0}]}";
+        assertEquals(expected, stringRep);
+    }
 }

+ 9 - 4
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java

@@ -6,14 +6,15 @@
 package org.elasticsearch.xpack.core.ml.inference.results;
 
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
 
+import java.io.IOException;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
-
-public class FeatureImportanceTests extends AbstractWireSerializingTestCase<FeatureImportance> {
+public class FeatureImportanceTests extends AbstractSerializingTestCase<FeatureImportance> {
 
     public static FeatureImportance createRandomInstance() {
         return randomBoolean() ? randomClassification() : randomRegression();
@@ -29,7 +30,6 @@ public class FeatureImportanceTests extends AbstractWireSerializingTestCase<Feat
             Stream.generate(() -> randomAlphaOfLength(10))
                 .limit(randomLongBetween(2, 10))
                 .collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
-
     }
 
     @Override
@@ -41,4 +41,9 @@ public class FeatureImportanceTests extends AbstractWireSerializingTestCase<Feat
     protected Writeable.Reader<FeatureImportance> instanceReader() {
         return FeatureImportance::new;
     }
+
+    @Override
+    protected FeatureImportance doParseInstance(XContentParser parser) throws IOException {
+        return FeatureImportance.fromXContent(parser);
+    }
 }

+ 16 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java

@@ -5,12 +5,14 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.ingest.IngestDocument;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
 
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -75,4 +77,18 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
     protected Writeable.Reader<RegressionInferenceResults> instanceReader() {
         return RegressionInferenceResults::new;
     }
+
+    public void testToXContent() {
+        String resultsField = "ml.results";
+        RegressionInferenceResults result = new RegressionInferenceResults(1.0, resultsField);
+        String stringRep = Strings.toString(result);
+        String expected = "{\"" + resultsField + "\":1.0}";
+        assertEquals(expected, stringRep);
+
+        FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
+        result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi));
+        stringRep = Strings.toString(result);
+        expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";
+        assertEquals(expected, stringRep);
+    }
 }

+ 43 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TopClassEntryTests.java

@@ -0,0 +1,43 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+
+public class TopClassEntryTests extends AbstractSerializingTestCase<TopClassEntry> {
+
+    public static TopClassEntry createRandomTopClassEntry() {
+        Object classification;
+        if (randomBoolean()) {
+            classification = randomAlphaOfLength(10);
+        } else if (randomBoolean()) {
+            classification = randomBoolean();
+        } else {
+            classification = randomDouble();
+        }
+        return new TopClassEntry(classification, randomDouble(), randomDouble());
+    }
+
+    @Override
+    protected TopClassEntry doParseInstance(XContentParser parser) throws IOException {
+        return TopClassEntry.fromXContent(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<TopClassEntry> instanceReader() {
+        return TopClassEntry::new;
+    }
+
+    @Override
+    protected TopClassEntry createTestInstance() {
+        return createRandomTopClassEntry();
+    }
+}

+ 21 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java

@@ -5,15 +5,29 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
+import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.ingest.IngestDocument;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.test.AbstractSerializingTestCase;
 
+import java.io.IOException;
 import java.util.HashMap;
 
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
 import static org.hamcrest.Matchers.equalTo;
 
-public class WarningInferenceResultsTests extends AbstractWireSerializingTestCase<WarningInferenceResults> {
+public class WarningInferenceResultsTests extends AbstractSerializingTestCase<WarningInferenceResults> {
+
+    private static final ConstructingObjectParser<WarningInferenceResults, Void> PARSER =
+        new ConstructingObjectParser<>("inference_warning",
+            a -> new WarningInferenceResults((String) a[0])
+        );
+
+    static {
+        PARSER.declareString(constructorArg(), new ParseField(WarningInferenceResults.NAME));
+    }
 
     public static WarningInferenceResults createRandomResults() {
         return new WarningInferenceResults(randomAlphaOfLength(10));
@@ -36,4 +50,9 @@ public class WarningInferenceResultsTests extends AbstractWireSerializingTestCas
     protected Writeable.Reader<WarningInferenceResults> instanceReader() {
         return WarningInferenceResults::new;
     }
+
+    @Override
+    protected WarningInferenceResults doParseInstance(XContentParser parser) throws IOException {
+        return PARSER.apply(parser, null);
+    }
 }

+ 26 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigUpdateTests.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
@@ -18,6 +19,7 @@ import java.util.Map;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests.randomClassificationConfig;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
 
 public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTestCase<ClassificationConfigUpdate> {
 
@@ -74,6 +76,30 @@ public class ClassificationConfigUpdateTests extends AbstractBWCSerializationTes
             ));
     }
 
+    public void testDuplicateFieldNamesThrow() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> new ClassificationConfigUpdate(5, "foo", "foo", 1, PredictionFieldType.BOOLEAN));
+
+        assertEquals("Cannot apply inference config. More than one field is configured as [foo]", e.getMessage());
+    }
+
+    public void testDuplicateWithResultsField() {
+        ClassificationConfigUpdate update = randomClassificationConfigUpdate();
+        String newFieldName = update.getResultsField() + "_value";
+
+        InferenceConfigUpdate updateWithField = update.newBuilder().setResultsField(newFieldName).build();
+
+        assertNotSame(updateWithField, update);
+        assertEquals(newFieldName, updateWithField.getResultsField());
+        // other fields are the same
+        assertThat(updateWithField, instanceOf(ClassificationConfigUpdate.class));
+        ClassificationConfigUpdate classUpdate = (ClassificationConfigUpdate)updateWithField;
+        assertEquals(update.getTopClassesResultsField(), classUpdate.getTopClassesResultsField());
+        assertEquals(update.getNumTopClasses(), classUpdate.getNumTopClasses());
+        assertEquals(update.getPredictionFieldType(), classUpdate.getPredictionFieldType());
+        assertEquals(update.getNumTopFeatureImportanceValues(), classUpdate.getNumTopFeatureImportanceValues());
+    }
+
     @Override
     protected ClassificationConfigUpdate createTestInstance() {
         return randomClassificationConfigUpdate();

+ 22 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigUpdateTests.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
 
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
@@ -18,6 +19,7 @@ import java.util.Map;
 
 import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests.randomRegressionConfig;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
 
 public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCase<RegressionConfigUpdate> {
 
@@ -60,6 +62,26 @@ public class RegressionConfigUpdateTests extends AbstractBWCSerializationTestCas
             ));
     }
 
+    public void testInvalidResultFieldNotUnique() {
+        ElasticsearchStatusException e =
+            expectThrows(ElasticsearchStatusException.class, () -> new RegressionConfigUpdate("warning", 0));
+        assertEquals("Cannot apply inference config. More than one field is configured as [warning]", e.getMessage());
+    }
+
+    public void testNewBuilder() {
+        RegressionConfigUpdate update = randomRegressionConfigUpdate();
+        String newFieldName = update.getResultsField() + "_value";
+
+        InferenceConfigUpdate updateWithField = update.newBuilder().setResultsField(newFieldName).build();
+
+        assertNotSame(updateWithField, update);
+        assertEquals(newFieldName, updateWithField.getResultsField());
+        // other fields are the same
+        assertThat(updateWithField, instanceOf(RegressionConfigUpdate.class));
+        assertEquals(update.getNumTopFeatureImportanceValues(),
+            ((RegressionConfigUpdate)updateWithField).getNumTopFeatureImportanceValues());
+    }
+
     @Override
     protected RegressionConfigUpdate createTestInstance() {
         return randomRegressionConfigUpdate();

+ 65 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ResultsFieldUpdateTests.java

@@ -0,0 +1,65 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.Mockito.mock;
+
+public class ResultsFieldUpdateTests extends AbstractSerializingTestCase<ResultsFieldUpdate> {
+
+    @Override
+    protected ResultsFieldUpdate doParseInstance(XContentParser parser) throws IOException {
+        return ResultsFieldUpdate.fromXContent(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<ResultsFieldUpdate> instanceReader() {
+        return ResultsFieldUpdate::new;
+    }
+
+    @Override
+    protected ResultsFieldUpdate createTestInstance() {
+        return new ResultsFieldUpdate(randomAlphaOfLength(4));
+    }
+
+    public void testIsSupported() {
+        ResultsFieldUpdate update = new ResultsFieldUpdate("foo");
+        assertTrue(update.isSupported(mock(InferenceConfig.class)));
+    }
+
+    public void testApply_OnlyTheResultsFieldIsChanged() {
+        if (randomBoolean()) {
+            ClassificationConfig config = ClassificationConfigTests.randomClassificationConfig();
+            String newResultsField = config.getResultsField() + "foobar";
+            ResultsFieldUpdate update = new ResultsFieldUpdate(newResultsField);
+            InferenceConfig applied = update.apply(config);
+
+            assertThat(applied, instanceOf(ClassificationConfig.class));
+            ClassificationConfig appliedConfig = (ClassificationConfig)applied;
+            assertEquals(newResultsField, appliedConfig.getResultsField());
+
+            assertEquals(appliedConfig, new ClassificationConfig.Builder(config).setResultsField(newResultsField).build());
+        } else {
+            RegressionConfig config = RegressionConfigTests.randomRegressionConfig();
+            String newResultsField = config.getResultsField() + "foobar";
+            ResultsFieldUpdate update = new ResultsFieldUpdate(newResultsField);
+            InferenceConfig applied = update.apply(config);
+
+            assertThat(applied, instanceOf(RegressionConfig.class));
+            RegressionConfig appliedConfig = (RegressionConfig)applied;
+            assertEquals(newResultsField, appliedConfig.getResultsField());
+
+            assertEquals(appliedConfig, new RegressionConfig.Builder(config).setResultsField(newResultsField).build());
+        }
+    }
+}

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
@@ -140,7 +141,7 @@ public class EnsembleInferenceModelTests extends ESTestCase {
         List<Double> expected = Arrays.asList(0.768524783, 0.231475216);
         List<Double> scores   = Arrays.asList(0.230557435, 0.162032651);
         double eps = 0.000001;
-        List<ClassificationInferenceResults.TopClassEntry> probabilities =
+        List<TopClassEntry> probabilities =
             ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
                 .getTopClasses();
         for(int i = 0; i < expected.size(); i++) {

+ 2 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
@@ -182,7 +183,7 @@ public class TreeInferenceModelTests extends ESTestCase {
         List<Double> expectedProbs = Arrays.asList(1.0, 0.0);
         List<String> expectedFields = Arrays.asList("dog", "cat");
         Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
-        List<ClassificationInferenceResults.TopClassEntry> probabilities =
+        List<TopClassEntry> probabilities =
             ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2), Collections.emptyMap()))
                 .getTopClasses();
         for(int i = 0; i < expectedProbs.size(); i++) {

+ 1 - 1
x-pack/plugin/ml/qa/ml-with-security/build.gradle

@@ -11,7 +11,7 @@ dependencies {
 // bring in machine learning rest test suite
 restResources {
   restApi {
-    includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest'
+    includeCore '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'count', 'ingest', 'bulk'
     includeXpack 'ml', 'cat'
   }
   restTests {

+ 1 - 1
x-pack/plugin/ml/qa/ml-with-security/roles.yml

@@ -7,7 +7,7 @@ minimal:
     # Give all users involved in these tests access to the indices where the data to
     # be analyzed is stored, because the ML roles alone do not provide access to
     # non-ML indices
-    - names: [ 'airline-data', 'index-*', 'unavailable-data', 'utopia' ]
+    - names: [ 'airline-data', 'index-*', 'unavailable-data', 'utopia', 'store' ]
       privileges:
         - create_index
         - indices:admin/refresh

+ 17 - 1
x-pack/plugin/ml/qa/ml-with-security/src/test/java/org/elasticsearch/smoketest/MlWithSecurityInsufficientRoleIT.java

@@ -8,6 +8,8 @@ package org.elasticsearch.smoketest;
 import com.carrotsearch.randomizedtesting.annotations.Name;
 
 import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate;
+import org.elasticsearch.test.rest.yaml.section.DoSection;
+import org.elasticsearch.test.rest.yaml.section.ExecutableSection;
 
 import java.io.IOException;
 
@@ -16,8 +18,11 @@ import static org.hamcrest.Matchers.either;
 
 public class MlWithSecurityInsufficientRoleIT extends MlWithSecurityIT {
 
+    private final ClientYamlTestCandidate testCandidate;
+
     public MlWithSecurityInsufficientRoleIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {
         super(testCandidate);
+        this.testCandidate = testCandidate;
     }
 
     @Override
@@ -26,7 +31,18 @@ public class MlWithSecurityInsufficientRoleIT extends MlWithSecurityIT {
             // Cannot use expectThrows here because blacklisted tests will throw an
             // InternalAssumptionViolatedException rather than an AssertionError
             super.test();
-            fail("should have failed because of missing role");
+
+            // We should have got here if and only if no ML endpoints were called
+            for (ExecutableSection section : testCandidate.getTestSection().getExecutableSections()) {
+                if (section instanceof DoSection) {
+                    String apiName = ((DoSection) section).getApiCallSection().getApi();
+
+                    if (apiName.startsWith("ml.")) {
+                        fail("call to ml endpoint should have failed because of missing role");
+                    }
+                }
+            }
+
         } catch (AssertionError ae) {
             // Some tests assert on searches of wildcarded ML indices rather than on ML endpoints.  For these we expect no hits.
             if (ae.getMessage().contains("hits.total didn't match expected value")) {

+ 1 - 1
x-pack/plugin/ml/qa/ml-with-security/src/test/java/org/elasticsearch/smoketest/MlWithSecurityUserRoleIT.java

@@ -47,7 +47,7 @@ public class MlWithSecurityUserRoleIT extends MlWithSecurityIT {
                 if (section instanceof DoSection) {
                     String apiName = ((DoSection) section).getApiCallSection().getApi();
 
-                    if (((DoSection) section).getApiCallSection().getApi().startsWith("ml.") && isAllowed(apiName) == false) {
+                    if (apiName.startsWith("ml.") && isAllowed(apiName) == false) {
                         fail("should have failed because of missing role");
                     }
                 }

+ 21 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -37,6 +37,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.common.xcontent.ContextParser;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.NodeEnvironment;
@@ -57,11 +58,13 @@ import org.elasticsearch.plugins.CircuitBreakerPlugin;
 import org.elasticsearch.plugins.IngestPlugin;
 import org.elasticsearch.plugins.PersistentTaskPlugin;
 import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.plugins.SearchPlugin;
 import org.elasticsearch.plugins.SystemIndexPlugin;
 import org.elasticsearch.repositories.RepositoriesService;
 import org.elasticsearch.rest.RestController;
 import org.elasticsearch.rest.RestHandler;
 import org.elasticsearch.script.ScriptService;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
 import org.elasticsearch.threadpool.ExecutorBuilder;
 import org.elasticsearch.threadpool.ScalingExecutorBuilder;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -233,6 +236,8 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationP
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
 import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
+import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder;
+import org.elasticsearch.xpack.ml.inference.aggs.InternalInferenceAggregation;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
@@ -351,7 +356,8 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
                                                        AnalysisPlugin,
                                                        CircuitBreakerPlugin,
                                                        IngestPlugin,
-                                                       PersistentTaskPlugin {
+                                                       PersistentTaskPlugin,
+                                                       SearchPlugin {
     public static final String NAME = "ml";
     public static final String BASE_PATH = "/_ml/";
     public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
@@ -468,6 +474,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
     private final SetOnce<MlMemoryTracker> memoryTracker = new SetOnce<>();
     private final SetOnce<ActionFilter> mlUpgradeModeActionFilter = new SetOnce<>();
     private final SetOnce<CircuitBreaker> inferenceModelBreaker = new SetOnce<>();
+    private final SetOnce<ModelLoadingService> modelLoadingService = new SetOnce<>();
 
     public MachineLearning(Settings settings, Path configPath) {
         this.settings = settings;
@@ -696,6 +703,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
             settings,
             clusterService.getNodeName(),
             inferenceModelBreaker.get());
+        this.modelLoadingService.set(modelLoadingService);
 
         // Data frame analytics components
         AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
@@ -971,6 +979,18 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
         return Collections.singletonMap(MlClassicTokenizer.NAME, MlClassicTokenizerFactory::new);
     }
 
+    @Override
+    public List<PipelineAggregationSpec> getPipelineAggregations() {
+        PipelineAggregationSpec spec = new PipelineAggregationSpec(InferencePipelineAggregationBuilder.NAME,
+            in -> new InferencePipelineAggregationBuilder(in, modelLoadingService),
+            (ContextParser<String, ? extends PipelineAggregationBuilder>)
+                (parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, name, parser
+                ));
+        spec.addResultReader(InternalInferenceAggregation::new);
+
+        return Collections.singletonList(spec);
+    }
+
     @Override
     public UnaryOperator<Map<String, IndexTemplateMetadata>> getIndexTemplateMetadataUpgrader() {
         return UnaryOperator.identity();

+ 214 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilder.java

@@ -0,0 +1,214 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.inference.aggs;
+
+import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.LatchedActionListener;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.collect.Tuple;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
+import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
+import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+import java.util.TreeMap;
+import java.util.concurrent.CountDownLatch;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {
+
+    public static String NAME = "inference";
+
+    public static final ParseField MODEL_ID = new ParseField("model_id");
+    private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
+
+    static String AGGREGATIONS_RESULTS_FIELD = "value";
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<InferencePipelineAggregationBuilder,
+        Tuple<SetOnce<ModelLoadingService>, String>> PARSER = new ConstructingObjectParser<>(
+        NAME, false,
+        (args, context) -> new InferencePipelineAggregationBuilder(context.v2(), context.v1(), (Map<String, String>) args[0])
+    );
+
+    static {
+        PARSER.declareObject(constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD);
+        PARSER.declareString(InferencePipelineAggregationBuilder::setModelId, MODEL_ID);
+        PARSER.declareNamedObject(InferencePipelineAggregationBuilder::setInferenceConfig,
+            (p, c, n) -> p.namedObject(InferenceConfigUpdate.class, n, c), INFERENCE_CONFIG);
+    }
+
+    private final Map<String, String> bucketPathMap;
+    private String modelId;
+    private InferenceConfigUpdate inferenceConfig;
+    private final SetOnce<ModelLoadingService> modelLoadingService;
+
+    public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService,
+                                                            String pipelineAggregatorName,
+                                                            XContentParser parser) {
+        Tuple<SetOnce<ModelLoadingService>, String> context = new Tuple<>(modelLoadingService, pipelineAggregatorName);
+        return PARSER.apply(parser, context);
+    }
+
+    public InferencePipelineAggregationBuilder(String name, SetOnce<ModelLoadingService> modelLoadingService,
+                                               Map<String, String> bucketsPath) {
+        super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
+        this.modelLoadingService = modelLoadingService;
+        this.bucketPathMap = bucketsPath;
+    }
+
+    public InferencePipelineAggregationBuilder(StreamInput in, SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
+        super(in, NAME);
+        modelId = in.readString();
+        bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
+        inferenceConfig = in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
+        this.modelLoadingService = modelLoadingService;
+    }
+
+    void setModelId(String modelId) {
+        this.modelId = modelId;
+    }
+
+    void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
+        this.inferenceConfig = inferenceConfig;
+    }
+
+    @Override
+    protected void validate(ValidationContext context) {
+        context.validateHasParent(NAME, name);
+        if (modelId == null) {
+            context.addValidationError("[model_id] must be set");
+        }
+
+        if (inferenceConfig != null) {
+            // error if the results field is set and not equal to the only acceptable value
+            String resultsField = inferenceConfig.getResultsField();
+            if (Strings.isNullOrEmpty(resultsField) == false && AGGREGATIONS_RESULTS_FIELD.equals(resultsField) == false) {
+                context.addValidationError("setting option [" + ClassificationConfig.RESULTS_FIELD.getPreferredName()
+                    + "] to [" + resultsField + "] is not valid for inference aggregations");
+            }
+
+            if (inferenceConfig instanceof ClassificationConfigUpdate) {
+                ClassificationConfigUpdate classUpdate = (ClassificationConfigUpdate)inferenceConfig;
+
+                // error if the top classes result field is set and not equal to the only acceptable value
+                String topClassesField = classUpdate.getTopClassesResultsField();
+                if (Strings.isNullOrEmpty(topClassesField) == false &&
+                    ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD.equals(topClassesField) == false) {
+                    context.addValidationError("setting option [" + ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD
+                        + "] to [" + topClassesField + "] is not valid for inference aggregations");
+                }
+            }
+        }
+    }
+
+    @Override
+    protected void doWriteTo(StreamOutput out) throws IOException {
+        out.writeString(modelId);
+        out.writeMap(bucketPathMap, StreamOutput::writeString, StreamOutput::writeString);
+        out.writeOptionalNamedWriteable(inferenceConfig);
+    }
+
+    @Override
+    protected PipelineAggregator createInternal(Map<String, Object> metaData) {
+
+        SetOnce<LocalModel> model = new SetOnce<>();
+        SetOnce<Exception> error = new SetOnce<>();
+        CountDownLatch latch = new CountDownLatch(1);
+        ActionListener<LocalModel> listener = new LatchedActionListener<>(
+            ActionListener.wrap(model::set, error::set), latch);
+
+        modelLoadingService.get().getModelForSearch(modelId, listener);
+        try {
+            // TODO Avoid the blocking wait
+            latch.await();
+        } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            throw new RuntimeException("Inference aggregation interrupted loading model", e);
+        }
+
+        Exception e = error.get();
+        if (e != null) {
+            if (e instanceof RuntimeException) {
+                throw (RuntimeException)e;
+            } else {
+                throw new RuntimeException(error.get());
+            }
+        }
+
+        InferenceConfigUpdate update = adaptForAggregation(inferenceConfig);
+
+        return new InferencePipelineAggregator(name, bucketPathMap, metaData, update, model.get());
+    }
+
+    static InferenceConfigUpdate adaptForAggregation(InferenceConfigUpdate originalUpdate) {
+        InferenceConfigUpdate updated;
+        if (originalUpdate == null) {
+            updated = new ResultsFieldUpdate(AGGREGATIONS_RESULTS_FIELD);
+        } else {
+            // Create an update that changes the default results field.
+            // This isn't necessary for top classes as the default is the same one used here
+            updated = originalUpdate.newBuilder().setResultsField(AGGREGATIONS_RESULTS_FIELD).build();
+        }
+
+        return updated;
+    }
+
+    @Override
+    protected boolean overrideBucketsPath() {
+        return true;
+    }
+
+    @Override
+    protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field(MODEL_ID.getPreferredName(), modelId);
+        builder.field(BUCKETS_PATH_FIELD.getPreferredName(), bucketPathMap);
+        if (inferenceConfig != null) {
+            builder.startObject(INFERENCE_CONFIG.getPreferredName());
+            builder.field(inferenceConfig.getName(), inferenceConfig);
+            builder.endObject();
+        }
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(super.hashCode(), bucketPathMap, modelId, inferenceConfig);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) return true;
+        if (obj == null || getClass() != obj.getClass()) return false;
+        if (super.equals(obj) == false) return false;
+
+        InferencePipelineAggregationBuilder other = (InferencePipelineAggregationBuilder) obj;
+        return Objects.equals(bucketPathMap, other.bucketPathMap)
+            && Objects.equals(modelId, other.modelId)
+            && Objects.equals(inferenceConfig, other.inferenceConfig);
+    }
+}

+ 133 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregator.java

@@ -0,0 +1,133 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.inference.aggs;
+
+import org.elasticsearch.search.aggregations.AggregationExecutionException;
+import org.elasticsearch.search.aggregations.InternalAggregation;
+import org.elasticsearch.search.aggregations.InternalAggregations;
+import org.elasticsearch.search.aggregations.InternalMultiBucketAggregation;
+import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
+import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
+import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation;
+import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.search.aggregations.support.AggregationPath;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
+import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+
+public class InferencePipelineAggregator extends PipelineAggregator {
+
+    private final Map<String, String> bucketPathMap;
+    private final InferenceConfigUpdate configUpdate;
+    private final LocalModel model;
+
+    public InferencePipelineAggregator(String name, Map<String,
+                                       String> bucketPathMap,
+                                       Map<String, Object> metaData,
+                                       InferenceConfigUpdate configUpdate,
+                                       LocalModel model) {
+        super(name, bucketPathMap.values().toArray(new String[] {}), metaData);
+        this.bucketPathMap = bucketPathMap;
+        this.configUpdate = configUpdate;
+        this.model = model;
+    }
+
+    @SuppressWarnings({"rawtypes", "unchecked"})
+    @Override
+    public InternalAggregation reduce(InternalAggregation aggregation, InternalAggregation.ReduceContext reduceContext) {
+
+        InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg =
+            (InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket>) aggregation;
+        List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = originalAgg.getBuckets();
+
+        List<InternalMultiBucketAggregation.InternalBucket> newBuckets = new ArrayList<>();
+        for (InternalMultiBucketAggregation.InternalBucket bucket : buckets) {
+            Map<String, Object> inputFields = new HashMap<>();
+
+            if (bucket.getDocCount() == 0) {
+                // ignore this empty bucket unless the doc count is used
+                if (bucketPathMap.containsKey("_count") == false) {
+                    newBuckets.add(bucket);
+                    continue;
+                }
+            }
+
+            for (Map.Entry<String, String> entry : bucketPathMap.entrySet()) {
+                String aggName = entry.getKey();
+                String bucketPath = entry.getValue();
+                Object propertyValue = resolveBucketValue(originalAgg, bucket, bucketPath);
+
+                if (propertyValue instanceof Number) {
+                    double doubleVal = ((Number) propertyValue).doubleValue();
+                    // NaN or infinite values indicate a missing value or a
+                    // valid result of an invalid calculation. Either way only
+                    // a valid number will do
+                    if (Double.isFinite(doubleVal)) {
+                        inputFields.put(aggName, doubleVal);
+                    }
+                } else if (propertyValue instanceof InternalNumericMetricsAggregation.SingleValue) {
+                    double doubleVal = ((InternalNumericMetricsAggregation.SingleValue) propertyValue).value();
+                    if (Double.isFinite(doubleVal)) {
+                        inputFields.put(aggName, doubleVal);
+                    }
+                } else if (propertyValue instanceof StringTerms.Bucket) {
+                    StringTerms.Bucket b = (StringTerms.Bucket) propertyValue;
+                    inputFields.put(aggName, b.getKeyAsString());
+                } else if (propertyValue instanceof String) {
+                    inputFields.put(aggName, propertyValue);
+                } else if (propertyValue != null) {
+                    // Doubles, String terms or null are valid, any other type is an error
+                    throw invalidAggTypeError(bucketPath, propertyValue);
+                }
+            }
+
+
+            InferenceResults inference;
+            try {
+                inference = model.infer(inputFields, configUpdate);
+            } catch (Exception e) {
+                inference = new WarningInferenceResults(e.getMessage());
+            }
+
+            final List<InternalAggregation> aggs = bucket.getAggregations().asList().stream().map(
+                (p) -> (InternalAggregation) p).collect(Collectors.toList());
+
+            InternalInferenceAggregation aggResult = new InternalInferenceAggregation(name(), metadata(), inference);
+            aggs.add(aggResult);
+            InternalMultiBucketAggregation.InternalBucket newBucket = originalAgg.createBucket(InternalAggregations.from(aggs), bucket);
+            newBuckets.add(newBucket);
+        }
+
+        return originalAgg.create(newBuckets);
+    }
+
+    public static Object resolveBucketValue(MultiBucketsAggregation agg,
+                                            InternalMultiBucketAggregation.InternalBucket bucket,
+                                            String aggPath) {
+
+        List<String> aggPathsList = AggregationPath.parse(aggPath).getPathElementsAsStringList();
+        return bucket.getProperty(agg.getName(), aggPathsList);
+    }
+
+    private static AggregationExecutionException invalidAggTypeError(String aggPath, Object propertyValue) {
+
+        String msg = AbstractPipelineAggregationBuilder.BUCKETS_PATH_FIELD.getPreferredName() +
+            " must reference either a number value, a single value numeric metric aggregation or a string: got [" +
+            propertyValue + "] of type [" + propertyValue.getClass().getSimpleName() + "] " +
+            "] at aggregation [" + aggPath + "]";
+        return new AggregationExecutionException(msg);
+    }
+}

+ 98 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregation.java

@@ -0,0 +1,98 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.inference.aggs;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.search.aggregations.InternalAggregation;
+import org.elasticsearch.search.aggregations.InvalidAggregationPathException;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+public class InternalInferenceAggregation extends InternalAggregation {
+
+    private final InferenceResults inferenceResult;
+
+    protected InternalInferenceAggregation(String name, Map<String, Object> metadata,
+                                           InferenceResults inferenceResult) {
+        super(name, metadata);
+        this.inferenceResult = inferenceResult;
+    }
+
+    public InternalInferenceAggregation(StreamInput in) throws IOException {
+        super(in);
+        inferenceResult = in.readNamedWriteable(InferenceResults.class);
+    }
+
+    public InferenceResults getInferenceResult() {
+        return inferenceResult;
+    }
+
+    @Override
+    protected void doWriteTo(StreamOutput out) throws IOException {
+        out.writeNamedWriteable(inferenceResult);
+    }
+
+    @Override
+    public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
+        throw new UnsupportedOperationException("Reducing an inference aggregation is not supported");
+    }
+
+
+    @Override
+    public Object getProperty(List<String> path) {
+        Object propertyValue;
+
+        if (path.isEmpty()) {
+            propertyValue = this;
+        } else if (path.size() == 1) {
+            if (CommonFields.VALUE.getPreferredName().equals(path.get(0))) {
+                propertyValue = inferenceResult.predictedValue();
+            } else {
+                throw invalidPathException(path);
+            }
+        } else {
+            throw invalidPathException(path);
+        }
+
+        return propertyValue;
+    }
+
+    private InvalidAggregationPathException invalidPathException(List<String> path) {
+        return new InvalidAggregationPathException("unknown property " +  path + " for " +
+            InferencePipelineAggregationBuilder.NAME + " aggregation [" + getName() + "]");
+    }
+
+    @Override
+    public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
+        return inferenceResult.toXContent(builder, params);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return "inference";
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(super.hashCode(), inferenceResult);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) return true;
+        if (obj == null || getClass() != obj.getClass()) return false;
+        if (super.equals(obj) == false) return false;
+        InternalInferenceAggregation other = (InternalInferenceAggregation) obj;
+        return Objects.equals(inferenceResult, other.inferenceResult);
+    }
+}

+ 131 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java

@@ -0,0 +1,131 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.inference.aggs;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParseException;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.ParsedAggregation;
+import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
+import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+import static org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults.FEATURE_IMPORTANCE;
+
+
+/**
+ * There isn't enough information in toXContent representation of the
+ * {@link org.elasticsearch.xpack.core.ml.inference.results.InferenceResults}
+ * objects to fully reconstruct them. In particular, depending on which
+ * fields are written (result value, feature importance) it is not possible to
+ * distinguish between a Regression result and a Classification result.
+ *
+ * This class parses the union all possible fields that may be written by
+ * InferenceResults.
+ *
+ * The warning field is mutually exclusive with all the other fields.
+ */
+public class ParsedInference extends ParsedAggregation {
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
+        new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
+            args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
+                (List<TopClassEntry>) args[2], (String) args[3]));
+
+    static {
+        PARSER.declareField(optionalConstructorArg(), (p, n) -> {
+            Object o;
+            XContentParser.Token token = p.currentToken();
+            if (token == XContentParser.Token.VALUE_STRING) {
+                o = p.text();
+            } else if (token == XContentParser.Token.VALUE_BOOLEAN) {
+                o = p.booleanValue();
+            } else if (token == XContentParser.Token.VALUE_NUMBER) {
+                o = p.doubleValue();
+            } else {
+                throw new XContentParseException(p.getTokenLocation(),
+                    "[" + ParsedInference.class.getSimpleName() + "] failed to parse field [" + CommonFields.VALUE + "] "
+                        + "value [" + token + "] is not a string, boolean or number");
+            }
+            return o;
+        }, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
+        PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p),
+            new ParseField(SingleValueInferenceResults.FEATURE_IMPORTANCE));
+        PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p),
+            new ParseField(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD));
+        PARSER.declareString(optionalConstructorArg(), new ParseField(WarningInferenceResults.NAME));
+        declareAggregationFields(PARSER);
+    }
+
+    public static ParsedInference fromXContent(XContentParser parser, final String name) {
+        ParsedInference parsed = PARSER.apply(parser, null);
+        parsed.setName(name);
+        return parsed;
+    }
+
+    private final Object value;
+    private final List<FeatureImportance> featureImportance;
+    private final List<TopClassEntry> topClasses;
+    private final String warning;
+
+    ParsedInference(Object value,
+                    List<FeatureImportance> featureImportance,
+                    List<TopClassEntry> topClasses,
+                    String warning) {
+        this.value = value;
+        this.warning = warning;
+        this.featureImportance = featureImportance;
+        this.topClasses = topClasses;
+    }
+
+    public Object getValue() {
+        return value;
+    }
+
+    public List<FeatureImportance> getFeatureImportance() {
+        return featureImportance;
+    }
+
+    public List<TopClassEntry> getTopClasses() {
+        return topClasses;
+    }
+
+    public String getWarning() {
+        return warning;
+    }
+
+    @Override
+    protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
+        if (warning != null) {
+            builder.field(WarningInferenceResults.WARNING.getPreferredName(), warning);
+        } else {
+            builder.field(CommonFields.VALUE.getPreferredName(), value);
+            if (topClasses != null && topClasses.size() > 0) {
+                builder.field(ClassificationConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD, topClasses);
+            }
+            if (featureImportance != null && featureImportance.size() > 0) {
+                builder.field(FEATURE_IMPORTANCE, featureImportance);
+            }
+        }
+        return builder;
+    }
+
+    @Override
+    public String getType() {
+        return InferencePipelineAggregationBuilder.NAME;
+    }
+}

+ 0 - 29
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

@@ -28,7 +28,6 @@ import org.elasticsearch.ingest.PipelineConfiguration;
 import org.elasticsearch.ingest.Processor;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
-import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
@@ -42,10 +41,8 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 
 import java.util.Arrays;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
@@ -172,10 +169,6 @@ public class InferenceProcessor extends AbstractProcessor {
         private static final int MAX_INFERENCE_PROCESSOR_SEARCH_RECURSIONS = 10;
         private static final Logger logger = LogManager.getLogger(Factory.class);
 
-        private static final Set<String> RESERVED_ML_FIELD_NAMES = new HashSet<>(Arrays.asList(
-            WarningInferenceResults.WARNING.getPreferredName(),
-            MODEL_ID));
-
         private final Client client;
         private final InferenceAuditor auditor;
         private volatile int currentInferenceProcessors;
@@ -333,12 +326,10 @@ public class InferenceProcessor extends AbstractProcessor {
             if (inferenceConfig.containsKey(ClassificationConfig.NAME.getPreferredName())) {
                 checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
                 ClassificationConfigUpdate config = ClassificationConfigUpdate.fromMap(valueMap);
-                checkFieldUniqueness(config.getResultsField(), config.getTopClassesResultsField());
                 return config;
             } else if (inferenceConfig.containsKey(RegressionConfig.NAME.getPreferredName())) {
                 checkSupportedVersion(RegressionConfig.EMPTY_PARAMS);
                 RegressionConfigUpdate config = RegressionConfigUpdate.fromMap(valueMap);
-                checkFieldUniqueness(config.getResultsField());
                 return config;
             } else {
                 throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
@@ -347,26 +338,6 @@ public class InferenceProcessor extends AbstractProcessor {
             }
         }
 
-        private static void checkFieldUniqueness(String... fieldNames) {
-            Set<String> duplicatedFieldNames = new HashSet<>();
-            Set<String> currentFieldNames = new HashSet<>(RESERVED_ML_FIELD_NAMES);
-            for(String fieldName : fieldNames) {
-                if (fieldName == null) {
-                    continue;
-                }
-                if (currentFieldNames.contains(fieldName)) {
-                    duplicatedFieldNames.add(fieldName);
-                } else {
-                    currentFieldNames.add(fieldName);
-                }
-            }
-            if (duplicatedFieldNames.isEmpty() == false) {
-                throw ExceptionsHelper.badRequestException("Cannot create processor as configured." +
-                        " More than one field is configured as {}",
-                    duplicatedFieldNames);
-            }
-        }
-
         void checkSupportedVersion(InferenceConfig config) {
             if (config.getMinimalSupportedVersion().after(minNodeVersion)) {
                 throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION,

+ 18 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java

@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.concurrent.atomic.LongAdder;
 
 import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_WARNING_ALL_FIELDS_MISSING;
@@ -78,7 +79,7 @@ public class LocalModel {
             persistenceQuotient = 10_000;
         }
     }
-    
+
     public void infer(Map<String, Object> fields, InferenceConfigUpdate update, ActionListener<InferenceResults> listener) {
         if (update.isSupported(this.inferenceConfig) == false) {
             listener.onFailure(ExceptionsHelper.badRequestException(
@@ -116,6 +117,22 @@ public class LocalModel {
         }
     }
 
+    public InferenceResults infer(Map<String, Object> fields, InferenceConfigUpdate update) throws Exception {
+        AtomicReference<InferenceResults> result = new AtomicReference<>();
+        AtomicReference<Exception> exception = new AtomicReference<>();
+        ActionListener<InferenceResults> listener = ActionListener.wrap(
+            result::set,
+            exception::set
+        );
+
+        infer(fields, update, listener);
+        if (exception.get() != null) {
+            throw exception.get();
+        }
+
+        return result.get();
+    }
+
     /**
      * Used for translating field names in according to the passed `fieldMappings` parameter.
      *

+ 130 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InferencePipelineAggregationBuilderTests.java

@@ -0,0 +1,130 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.inference.aggs;
+
+import org.apache.lucene.util.SetOnce;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
+import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class InferencePipelineAggregationBuilderTests extends BasePipelineAggregationTestCase<InferencePipelineAggregationBuilder> {
+
+    private static final String NAME = "inf-agg";
+
+    @Override
+    protected List<SearchPlugin> plugins() {
+        return Collections.singletonList(new MachineLearning(Settings.EMPTY, null));
+    }
+
+    @Override
+    protected List<NamedXContentRegistry.Entry> additionalNamedContents() {
+        return new MlInferenceNamedXContentProvider().getNamedXContentParsers();
+    }
+
+    @Override
+    protected List<NamedWriteableRegistry.Entry> additionalNamedWriteables() {
+        return new MlInferenceNamedXContentProvider().getNamedWriteables();
+    }
+
+    @Override
+    protected InferencePipelineAggregationBuilder createTestAggregatorFactory() {
+        Map<String, String> bucketPaths = Stream.generate(() -> randomAlphaOfLength(8))
+            .limit(randomIntBetween(1, 4))
+            .collect(Collectors.toMap(Function.identity(), (t) -> randomAlphaOfLength(5)));
+
+        InferencePipelineAggregationBuilder builder =
+            new InferencePipelineAggregationBuilder(NAME, new SetOnce<>(mock(ModelLoadingService.class)), bucketPaths);
+        builder.setModelId(randomAlphaOfLength(6));
+
+        if (randomBoolean()) {
+            InferenceConfigUpdate config;
+            if (randomBoolean()) {
+                config = ClassificationConfigUpdateTests.randomClassificationConfigUpdate();
+            } else {
+                config = RegressionConfigUpdateTests.randomRegressionConfigUpdate();
+            }
+            builder.setInferenceConfig(config);
+        }
+
+        return builder;
+    }
+
+    public void testAdaptForAggregation_givenNull() {
+        InferenceConfigUpdate update = InferencePipelineAggregationBuilder.adaptForAggregation(null);
+        assertThat(update, is(instanceOf(ResultsFieldUpdate.class)));
+        assertEquals(InferencePipelineAggregationBuilder.AGGREGATIONS_RESULTS_FIELD, update.getResultsField());
+    }
+
+    public void testAdaptForAggregation() {
+        RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate(null, 20);
+        InferenceConfigUpdate update = InferencePipelineAggregationBuilder.adaptForAggregation(regressionConfigUpdate);
+        assertEquals(InferencePipelineAggregationBuilder.AGGREGATIONS_RESULTS_FIELD, update.getResultsField());
+
+        ClassificationConfigUpdate configUpdate = new ClassificationConfigUpdate(1, null, null, null, null);
+        update = InferencePipelineAggregationBuilder.adaptForAggregation(configUpdate);
+        assertEquals(InferencePipelineAggregationBuilder.AGGREGATIONS_RESULTS_FIELD, update.getResultsField());
+    }
+
+    public void testValidate() {
+        InferencePipelineAggregationBuilder aggregationBuilder = createTestAggregatorFactory();
+        PipelineAggregationBuilder.ValidationContext validationContext =
+            PipelineAggregationBuilder.ValidationContext.forInsideTree(mock(AggregationBuilder.class), null);
+
+        aggregationBuilder.setModelId(null);
+        aggregationBuilder.validate(validationContext);
+        List<String> errors = validationContext.getValidationException().validationErrors();
+        assertEquals("[model_id] must be set", errors.get(0));
+    }
+
+    public void testValidate_invalidResultsField() {
+        InferencePipelineAggregationBuilder aggregationBuilder = createTestAggregatorFactory();
+        PipelineAggregationBuilder.ValidationContext validationContext =
+            PipelineAggregationBuilder.ValidationContext.forInsideTree(mock(AggregationBuilder.class), null);
+
+        RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", null);
+        aggregationBuilder.setInferenceConfig(regressionConfigUpdate);
+        aggregationBuilder.validate(validationContext);
+        List<String> errors = validationContext.getValidationException().validationErrors();
+        assertEquals("setting option [results_field] to [foo] is not valid for inference aggregations", errors.get(0));
+    }
+
+    public void testValidate_invalidTopClassesField() {
+        InferencePipelineAggregationBuilder aggregationBuilder = createTestAggregatorFactory();
+        PipelineAggregationBuilder.ValidationContext validationContext =
+            PipelineAggregationBuilder.ValidationContext.forInsideTree(mock(AggregationBuilder.class), null);
+
+        ClassificationConfigUpdate configUpdate = new ClassificationConfigUpdate(1, null, "some_other_field", null, null);
+        aggregationBuilder.setInferenceConfig(configUpdate);
+        aggregationBuilder.validate(validationContext);
+        List<String> errors = validationContext.getValidationException().validationErrors();
+        assertEquals("setting option [top_classes] to [some_other_field] is not valid for inference aggregations", errors.get(0));
+    }
+}

+ 218 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/InternalInferenceAggregationTests.java

@@ -0,0 +1,218 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.ml.inference.aggs;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.search.aggregations.Aggregation;
+import org.elasticsearch.search.aggregations.InvalidAggregationPathException;
+import org.elasticsearch.search.aggregations.ParsedAggregation;
+import org.elasticsearch.test.InternalAggregationTestCase;
+import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
+import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
+import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
+import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
+import org.elasticsearch.xpack.ml.MachineLearning;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Predicate;
+
+import static org.hamcrest.Matchers.sameInstance;
+
+public class InternalInferenceAggregationTests extends InternalAggregationTestCase<InternalInferenceAggregation> {
+
+    @Override
+    protected SearchPlugin registerPlugin() {
+        return new MachineLearning(Settings.EMPTY, null);
+    }
+
+    @Override
+    protected List<NamedXContentRegistry.Entry> getNamedXContents() {
+        List<NamedXContentRegistry.Entry> entries = new ArrayList<>(super.getNamedXContents());
+        entries.add(new NamedXContentRegistry.Entry(Aggregation.class,
+            new ParseField(InferencePipelineAggregationBuilder.NAME), (p, c) -> ParsedInference.fromXContent(p, (String)c)));
+
+        return entries;
+    }
+
+    @Override
+    protected Predicate<String> excludePathsFromXContentInsertion() {
+        return p -> p.contains("top_classes") || p.contains("feature_importance");
+    }
+
+    @Override
+    protected InternalInferenceAggregation createTestInstance(String name, Map<String, Object> metadata) {
+        InferenceResults result;
+
+        if (randomBoolean()) {
+            // build a random result with the result field set to `value`
+            ClassificationInferenceResults randomResults = ClassificationInferenceResultsTests.createRandomResults();
+            result = new ClassificationInferenceResults(
+                randomResults.value(),
+                randomResults.getClassificationLabel(),
+                randomResults.getTopClasses(),
+                randomResults.getFeatureImportance(),
+                new ClassificationConfig(null, "value", null, null, randomResults.getPredictionFieldType())
+            );
+        } else if (randomBoolean()) {
+            // build a random result with the result field set to `value`
+            RegressionInferenceResults randomResults = RegressionInferenceResultsTests.createRandomResults();
+            result = new RegressionInferenceResults(
+                randomResults.value(),
+                "value",
+                randomResults.getFeatureImportance());
+        } else {
+            result = new WarningInferenceResults("this is a warning");
+        }
+
+        return new InternalInferenceAggregation(name, metadata, result);
+    }
+
+    @Override
+    public void testReduceRandom() {
+        expectThrows(UnsupportedOperationException.class, () -> createTestInstance("name", null).reduce(null, null));
+    }
+
+    @Override
+    protected void assertReduced(InternalInferenceAggregation reduced, List<InternalInferenceAggregation> inputs) {
+        // no test since reduce operation is unsupported
+    }
+
+    @Override
+    protected void assertFromXContent(InternalInferenceAggregation agg, ParsedAggregation parsedAggregation) {
+        ParsedInference parsed = ((ParsedInference) parsedAggregation);
+
+        InferenceResults result = agg.getInferenceResult();
+        if (result instanceof WarningInferenceResults) {
+            WarningInferenceResults warning = (WarningInferenceResults) result;
+            assertEquals(warning.getWarning(), parsed.getWarning());
+        } else if (result instanceof RegressionInferenceResults) {
+            RegressionInferenceResults regression = (RegressionInferenceResults) result;
+            assertEquals(regression.value(), parsed.getValue());
+            List<FeatureImportance> featureImportance = regression.getFeatureImportance();
+            if (featureImportance.isEmpty()) {
+                featureImportance = null;
+            }
+            assertEquals(featureImportance, parsed.getFeatureImportance());
+        } else if (result instanceof ClassificationInferenceResults) {
+            ClassificationInferenceResults classification = (ClassificationInferenceResults) result;
+            assertEquals(classification.predictedValue(), parsed.getValue());
+
+            List<FeatureImportance> featureImportance = classification.getFeatureImportance();
+            if (featureImportance.isEmpty()) {
+                featureImportance = null;
+            }
+            assertEquals(featureImportance, parsed.getFeatureImportance());
+
+            List<TopClassEntry> topClasses = classification.getTopClasses();
+            if (topClasses.isEmpty()) {
+                topClasses = null;
+            }
+            assertEquals(topClasses, parsed.getTopClasses());
+        }
+    }
+
+    public void testGetProperty_givenEmptyPath() {
+        InternalInferenceAggregation internalAgg = createTestInstance();
+        assertThat(internalAgg, sameInstance(internalAgg.getProperty(Collections.emptyList())));
+    }
+
+    public void testGetProperty_givenTooLongPath() {
+        InternalInferenceAggregation internalAgg = createTestInstance();
+        InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
+            () -> internalAgg.getProperty(Arrays.asList("one", "two")));
+
+        String message = "unknown property [one, two] for inference aggregation [" + internalAgg.getName() + "]";
+        assertEquals(message, e.getMessage());
+    }
+
+    public void testGetProperty_givenWrongPath() {
+        InternalInferenceAggregation internalAgg = createTestInstance();
+        InvalidAggregationPathException e = expectThrows(InvalidAggregationPathException.class,
+            () -> internalAgg.getProperty(Collections.singletonList("bar")));
+
+        String message = "unknown property [bar] for inference aggregation [" + internalAgg.getName() + "]";
+        assertEquals(message, e.getMessage());
+    }
+
+    public void testGetProperty_value() {
+        {
+            ClassificationInferenceResults results = ClassificationInferenceResultsTests.createRandomResults();
+            InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
+            assertEquals(results.predictedValue(), internalAgg.getProperty(Collections.singletonList("value")));
+        }
+
+        {
+            RegressionInferenceResults results = RegressionInferenceResultsTests.createRandomResults();
+            InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
+            assertEquals(results.value(), internalAgg.getProperty(Collections.singletonList("value")));
+        }
+
+        {
+            WarningInferenceResults results = new WarningInferenceResults("a warning from history");
+            InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
+            assertNull(internalAgg.getProperty(Collections.singletonList("value")));
+        }
+    }
+
+    public void testGetProperty_featureImportance() {
+        {
+            ClassificationInferenceResults results = ClassificationInferenceResultsTests.createRandomResults();
+            InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
+            expectThrows(InvalidAggregationPathException.class,
+                () -> internalAgg.getProperty(Collections.singletonList("feature_importance")));
+        }
+
+        {
+            RegressionInferenceResults results = RegressionInferenceResultsTests.createRandomResults();
+            InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
+            expectThrows(InvalidAggregationPathException.class,
+                () -> internalAgg.getProperty(Collections.singletonList("feature_importance")));
+        }
+
+        {
+            WarningInferenceResults results = new WarningInferenceResults("a warning from history");
+            InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
+            expectThrows(InvalidAggregationPathException.class,
+                () -> internalAgg.getProperty(Collections.singletonList("feature_importance")));
+        }
+    }
+
+    public void testGetProperty_topClasses() {
+        {
+            ClassificationInferenceResults results = ClassificationInferenceResultsTests.createRandomResults();
+            InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
+            expectThrows(InvalidAggregationPathException.class,
+                () -> internalAgg.getProperty(Collections.singletonList("top_classes")));
+        }
+
+        {
+            RegressionInferenceResults results = RegressionInferenceResultsTests.createRandomResults();
+            InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
+            expectThrows(InvalidAggregationPathException.class,
+                () -> internalAgg.getProperty(Collections.singletonList("top_classes")));
+        }
+
+        {
+            WarningInferenceResults results = new WarningInferenceResults("a warning from history");
+            InternalInferenceAggregation internalAgg = new InternalInferenceAggregation("foo", Collections.emptyMap(), results);
+            expectThrows(InvalidAggregationPathException.class,
+                () -> internalAgg.getProperty(Collections.singletonList("top_classes")));
+        }
+    }
+}

+ 0 - 22
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java

@@ -272,28 +272,6 @@ public class InferenceProcessorFactoryTests extends ESTestCase {
         }
     }
 
-    public void testCreateProcessorWithDuplicateFields() {
-        InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
-            clusterService,
-            Settings.EMPTY);
-
-        Map<String, Object> regression = new HashMap<>() {{
-            put(InferenceProcessor.FIELD_MAP, Collections.emptyMap());
-            put(InferenceProcessor.MODEL_ID, "my_model");
-            put(InferenceProcessor.TARGET_FIELD, "ml");
-            put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME.getPreferredName(),
-                Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning")));
-        }};
-
-        try {
-            processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression);
-            fail("should not have succeeded creating with duplicate fields");
-        } catch (Exception ex) {
-            assertThat(ex.getMessage(), equalTo("Cannot create processor as configured. " +
-                "More than one field is configured as [warning]"));
-        }
-    }
-
     private static ClusterState buildClusterState(Metadata metadata) {
        return ClusterState.builder(new ClusterName("_name")).metadata(metadata).build();
     }

+ 12 - 11
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
+import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
@@ -92,9 +93,9 @@ public class InferenceProcessorTests extends ESTestCase {
         Map<String, Object> ingestMetadata = new HashMap<>();
         IngestDocument document = new IngestDocument(source, ingestMetadata);
 
-        List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
-        classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
-        classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
+        List<TopClassEntry> classes = new ArrayList<>(2);
+        classes.add(new TopClassEntry("foo", 0.6, 0.6));
+        classes.add(new TopClassEntry("bar", 0.4, 0.4));
 
         InternalInferModelAction.Response response = new InternalInferModelAction.Response(
             Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
@@ -102,7 +103,7 @@ public class InferenceProcessorTests extends ESTestCase {
         inferenceProcessor.mutateDocument(response, document);
 
         assertThat((List<Map<?,?>>)document.getFieldValue("ml.my_processor.top_classes", List.class),
-            contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
+            contains(classes.stream().map(TopClassEntry::asValueMap).toArray(Map[]::new)));
         assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
         assertThat(document.getFieldValue("ml.my_processor.predicted_value", String.class), equalTo("foo"));
     }
@@ -122,9 +123,9 @@ public class InferenceProcessorTests extends ESTestCase {
         Map<String, Object> ingestMetadata = new HashMap<>();
         IngestDocument document = new IngestDocument(source, ingestMetadata);
 
-        List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
-        classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
-        classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
+        List<TopClassEntry> classes = new ArrayList<>(2);
+        classes.add(new TopClassEntry("foo", 0.6, 0.6));
+        classes.add(new TopClassEntry("bar", 0.4, 0.4));
 
         List<FeatureImportance> featureInfluence = new ArrayList<>();
         featureInfluence.add(FeatureImportance.forRegression("feature_1", 1.13));
@@ -163,9 +164,9 @@ public class InferenceProcessorTests extends ESTestCase {
         Map<String, Object> ingestMetadata = new HashMap<>();
         IngestDocument document = new IngestDocument(source, ingestMetadata);
 
-        List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
-        classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6, 0.6));
-        classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4, 0.4));
+        List<TopClassEntry> classes = new ArrayList<>(2);
+        classes.add(new TopClassEntry("foo", 0.6, 0.6));
+        classes.add(new TopClassEntry("bar", 0.4, 0.4));
 
         InternalInferModelAction.Response response = new InternalInferModelAction.Response(
             Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes, classificationConfig)),
@@ -173,7 +174,7 @@ public class InferenceProcessorTests extends ESTestCase {
         inferenceProcessor.mutateDocument(response, document);
 
         assertThat((List<Map<?,?>>)document.getFieldValue("ml.my_processor.tops", List.class),
-            contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
+            contains(classes.stream().map(TopClassEntry::asValueMap).toArray(Map[]::new)));
         assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("classification_model"));
         assertThat(document.getFieldValue("ml.my_processor.result", String.class), equalTo("foo"));
     }

+ 266 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/pipeline_inference.yml

@@ -0,0 +1,266 @@
+setup:
+  - skip:
+      features: headers
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      ml.put_trained_model:
+        model_id: a-complex-regression-model
+        body: >
+          {
+            "description": "super complex model for tests",
+            "input": {"field_names": ["avg_cost", "item"]},
+            "inference_config": {
+              "regression": {
+                "results_field": "regression-value",
+                "num_top_feature_importance_values": 2
+              }
+            },
+            "definition": {
+              "preprocessors" : [{
+                "one_hot_encoding": {
+                  "field": "product_type",
+                  "hot_map": {
+                    "TV": "type_tv",
+                    "VCR": "type_vcr",
+                    "Laptop": "type_laptop"
+                  }
+                }
+              }],
+              "trained_model": {
+                "ensemble": {
+                  "feature_names": [],
+                  "target_type": "regression",
+                  "trained_models": [
+                  {
+                    "tree": {
+                      "feature_names": [
+                        "avg_cost", "type_tv", "type_vcr", "type_laptop"
+                      ],
+                      "tree_structure": [
+                      {
+                        "node_index": 0,
+                        "split_feature": 0,
+                        "split_gain": 12,
+                        "threshold": 38,
+                        "decision_type": "lte",
+                        "default_left": true,
+                        "left_child": 1,
+                        "right_child": 2
+                      },
+                      {
+                        "node_index": 1,
+                        "leaf_value": 5.0
+                      },
+                      {
+                        "node_index": 2,
+                        "leaf_value": 2.0
+                      }
+                      ],
+                      "target_type": "regression"
+                    }
+                  }
+                  ]
+                }
+              }
+            }
+          }
+
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      indices.create:
+        index: store
+        body:
+          mappings:
+            properties:
+              product:
+                type: keyword
+              cost:
+                type: integer
+              time:
+                type: date
+
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+        Content-Type: application/json
+      bulk:
+        index: store
+        refresh: true
+        body: |
+          { "index": {} }
+          { "product": "TV", "cost": 300, "time": 1587501233000 }
+          { "index": {} }
+          { "product": "TV", "cost": 400, "time": 1587501233000}
+          { "index": {} }
+          { "product": "VCR", "cost": 150, "time": 1587501233000 }
+          { "index": {} }
+          { "product": "VCR", "cost": 180, "time": 1587501233000 }
+          { "index": {} }
+          { "product": "Laptop", "cost": 15000, "time": 1587501233000 }
+
+---
+"Test pipeline regression simple":
+
+  - do:
+      search:
+        index: store
+        body: |
+          {
+            "size": 0,
+            "aggs": {
+              "good": {
+                "terms": {
+                  "field": "product",
+                  "size": 10
+                },
+                "aggs": {
+                  "avg_cost_agg": {
+                    "avg": {
+                      "field": "cost"
+                    }
+                  },
+                  "regression_agg": {
+                    "inference": {
+                      "model_id": "a-complex-regression-model",
+                      "inference_config": {
+                        "regression": {
+                          "results_field": "value"
+                        }
+                      },
+                      "buckets_path": {
+                        "avg_cost": "avg_cost_agg"
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+  - match: { aggregations.good.buckets.0.regression_agg.value: 2.0 }
+---
+"Test pipeline agg referencing a single bucket":
+
+  - do:
+      search:
+        index: store
+        body: |
+            {
+              "size": 0,
+              "query": {
+                "match_all": {}
+              },
+              "aggs": {
+                "date_histo": {
+                  "date_histogram": {
+                    "field": "time",
+                    "fixed_interval": "1d"
+                  },
+                  "aggs": {
+                    "good": {
+                      "terms": {
+                        "field": "product",
+                        "size": 10
+                      },
+                      "aggs": {
+                        "avg_cost_agg": {
+                          "avg": {
+                            "field": "cost"
+                          }
+                        }
+                      }
+                    },
+                    "regression_agg": {
+                      "inference": {
+                        "model_id": "a-complex-regression-model",
+                        "buckets_path": {
+                          "avg_cost": "good['TV']>avg_cost_agg",
+                          "product_type": "good['TV']"
+                        }
+                      }
+                    }
+                  }
+                }
+              }
+            }
+  - match: { aggregations.date_histo.buckets.0.regression_agg.value: 2.0 }
+
+---
+"Test all fields missing warning":
+
+  - do:
+      search:
+        index: store
+        body: |
+          {
+            "size": 0,
+            "query": { "match_all" : { } },
+            "aggs": {
+              "good": {
+                "terms": {
+                  "field": "product",
+                  "size": 10
+                },
+                "aggs": {
+                  "avg_cost_agg": {
+                    "avg": {
+                      "field": "cost"
+                    }
+                  },
+                  "regression_agg" : {
+                    "inference": {
+                      "model_id": "a-complex-regression-model",
+                      "buckets_path": {
+                        "cost" : "avg_cost_agg"
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+  - match: { aggregations.good.buckets.0.regression_agg.warning: "Model [a-complex-regression-model] could not be inferred as all fields were missing" }
+
+---
+"Test setting results field is invalid":
+
+  - do:
+      catch: /action_request_validation_exception/
+      search:
+        index: store
+        body: |
+          {
+            "size": 0,
+            "query": { "match_all" : { } },
+            "aggs": {
+              "good": {
+                "terms": {
+                  "field": "product",
+                  "size": 10
+                },
+                "aggs": {
+                  "avg_cost_agg": {
+                    "avg": {
+                      "field": "cost"
+                    }
+                  },
+                  "regression_agg" : {
+                    "inference": {
+                      "model_id": "a-complex-regression-model",
+                      "inference_config": {
+                        "regression": {
+                          "results_field": "banana"
+                        }
+                      },
+                      "buckets_path": {
+                        "cost" : "avg_cost_agg"
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+  - match: { error.root_cause.0.type: "action_request_validation_exception" }
+  - match: { error.root_cause.0.reason: "Validation Failed: 1: setting option [results_field] to [banana] is not valid for inference aggregations;" }