Browse Source

Add Inference Pipeline aggregation to HLRC (#59086)

Adds InferencePipelineAggregationBuilder to the HLRC duplicating 
the server side classes
David Kyle 5 years ago
parent
commit
49f9431fe7

+ 3 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java

@@ -54,6 +54,8 @@ import org.elasticsearch.action.search.SearchScrollRequest;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.action.update.UpdateRequest;
 import org.elasticsearch.action.update.UpdateResponse;
+import org.elasticsearch.client.analytics.InferencePipelineAggregationBuilder;
+import org.elasticsearch.client.analytics.ParsedInference;
 import org.elasticsearch.client.analytics.ParsedStringStats;
 import org.elasticsearch.client.analytics.ParsedTopMetrics;
 import org.elasticsearch.client.analytics.StringStatsAggregationBuilder;
@@ -1957,6 +1959,7 @@ public class RestHighLevelClient implements Closeable {
         map.put(CompositeAggregationBuilder.NAME, (p, c) -> ParsedComposite.fromXContent(p, (String) c));
         map.put(StringStatsAggregationBuilder.NAME, (p, c) -> ParsedStringStats.PARSER.parse(p, (String) c));
         map.put(TopMetricsAggregationBuilder.NAME, (p, c) -> ParsedTopMetrics.PARSER.parse(p, (String) c));
+        map.put(InferencePipelineAggregationBuilder.NAME, (p, c) -> ParsedInference.fromXContent(p, (String ) (c)));
         List<NamedXContentRegistry.Entry> entries = map.entrySet().stream()
                 .map(entry -> new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(entry.getKey()), entry.getValue()))
                 .collect(Collectors.toList());

+ 141 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/analytics/InferencePipelineAggregationBuilder.java

@@ -0,0 +1,141 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.analytics;
+
+import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+import java.util.TreeMap;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+/**
+ * For building inference pipeline aggregations
+ *
+ * NOTE: This extends {@linkplain AbstractPipelineAggregationBuilder} for compatibility
+ * with {@link SearchSourceBuilder#aggregation(PipelineAggregationBuilder)} but it
+ * doesn't support any "server" side things like {@linkplain #doWriteTo(StreamOutput)}
+ * or {@linkplain #createInternal(Map)}
+ */
+public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {
+
+    public static String NAME = "inference";
+
+    public static final ParseField MODEL_ID = new ParseField("model_id");
+    private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
+
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, String> PARSER = new ConstructingObjectParser<>(
+        NAME, false,
+        (args, name) -> new InferencePipelineAggregationBuilder(name, (String)args[0], (Map<String, String>) args[1])
+    );
+
+    static {
+        PARSER.declareString(constructorArg(), MODEL_ID);
+        PARSER.declareObject(constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD);
+        PARSER.declareNamedObject(InferencePipelineAggregationBuilder::setInferenceConfig,
+            (p, c, n) -> p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG);
+    }
+
+    private final Map<String, String> bucketPathMap;
+    private final String modelId;
+    private InferenceConfig inferenceConfig;
+
+    public static InferencePipelineAggregationBuilder parse(String pipelineAggregatorName,
+                                                            XContentParser parser) {
+        return PARSER.apply(parser, pipelineAggregatorName);
+    }
+
+    public InferencePipelineAggregationBuilder(String name, String modelId, Map<String, String> bucketsPath) {
+        super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
+        this.modelId = modelId;
+        this.bucketPathMap = bucketsPath;
+    }
+
+    public void setInferenceConfig(InferenceConfig inferenceConfig) {
+        this.inferenceConfig = inferenceConfig;
+    }
+
+    @Override
+    protected void validate(ValidationContext context) {
+        // validation occurs on the server
+    }
+
+    @Override
+    protected void doWriteTo(StreamOutput out) {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    protected PipelineAggregator createInternal(Map<String, Object> metaData) {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    protected boolean overrideBucketsPath() {
+        return true;
+    }
+
+    @Override
+    protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field(MODEL_ID.getPreferredName(), modelId);
+        builder.field(BUCKETS_PATH_FIELD.getPreferredName(), bucketPathMap);
+        if (inferenceConfig != null) {
+            builder.startObject(INFERENCE_CONFIG.getPreferredName());
+            builder.field(inferenceConfig.getName(), inferenceConfig);
+            builder.endObject();
+        }
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(super.hashCode(), bucketPathMap, modelId, inferenceConfig);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) return true;
+        if (obj == null || getClass() != obj.getClass()) return false;
+        if (super.equals(obj) == false) return false;
+
+        InferencePipelineAggregationBuilder other = (InferencePipelineAggregationBuilder) obj;
+        return Objects.equals(bucketPathMap, other.bucketPathMap)
+            && Objects.equals(modelId, other.modelId)
+            && Objects.equals(inferenceConfig, other.inferenceConfig);
+    }
+}

+ 137 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/analytics/ParsedInference.java

@@ -0,0 +1,137 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.analytics;
+
+import org.elasticsearch.client.ml.inference.results.FeatureImportance;
+import org.elasticsearch.client.ml.inference.results.TopClassEntry;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParseException;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.aggregations.ParsedAggregation;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * This class parses the superset of all possible fields that may be written by
+ * InferenceResults. The warning field is mutually exclusive with all the other fields.
+ *
+ * In the case of classification results {@link #getValue()} may return a String,
+ * Boolean or a Double. For regression results {@link #getValue()} is always
+ * a Double.
+ */
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+public class ParsedInference extends ParsedAggregation {
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
+        new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
+            args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
+                (List<TopClassEntry>) args[2], (String) args[3]));
+
+    public static final ParseField FEATURE_IMPORTANCE = new ParseField("feature_importance");
+    public static final ParseField WARNING = new ParseField("warning");
+    public static final ParseField TOP_CLASSES = new ParseField("top_classes");
+
+    static {
+        PARSER.declareField(optionalConstructorArg(), (p, n) -> {
+            Object o;
+            XContentParser.Token token = p.currentToken();
+            if (token == XContentParser.Token.VALUE_STRING) {
+                o = p.text();
+            } else if (token == XContentParser.Token.VALUE_BOOLEAN) {
+                o = p.booleanValue();
+            } else if (token == XContentParser.Token.VALUE_NUMBER) {
+                o = p.doubleValue();
+            } else {
+                throw new XContentParseException(p.getTokenLocation(),
+                    "[" + ParsedInference.class.getSimpleName() + "] failed to parse field [" + CommonFields.VALUE + "] "
+                        + "value [" + token + "] is not a string, boolean or number");
+            }
+            return o;
+        }, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
+        PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p), FEATURE_IMPORTANCE);
+        PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p), TOP_CLASSES);
+        PARSER.declareString(optionalConstructorArg(), WARNING);
+        declareAggregationFields(PARSER);
+    }
+
+    public static ParsedInference fromXContent(XContentParser parser, final String name) {
+        ParsedInference parsed = PARSER.apply(parser, null);
+        parsed.setName(name);
+        return parsed;
+    }
+
+    private final Object value;
+    private final List<FeatureImportance> featureImportance;
+    private final List<TopClassEntry> topClasses;
+    private final String warning;
+
+    ParsedInference(Object value,
+                    List<FeatureImportance> featureImportance,
+                    List<TopClassEntry> topClasses,
+                    String warning) {
+        this.value = value;
+        this.warning = warning;
+        this.featureImportance = featureImportance;
+        this.topClasses = topClasses;
+    }
+
+    public Object getValue() {
+        return value;
+    }
+
+    public List<FeatureImportance> getFeatureImportance() {
+        return featureImportance;
+    }
+
+    public List<TopClassEntry> getTopClasses() {
+        return topClasses;
+    }
+
+    public String getWarning() {
+        return warning;
+    }
+
+    @Override
+    protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
+        if (warning != null) {
+            builder.field(WARNING.getPreferredName(), warning);
+        } else {
+            builder.field(CommonFields.VALUE.getPreferredName(), value);
+            if (topClasses != null && topClasses.size() > 0) {
+                builder.field(TOP_CLASSES.getPreferredName(), topClasses);
+            }
+            if (featureImportance != null && featureImportance.size() > 0) {
+                builder.field(FEATURE_IMPORTANCE.getPreferredName(), featureImportance);
+            }
+        }
+        return builder;
+    }
+
+    @Override
+    public String getType() {
+        return InferencePipelineAggregationBuilder.NAME;
+    }
+}

+ 112 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/FeatureImportance.java

@@ -0,0 +1,112 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference.results;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+public class FeatureImportance implements ToXContentObject {
+
+    public static final String IMPORTANCE = "importance";
+    public static final String FEATURE_NAME = "feature_name";
+    public static final String CLASS_IMPORTANCE = "class_importance";
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
+        new ConstructingObjectParser<>("feature_importance", true,
+            a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
+        );
+
+    static {
+        PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
+        PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
+        PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
+            new ParseField(FeatureImportance.CLASS_IMPORTANCE));
+    }
+
+    public static FeatureImportance fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final Map<String, Double> classImportance;
+    private final double importance;
+    private final String featureName;
+
+    public FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
+        this.featureName = Objects.requireNonNull(featureName);
+        this.importance = importance;
+        this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
+    }
+
+    public Map<String, Double> getClassImportance() {
+        return classImportance;
+    }
+
+    public double getImportance() {
+        return importance;
+    }
+
+    public String getFeatureName() {
+        return featureName;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(FEATURE_NAME, featureName);
+        builder.field(IMPORTANCE, importance);
+        if (classImportance != null && classImportance.isEmpty() == false) {
+            builder.startObject(CLASS_IMPORTANCE);
+            for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
+                builder.field(entry.getKey(), entry.getValue());
+            }
+            builder.endObject();
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object object) {
+        if (object == this) { return true; }
+        if (object == null || getClass() != object.getClass()) { return false; }
+        FeatureImportance that = (FeatureImportance) object;
+        return Objects.equals(featureName, that.featureName)
+            && Objects.equals(importance, that.importance)
+            && Objects.equals(classImportance, that.classImportance);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(featureName, importance, classImportance);
+    }
+}

+ 116 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/results/TopClassEntry.java

@@ -0,0 +1,116 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference.results;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParseException;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+public class TopClassEntry implements ToXContentObject {
+
+    public static final ParseField CLASS_NAME = new ParseField("class_name");
+    public static final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
+    public static final ParseField CLASS_SCORE = new ParseField("class_score");
+
+    public static final String NAME = "top_class";
+
+    private static final ConstructingObjectParser<TopClassEntry, Void> PARSER =
+        new ConstructingObjectParser<>(NAME, true, a -> new TopClassEntry(a[0], (Double) a[1], (Double) a[2]));
+
+    static {
+        PARSER.declareField(constructorArg(), (p, n) -> {
+            Object o;
+            XContentParser.Token token = p.currentToken();
+            if (token == XContentParser.Token.VALUE_STRING) {
+                o = p.text();
+            } else if (token == XContentParser.Token.VALUE_BOOLEAN) {
+                o = p.booleanValue();
+            } else if (token == XContentParser.Token.VALUE_NUMBER) {
+                o = p.doubleValue();
+            } else {
+                throw new XContentParseException(p.getTokenLocation(),
+                    "[" + NAME + "] failed to parse field [" + CLASS_NAME + "] value [" + token
+                        + "] is not a string, boolean or number");
+            }
+            return o;
+        }, CLASS_NAME, ObjectParser.ValueType.VALUE);
+        PARSER.declareDouble(constructorArg(), CLASS_PROBABILITY);
+        PARSER.declareDouble(constructorArg(), CLASS_SCORE);
+    }
+
+    public static TopClassEntry fromXContent(XContentParser parser) throws IOException {
+        return PARSER.parse(parser, null);
+    }
+
+    private final Object classification;
+    private final double probability;
+    private final double score;
+
+    public TopClassEntry(Object classification, double probability, double score) {
+        this.classification = Objects.requireNonNull(classification);
+        this.probability = probability;
+        this.score = score;
+    }
+
+    public Object getClassification() {
+        return classification;
+    }
+
+    public double getProbability() {
+        return probability;
+    }
+
+    public double getScore() {
+        return score;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        builder.field(CLASS_NAME.getPreferredName(), classification);
+        builder.field(CLASS_PROBABILITY.getPreferredName(), probability);
+        builder.field(CLASS_SCORE.getPreferredName(), score);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object object) {
+        if (object == this) { return true; }
+        if (object == null || getClass() != object.getClass()) { return false; }
+        TopClassEntry that = (TopClassEntry) object;
+        return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(classification, probability, score);
+    }
+}

+ 1 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -688,6 +688,7 @@ public class RestHighLevelClientTests extends ESTestCase {
         // Explicitly check for metrics from the analytics module because they aren't in InternalAggregationTestCase
         assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("string_stats")));
         assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("top_metrics")));
+        assertTrue(namedXContents.removeIf(e -> e.name.getPreferredName().equals("inference")));
 
         assertEquals(expectedInternalAggregations + expectedSuggestions, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();

+ 127 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/analytics/InferenceAggIT.java

@@ -0,0 +1,127 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.analytics;
+
+import org.elasticsearch.action.bulk.BulkRequest;
+import org.elasticsearch.action.index.IndexRequest;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.support.WriteRequest;
+import org.elasticsearch.client.ESRestHighLevelClientTestCase;
+import org.elasticsearch.client.RequestOptions;
+import org.elasticsearch.client.indices.CreateIndexRequest;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
+import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
+import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeNode;
+import org.elasticsearch.common.xcontent.XContentType;
+import org.elasticsearch.search.aggregations.bucket.terms.ParsedTerms;
+import org.elasticsearch.search.aggregations.bucket.terms.Terms;
+import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
+import org.elasticsearch.search.aggregations.metrics.AvgAggregationBuilder;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.closeTo;
+import static org.hamcrest.Matchers.equalTo;
+
+public class InferenceAggIT extends ESRestHighLevelClientTestCase {
+
+    public void testInferenceAgg() throws IOException {
+
+        // create a very simple decision tree with a root node and 2 leaves
+        List<String> featureNames = Collections.singletonList("cost");
+        Tree.Builder builder = Tree.builder();
+        builder.setFeatureNames(featureNames);
+        TreeNode.Builder root = builder.addJunction(0, 0, true, 1.0);
+        int leftChild = root.getLeftChild();
+        int rightChild = root.getRightChild();
+        builder.addLeaf(leftChild, 10.0);
+        builder.addLeaf(rightChild, 20.0);
+
+        final String modelId = "simple_regression";
+        putTrainedModel(modelId, featureNames, builder.build());
+
+        final String index = "inference-test-data";
+        indexData(index);
+
+        TermsAggregationBuilder termsAgg = new TermsAggregationBuilder("fruit_type").field("fruit");
+        AvgAggregationBuilder avgAgg = new AvgAggregationBuilder("avg_cost").field("cost");
+        termsAgg.subAggregation(avgAgg);
+
+        Map<String, String> bucketPaths = new HashMap<>();
+        bucketPaths.put("cost", "avg_cost");
+        InferencePipelineAggregationBuilder inferenceAgg = new InferencePipelineAggregationBuilder("infer", modelId,  bucketPaths);
+        termsAgg.subAggregation(inferenceAgg);
+
+        SearchRequest search = new SearchRequest(index);
+        search.source().aggregation(termsAgg);
+        SearchResponse response = highLevelClient().search(search, RequestOptions.DEFAULT);
+        ParsedTerms terms = response.getAggregations().get("fruit_type");
+        List<? extends Terms.Bucket> buckets = terms.getBuckets();
+        {
+            assertThat(buckets.get(0).getKey(), equalTo("apple"));
+            ParsedInference inference = buckets.get(0).getAggregations().get("infer");
+            assertThat((Double) inference.getValue(), closeTo(20.0, 0.01));
+            assertNull(inference.getWarning());
+            assertNull(inference.getFeatureImportance());
+            assertNull(inference.getTopClasses());
+        }
+        {
+            assertThat(buckets.get(1).getKey(), equalTo("banana"));
+            ParsedInference inference = buckets.get(1).getAggregations().get("infer");
+            assertThat((Double) inference.getValue(), closeTo(10.0, 0.01));
+            assertNull(inference.getWarning());
+            assertNull(inference.getFeatureImportance());
+            assertNull(inference.getTopClasses());
+        }
+    }
+
+    private void putTrainedModel(String modelId, List<String> inputFields, Tree tree) throws IOException {
+        TrainedModelDefinition definition = new TrainedModelDefinition.Builder().setTrainedModel(tree).build();
+        TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+            .setDefinition(definition)
+            .setModelId(modelId)
+            .setInferenceConfig(new RegressionConfig())
+            .setInput(new TrainedModelInput(inputFields))
+            .setDescription("test model")
+            .build();
+        highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
+    }
+
+    private void indexData(String index) throws IOException {
+        CreateIndexRequest create = new CreateIndexRequest(index);
+        create.mapping("{\"properties\": {\"fruit\": {\"type\": \"keyword\"}," +
+            "\"cost\": {\"type\": \"double\"}}}", XContentType.JSON);
+        highLevelClient().indices().create(create, RequestOptions.DEFAULT);
+        BulkRequest bulk = new BulkRequest(index).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
+        bulk.add(new IndexRequest().source(XContentType.JSON, "fruit", "apple", "cost", "1.2"));
+        bulk.add(new IndexRequest().source(XContentType.JSON, "fruit", "banana", "cost", "0.8"));
+        bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
+        highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
+    }
+}

+ 0 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelInputTests.java

@@ -54,5 +54,4 @@ public class TrainedModelInputTests extends AbstractXContentTestCase<TrainedMode
     protected TrainedModelInput createTestInstance() {
         return createRandomInput();
     }
-
 }

+ 59 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/FeatureImportanceTests.java

@@ -0,0 +1,59 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference.results;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.function.Function;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class FeatureImportanceTests extends AbstractXContentTestCase<FeatureImportance> {
+
+    @Override
+    protected FeatureImportance createTestInstance() {
+        return new FeatureImportance(
+            randomAlphaOfLength(10),
+            randomDoubleBetween(-10.0, 10.0, false),
+            randomBoolean() ? null :
+                Stream.generate(() -> randomAlphaOfLength(10))
+                    .limit(randomLongBetween(2, 10))
+                    .collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
+
+    }
+
+    @Override
+    protected FeatureImportance doParseInstance(XContentParser parser) throws IOException {
+        return FeatureImportance.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> field.equals(FeatureImportance.CLASS_IMPORTANCE);
+    }
+}

+ 50 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/results/TopClassEntryTests.java

@@ -0,0 +1,50 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference.results;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class TopClassEntryTests extends AbstractXContentTestCase<TopClassEntry> {
+    @Override
+    protected TopClassEntry createTestInstance() {
+        Object classification;
+        if (randomBoolean()) {
+            classification = randomAlphaOfLength(10);
+        } else if (randomBoolean()) {
+            classification = randomBoolean();
+        } else {
+            classification = randomDouble();
+        }
+        return new TopClassEntry(classification, randomDouble(), randomDouble());
+    }
+
+    @Override
+    protected TopClassEntry doParseInstance(XContentParser parser) throws IOException {
+        return TopClassEntry.fromXContent(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 1 - 0
docs/java-rest/high-level/aggs-builders.asciidoc

@@ -62,6 +62,7 @@ This page lists all the available aggregations with their corresponding `Aggrega
 | Pipeline on                                                                                        | PipelineAggregationBuilder Class                                                                                                                  | Method in PipelineAggregatorBuilders
 | {ref}/search-aggregations-pipeline-avg-bucket-aggregation.html[Avg Bucket]                         | {agg-ref}/pipeline/bucketmetrics/avg/AvgBucketPipelineAggregationBuilder.html[AvgBucketPipelineAggregationBuilder]                                | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#avgBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.avgBucket()]
 | {ref}/search-aggregations-pipeline-derivative-aggregation.html[Derivative]                         | {agg-ref}/pipeline/derivative/DerivativePipelineAggregationBuilder.html[DerivativePipelineAggregationBuilder]                                     | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#derivative-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.derivative()]
+| {ref}/search-aggregations-pipeline-inference-bucket-aggregation.html[Inference]                    | {javadoc-client}/analytics/InferencePipelineAggregationBuilder.html[InferencePipelineAggregationBuilder]                                          | None
 | {ref}/search-aggregations-pipeline-max-bucket-aggregation.html[Max Bucket]                         | {agg-ref}/pipeline/bucketmetrics/max/MaxBucketPipelineAggregationBuilder.html[MaxBucketPipelineAggregationBuilder]                                | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#maxBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.maxBucket()]
 | {ref}/search-aggregations-pipeline-min-bucket-aggregation.html[Min Bucket]                         | {agg-ref}/pipeline/bucketmetrics/min/MinBucketPipelineAggregationBuilder.html[MinBucketPipelineAggregationBuilder]                                | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#minBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.minBucket()]
 | {ref}/search-aggregations-pipeline-sum-bucket-aggregation.html[Sum Bucket]                         | {agg-ref}/pipeline/bucketmetrics/sum/SumBucketPipelineAggregationBuilder.html[SumBucketPipelineAggregationBuilder]                                | {agg-ref}/pipeline/PipelineAggregatorBuilders.html#sumBucket-java.lang.String-java.lang.String-[PipelineAggregatorBuilders.sumBucket()]

+ 0 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/aggs/ParsedInference.java