Browse Source

[ML] add new multi custom processor for data frame analytics and model inference (#67362)

This adds the multi custom feature processor to data frame analytics and inference.

The `multi_encoding` processor allows custom processors to be chained together and use the outputs from one processor as the inputs to another.
Benjamin Trent 4 years ago
parent
commit
cb34ca601c
22 changed files with 778 additions and 22 deletions
  1. 3 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java
  2. 119 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/Multi.java
  3. 32 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java
  4. 11 3
      client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java
  5. 7 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java
  6. 6 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/FrequencyEncodingTests.java
  7. 88 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/MultiTests.java
  8. 3 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/NGramTests.java
  9. 7 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/OneHotEncodingTests.java
  10. 6 2
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/TargetMeanEncodingTests.java
  11. 10 1
      server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java
  12. 7 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java
  13. 257 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/Multi.java
  14. 3 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java
  15. 5 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java
  16. 166 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/MultiTests.java
  17. 5 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java
  18. 17 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessingTests.java
  19. 5 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java
  20. 1 3
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java
  21. 19 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java
  22. 1 1
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

+ 3 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java

@@ -19,6 +19,7 @@
 package org.elasticsearch.client.ml.inference;
 
 import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
+import org.elasticsearch.client.ml.inference.preprocessing.Multi;
 import org.elasticsearch.client.ml.inference.preprocessing.NGram;
 import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
@@ -60,6 +61,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             CustomWordEmbedding::fromXContent));
         namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(NGram.NAME),
             NGram::fromXContent));
+        namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(Multi.NAME),
+            Multi::fromXContent));
 
         // Model
         namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));

+ 119 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/Multi.java

@@ -0,0 +1,119 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference.preprocessing;
+
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+/**
+ * Multi-PreProcessor for chaining together multiple processors
+ */
+public class Multi implements PreProcessor {
+
+    public static final String NAME = "multi_encoding";
+    public static final ParseField PROCESSORS = new ParseField("processors");
+    public static final ParseField CUSTOM = new ParseField("custom");
+
+    @SuppressWarnings("unchecked")
+    public static final ConstructingObjectParser<Multi, Void> PARSER = new ConstructingObjectParser<>(
+        NAME,
+        true,
+        a -> new Multi((List<PreProcessor>)a[0], (Boolean)a[1]));
+    static {
+        PARSER.declareNamedObjects(ConstructingObjectParser.constructorArg(),
+            (p, c, n) -> p.namedObject(PreProcessor.class, n, null),
+            (_unused) -> {/* Does not matter client side*/ },
+            PROCESSORS);
+        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
+    }
+
+    public static Multi fromXContent(XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final List<PreProcessor> processors;
+    private final Boolean custom;
+
+    Multi(List<PreProcessor> processors, Boolean custom) {
+        this.processors = Objects.requireNonNull(processors, PROCESSORS.getPreferredName());
+        this.custom = custom;
+    }
+
+    @Override
+    public String getName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        NamedXContentObjectHelper.writeNamedObjects(builder, params, true, PROCESSORS.getPreferredName(), processors);
+        if (custom != null) {
+            builder.field(CUSTOM.getPreferredName(), custom);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Multi multi = (Multi) o;
+        return Objects.equals(multi.processors, processors) && Objects.equals(custom, multi.custom);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(custom, processors);
+    }
+
+    public static Builder builder(List<PreProcessor> processors) {
+        return new Builder(processors);
+    }
+
+    public static class Builder {
+        private final List<PreProcessor> processors;
+        private Boolean custom;
+
+        public Builder(List<PreProcessor> processors) {
+            this.processors = processors;
+        }
+
+        public Builder setCustom(boolean custom) {
+            this.custom = custom;
+            return this;
+        }
+
+        public Multi build() {
+            return new Multi(processors, custom);
+        }
+    }
+
+}

+ 32 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/NGram.java

@@ -24,8 +24,12 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
+import java.util.function.IntFunction;
+import java.util.stream.IntStream;
 
 
 /**
@@ -134,6 +138,10 @@ public class NGram implements PreProcessor {
         return custom;
     }
 
+    public List<String> outputFields() {
+        return allPossibleNGramOutputFeatureNames();
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) return true;
@@ -152,6 +160,30 @@ public class NGram implements PreProcessor {
         return Objects.hash(field, featurePrefix, start, length, custom, nGrams);
     }
 
+    private String nGramFeature(int nGram, int pos) {
+        return featurePrefix
+            + "."
+            + nGram
+            + pos;
+    }
+
+    private List<String> allPossibleNGramOutputFeatureNames() {
+        int totalNgrams = 0;
+        for (int nGram : nGrams) {
+            totalNgrams += (length - (nGram - 1));
+        }
+        if (totalNgrams <= 0) {
+            return Collections.emptyList();
+        }
+        List<String> ngramOutputs = new ArrayList<>(totalNgrams);
+
+        for (int nGram : nGrams) {
+            IntFunction<String> func = i -> nGramFeature(nGram, i);
+            IntStream.range(0, (length - (nGram - 1))).mapToObj(func).forEach(ngramOutputs::add);
+        }
+        return ngramOutputs;
+    }
+
     public static Builder builder(String field) {
         return new Builder(field);
     }

+ 11 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -76,6 +76,7 @@ import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetec
 import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStats;
 import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
 import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
+import org.elasticsearch.client.ml.inference.preprocessing.Multi;
 import org.elasticsearch.client.ml.inference.preprocessing.NGram;
 import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
 import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
@@ -707,7 +708,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(75, namedXContents.size());
+        assertEquals(76, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -792,9 +793,16 @@ public class RestHighLevelClientTests extends ESTestCase {
                 registeredMetricName(Regression.NAME, MeanSquaredLogarithmicErrorMetric.NAME),
                 registeredMetricName(Regression.NAME, HuberMetric.NAME),
                 registeredMetricName(Regression.NAME, RSquaredMetric.NAME)));
-        assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
+        assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class));
         assertThat(names,
-            hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME, CustomWordEmbedding.NAME, NGram.NAME));
+            hasItems(
+                FrequencyEncoding.NAME,
+                OneHotEncoding.NAME,
+                TargetMeanEncoding.NAME,
+                CustomWordEmbedding.NAME,
+                NGram.NAME,
+                Multi.NAME
+            ));
         assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class));
         assertThat(names, hasItems(Tree.NAME, Ensemble.NAME, LangIdentNeuralNetwork.NAME));
         assertEquals(Integer.valueOf(4),

+ 7 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java

@@ -19,6 +19,8 @@
 package org.elasticsearch.client.ml.inference;
 
 import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
+import org.elasticsearch.client.ml.inference.preprocessing.MultiTests;
+import org.elasticsearch.client.ml.inference.preprocessing.NGramTests;
 import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
 import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
 import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
@@ -66,9 +68,12 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase<Traine
         return new TrainedModelDefinition.Builder()
             .setPreProcessors(
                 randomBoolean() ? null :
-                    Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
+                    Stream.generate(() -> randomFrom(
+                        FrequencyEncodingTests.createRandom(),
                         OneHotEncodingTests.createRandom(),
-                        TargetMeanEncodingTests.createRandom()))
+                        TargetMeanEncodingTests.createRandom(),
+                        NGramTests.createRandom(),
+                        MultiTests.createRandom()))
                         .limit(numberOfProcessors)
                         .collect(Collectors.toList()))
             .setTrainedModel(randomFrom(TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 6, targetType),

+ 6 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/FrequencyEncodingTests.java

@@ -50,14 +50,19 @@ public class FrequencyEncodingTests extends AbstractXContentTestCase<FrequencyEn
     }
 
     public static FrequencyEncoding createRandom() {
+        return createRandom(randomAlphaOfLength(10));
+    }
+
+    public static FrequencyEncoding createRandom(String inputField) {
         int valuesSize = randomIntBetween(1, 10);
         Map<String, Double> valueMap = new HashMap<>();
         for (int i = 0; i < valuesSize; i++) {
             valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
         }
-        return new FrequencyEncoding(randomAlphaOfLength(10),
+        return new FrequencyEncoding(inputField,
             randomAlphaOfLength(10),
             valueMap,
             randomBoolean() ? null : randomBoolean());
     }
+
 }

+ 88 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/MultiTests.java

@@ -0,0 +1,88 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference.preprocessing;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+
+public class MultiTests extends AbstractXContentTestCase<Multi> {
+
+    @Override
+    protected Multi doParseInstance(XContentParser parser) throws IOException {
+        return Multi.fromXContent(parser);
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> !field.isEmpty();
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Multi createTestInstance() {
+        return createRandom();
+    }
+
+    public static Multi createRandom() {
+        final List<PreProcessor> processors;
+        Boolean isCustom = randomBoolean() ? null : randomBoolean();
+        if (isCustom == null || isCustom == false) {
+            NGram nGram = new NGram(randomAlphaOfLength(10), Arrays.asList(1, 2), 0, 10, isCustom, "f");
+            List<PreProcessor> preProcessorList = new ArrayList<>();
+            preProcessorList.add(nGram);
+            Stream.generate(() -> randomFrom(
+                FrequencyEncodingTests.createRandom(randomFrom(nGram.outputFields())),
+                TargetMeanEncodingTests.createRandom(randomFrom(nGram.outputFields())),
+                OneHotEncodingTests.createRandom(randomFrom(nGram.outputFields()))
+            )).limit(randomIntBetween(1, 10)).forEach(preProcessorList::add);
+            processors = preProcessorList;
+        } else {
+            processors = Stream.generate(
+                () -> randomFrom(
+                    FrequencyEncodingTests.createRandom(),
+                    TargetMeanEncodingTests.createRandom(),
+                    OneHotEncodingTests.createRandom(),
+                    NGramTests.createRandom()
+                )
+            ).limit(randomIntBetween(1, 10)).collect(Collectors.toList());
+        }
+        return new Multi(processors, isCustom);
+    }
+
+}

+ 3 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/NGramTests.java

@@ -44,10 +44,11 @@ public class NGramTests extends AbstractXContentTestCase<NGram> {
     }
 
     public static NGram createRandom() {
+        int length = randomIntBetween(1, 10);
         return new NGram(randomAlphaOfLength(10),
-            IntStream.range(1, 5).limit(5).boxed().collect(Collectors.toList()),
+            IntStream.range(1, Math.min(5, length + 1)).limit(5).boxed().collect(Collectors.toList()),
             randomBoolean() ? null : randomIntBetween(0, 10),
-            randomBoolean() ? null : randomIntBetween(1, 10),
+            randomBoolean() ? null : length,
             randomBoolean() ? null : randomBoolean(),
             randomBoolean() ? null : randomAlphaOfLength(10));
     }

+ 7 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/OneHotEncodingTests.java

@@ -50,12 +50,18 @@ public class OneHotEncodingTests extends AbstractXContentTestCase<OneHotEncoding
     }
 
     public static OneHotEncoding createRandom() {
+        return createRandom(randomAlphaOfLength(10));
+    }
+
+    public static OneHotEncoding createRandom(String inputField) {
         int valuesSize = randomIntBetween(1, 10);
         Map<String, String> valueMap = new HashMap<>();
         for (int i = 0; i < valuesSize; i++) {
             valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
         }
-        return new OneHotEncoding(randomAlphaOfLength(10), valueMap, randomBoolean() ? null : randomBoolean());
+        return new OneHotEncoding(inputField,
+            valueMap,
+            randomBoolean() ? null : randomBoolean());
     }
 
 }

+ 6 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/TargetMeanEncodingTests.java

@@ -50,16 +50,20 @@ public class TargetMeanEncodingTests extends AbstractXContentTestCase<TargetMean
     }
 
     public static TargetMeanEncoding createRandom() {
+        return createRandom(randomAlphaOfLength(10));
+    }
+
+    public static TargetMeanEncoding createRandom(String inputField) {
         int valuesSize = randomIntBetween(1, 10);
         Map<String, Double> valueMap = new HashMap<>();
         for (int i = 0; i < valuesSize; i++) {
             valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
         }
-        return new TargetMeanEncoding(randomAlphaOfLength(10),
+        return new TargetMeanEncoding(inputField,
             randomAlphaOfLength(10),
             valueMap,
             randomDoubleBetween(0.0, 1.0, false),
-            randomBoolean() ? null : randomBoolean());
+            true);
     }
 
 }

+ 10 - 1
server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java

@@ -632,11 +632,20 @@ public abstract class StreamInput extends InputStream {
      * If the returned map contains any entries it will be mutable. If it is empty it might be immutable.
      */
     public <K, V> Map<K, V> readMap(Writeable.Reader<K> keyReader, Writeable.Reader<V> valueReader) throws IOException {
+        return readMap(keyReader, valueReader, HashMap::new);
+    }
+
+    public <K, V> Map<K, V> readOrderedMap(Writeable.Reader<K> keyReader, Writeable.Reader<V> valueReader) throws IOException {
+        return readMap(keyReader, valueReader, LinkedHashMap::new);
+    }
+
+    private <K, V> Map<K, V> readMap(Writeable.Reader<K> keyReader, Writeable.Reader<V> valueReader, IntFunction<Map<K, V>> constructor)
+        throws IOException {
         int size = readArraySize();
         if (size == 0) {
             return Collections.emptyMap();
         }
-        Map<K, V> map = new HashMap<>(size);
+        Map<K, V> map = constructor.apply(size);
         for (int i = 0; i < size; i++) {
             K key = keyReader.read(this);
             V value = valueReader.read(this);

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

@@ -11,6 +11,7 @@ import org.elasticsearch.plugins.spi.NamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.CustomWordEmbedding;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.Multi;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
@@ -67,6 +68,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             (p, c) -> CustomWordEmbedding.fromXContentLenient(p)));
         namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, NGram.NAME,
             (p, c) -> NGram.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
+        namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, Multi.NAME,
+            (p, c) -> Multi.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
 
         // PreProcessing Strict
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
@@ -79,6 +82,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             (p, c) -> CustomWordEmbedding.fromXContentStrict(p)));
         namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, NGram.NAME,
             (p, c) -> NGram.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
+        namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, Multi.NAME,
+            (p, c) -> Multi.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
 
         // Model Lenient
         namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));
@@ -161,6 +166,8 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
             CustomWordEmbedding::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, NGram.NAME.getPreferredName(),
             NGram::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(PreProcessor.class, Multi.NAME.getPreferredName(),
+            Multi::new));
 
         // Model
         namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new));

+ 257 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/Multi.java

@@ -0,0 +1,257 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.preprocessing;
+
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import org.apache.lucene.util.RamUsageEstimator;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
+
+/**
+ * Multi-PreProcessor for chaining together multiple processors
+ */
+public class Multi implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
+
+    public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Multi.class);
+    public static final ParseField NAME = new ParseField("multi_encoding");
+    public static final ParseField PROCESSORS = new ParseField("processors");
+    public static final ParseField CUSTOM = new ParseField("custom");
+
+    private static final ObjectParser<Multi.Builder, PreProcessorParseContext> STRICT_PARSER = createParser(false);
+    private static final ObjectParser<Multi.Builder, PreProcessorParseContext> LENIENT_PARSER = createParser(true);
+
+    private static ObjectParser<Multi.Builder, PreProcessorParseContext> createParser(boolean lenient) {
+        ObjectParser<Multi.Builder, PreProcessorParseContext> parser = new ObjectParser<>(
+            NAME.getPreferredName(),
+            lenient,
+            Multi.Builder::new
+        );
+        parser.declareNamedObjects(Multi.Builder::setProcessors,
+            (p, c, n) -> lenient ?
+                p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) :
+                p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
+            (multiBuilder) -> multiBuilder.setOrdered(true),
+            PROCESSORS);
+        parser.declareBoolean(Multi.Builder::setCustom, CUSTOM);
+        return parser;
+    }
+
+    public static Multi fromXContentStrict(XContentParser parser, PreProcessorParseContext context) {
+        return STRICT_PARSER.apply(parser, context == null ?  PreProcessorParseContext.DEFAULT : context).build();
+    }
+
+    public static Multi fromXContentLenient(XContentParser parser, PreProcessorParseContext context) {
+        return LENIENT_PARSER.apply(parser, context == null ?  PreProcessorParseContext.DEFAULT : context).build();
+    }
+
+    private final PreProcessor[] processors;
+    private final boolean custom;
+    private final Map<String, String> outputFields;
+    private final String[] inputFields;
+
+    public Multi(PreProcessor[] processors, Boolean custom) {
+        this.processors = ExceptionsHelper.requireNonNull(processors, PROCESSORS);
+        if (this.processors.length < 2) {
+            throw new IllegalArgumentException("processors must be an array of objects with at least length 2");
+        }
+        this.custom = custom != null && custom;
+        Set<String> consumedOutputFields = new HashSet<>();
+        List<String> inputFields = new ArrayList<>(processors[0].inputFields());
+        Map<String, String> originatingOutputFields = new LinkedHashMap<>();
+        for (String outputField : processors[0].outputFields()) {
+            originatingOutputFields.put(outputField, processors[0].getOutputFieldType(outputField));
+        }
+        for (int i = 1; i < processors.length; i++) {
+            final PreProcessor processor = processors[i];
+            for (String inputField : processor.inputFields()) {
+                if (originatingOutputFields.containsKey(inputField) == false) {
+                    inputFields.add(inputField);
+                } else {
+                    consumedOutputFields.add(inputField);
+                }
+            }
+            for (String outputField : processor.outputFields()) {
+                originatingOutputFields.put(outputField, processor.getOutputFieldType(outputField));
+            }
+        }
+        Map<String, String> outputFields = new LinkedHashMap<>();
+        for (var outputField : originatingOutputFields.entrySet()) {
+            if (consumedOutputFields.contains(outputField.getKey()) == false) {
+                outputFields.put(outputField.getKey(), outputField.getValue());
+            }
+        }
+        this.outputFields = outputFields;
+        this.inputFields = inputFields.toArray(String[]::new);
+        if (this.custom == false && this.inputFields.length > 1) {
+            throw new IllegalArgumentException(
+                String.format(
+                    Locale.ROOT,
+                    "[custom] cannot be false as [%s] is unable to accurately determine" +
+                        " field reverse encoding for input fields [%s] and output fields %s",
+                    NAME.getPreferredName(),
+                    Strings.arrayToCommaDelimitedString(this.inputFields),
+                    this.outputFields.keySet()
+                )
+            );
+        }
+    }
+
+    public Multi(StreamInput in) throws IOException {
+        this.processors = in.readNamedWriteableList(PreProcessor.class).toArray(PreProcessor[]::new);
+        this.custom = in.readBoolean();
+        this.outputFields = in.readOrderedMap(StreamInput::readString, StreamInput::readString);
+        this.inputFields = in.readStringArray();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeNamedWriteableList(Arrays.asList(processors));
+        out.writeBoolean(custom);
+        out.writeMap(outputFields, StreamOutput::writeString, StreamOutput::writeString);
+        out.writeStringArray(inputFields);
+    }
+
+    @Override
+    public String toString() {
+        return Strings.toString(this);
+    }
+
+    @Override
+    public List<String> inputFields() {
+        return Arrays.asList(inputFields);
+    }
+
+    @Override
+    public List<String> outputFields() {
+        return new ArrayList<>(outputFields.keySet());
+    }
+
+    @Override
+    public void process(Map<String, Object> fields) {
+        for (PreProcessor processor : processors) {
+            processor.process(fields);
+        }
+    }
+
+    @Override
+    public Map<String, String> reverseLookup() {
+        if (inputFields.length > 1) {
+            throw new IllegalArgumentException(
+                String.format(
+                    Locale.ROOT,
+                    "[%s] is unable to accurately determine field reverse encoding for input fields [%s] and output fields %s",
+                    NAME.getPreferredName(),
+                    Strings.arrayToCommaDelimitedString(inputFields),
+                    outputFields.keySet()
+                )
+            );
+        }
+        return outputFields.keySet().stream().collect(Collectors.toMap(Function.identity(), _unused -> inputFields[0]));
+    }
+
+    @Override
+    public String getOutputFieldType(String outputField) {
+        return outputFields.get(outputField);
+    }
+
+    @Override
+    public long ramBytesUsed() {
+        long size = SHALLOW_SIZE;
+        size += RamUsageEstimator.sizeOf(processors);
+        size += RamUsageEstimator.sizeOf(inputFields);
+        size += RamUsageEstimator.sizeOfMap(outputFields, 0);
+        return size;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public String getName() {
+        return NAME.getPreferredName();
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        NamedXContentObjectHelper.writeNamedObjects(builder, params, true, PROCESSORS.getPreferredName(), Arrays.asList(processors));
+        builder.field(CUSTOM.getPreferredName(), custom);
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean isCustom() {
+        return custom;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        Multi multi = (Multi) o;
+        return Arrays.equals(multi.processors, processors) && custom == multi.custom;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(custom, Arrays.hashCode(processors));
+    }
+
+    static class Builder {
+        private boolean ordered;
+        private List<PreProcessor> processors;
+        private boolean custom;
+
+        public Builder setOrdered(boolean ordered) {
+            this.ordered = ordered;
+            return this;
+        }
+
+        public Builder setProcessors(List<PreProcessor> processors) {
+            this.processors = processors;
+            return this;
+        }
+
+        public Builder setCustom(boolean custom) {
+            this.custom = custom;
+            return this;
+        }
+
+        Multi build() {
+            if (ordered == false) {
+                throw new IllegalArgumentException("processors must be an array of objects");
+            }
+            if (processors.size() < 2) {
+                throw new IllegalArgumentException("processors must be an array of objects with at least length 2");
+            }
+            return new Multi(processors.toArray(PreProcessor[]::new), custom);
+        }
+    }
+
+}

+ 3 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/NGram.java

@@ -114,6 +114,9 @@ public class NGram implements LenientlyParsedPreProcessor, StrictlyParsedPreProc
         this.field = ExceptionsHelper.requireNonNull(field, FIELD);
         this.featurePrefix = ExceptionsHelper.requireNonNull(featurePrefix, FEATURE_PREFIX);
         this.nGrams = ExceptionsHelper.requireNonNull(nGrams, NGRAMS);
+        if (nGrams.length == 0) {
+            throw ExceptionsHelper.badRequestException("[{}] must not be empty", NGRAMS.getPreferredName());
+        }
         if (Arrays.stream(this.nGrams).anyMatch(i -> i < MIN_GRAM || i > MAX_GRAM)) {
             throw ExceptionsHelper.badRequestException(
                 "[{}] is invalid [{}]; minimum supported value is [{}]; maximum supported value is [{}]",

+ 5 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java

@@ -39,12 +39,16 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
     }
 
     public static FrequencyEncoding createRandom(Boolean isCustom) {
+        return createRandom(isCustom, randomAlphaOfLength(10));
+    }
+
+    public static FrequencyEncoding createRandom(Boolean isCustom, String inputField) {
         int valuesSize = randomIntBetween(1, 10);
         Map<String, Double> valueMap = new HashMap<>();
         for (int i = 0; i < valuesSize; i++) {
             valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
         }
-        return new FrequencyEncoding(randomAlphaOfLength(10),
+        return new FrequencyEncoding(inputField,
             randomAlphaOfLength(10),
             valueMap,
             isCustom);

+ 166 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/MultiTests.java

@@ -0,0 +1,166 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference.preprocessing;
+
+import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasEntry;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Predicate;
+import java.util.stream.Stream;
+
+import org.elasticsearch.common.collect.MapBuilder;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+public class MultiTests extends PreProcessingTests<Multi> {
+
+    @Override
+    protected Multi doParseInstance(XContentParser parser) throws IOException {
+        return lenient ?
+            Multi.fromXContentLenient(parser, PreProcessor.PreProcessorParseContext.DEFAULT) :
+            Multi.fromXContentStrict(parser, PreProcessor.PreProcessorParseContext.DEFAULT);
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> !field.isEmpty();
+    }
+
+    @Override
+    protected Multi createTestInstance() {
+        return createRandom();
+    }
+
+    public static Multi createRandom() {
+        return createRandom(randomBoolean() ? null : randomBoolean());
+    }
+
+    public static Multi createRandom(Boolean isCustom) {
+        final PreProcessor[] processors;
+        if (isCustom == null || isCustom == false) {
+            NGram nGram = NGramTests.createRandom(isCustom);
+            List<PreProcessor> preProcessorList = new ArrayList<>();
+            preProcessorList.add(nGram);
+            Stream.generate(() -> randomFrom(
+                FrequencyEncodingTests.createRandom(isCustom, randomFrom(nGram.outputFields())),
+                TargetMeanEncodingTests.createRandom(isCustom, randomFrom(nGram.outputFields())),
+                OneHotEncodingTests.createRandom(isCustom, randomFrom(nGram.outputFields()))
+            )).limit(randomIntBetween(1, 10)).forEach(preProcessorList::add);
+            processors = preProcessorList.toArray(PreProcessor[]::new);
+        } else {
+            processors = randomArray(
+                2,
+                10,
+                PreProcessor[]::new,
+                () -> randomFrom(
+                    FrequencyEncodingTests.createRandom(isCustom),
+                    TargetMeanEncodingTests.createRandom(isCustom),
+                    OneHotEncodingTests.createRandom(isCustom),
+                    NGramTests.createRandom(isCustom)
+                )
+            );
+        }
+        return new Multi(processors, isCustom);
+    }
+
+    @Override
+    protected Writeable.Reader<Multi> instanceReader() {
+        return Multi::new;
+    }
+
+    public void testReverseLookup() {
+        String field = "text";
+        NGram nGram = new NGram(field, Collections.singletonList(1), 0, 2, null, "f");
+        OneHotEncoding oneHotEncoding = new OneHotEncoding("f.10",
+            MapBuilder.<String, String>newMapBuilder()
+                .put("a", "has_a")
+                .put("b", "has_b")
+                .map(),
+            true);
+        Multi multi = new Multi(new PreProcessor[]{nGram, oneHotEncoding}, true);
+        assertThat(multi.reverseLookup(), allOf(hasEntry("has_a", field), hasEntry("has_b", field), hasEntry("f.11", field)));
+
+        OneHotEncoding oneHotEncodingOutside = new OneHotEncoding("some_other",
+            MapBuilder.<String, String>newMapBuilder()
+                .put("a", "has_3_a")
+                .put("b", "has_3_b")
+                .map(),
+            true);
+        multi = new Multi(new PreProcessor[]{nGram, oneHotEncoding, oneHotEncodingOutside}, true);
+        expectThrows(IllegalArgumentException.class, multi::reverseLookup);
+    }
+
+    public void testProcessWithFieldPresent() {
+        String field = "text";
+        NGram nGram = new NGram(field, Collections.singletonList(1), 0, 2, null, "f");
+        OneHotEncoding oneHotEncoding1 = new OneHotEncoding("f.10",
+            MapBuilder.<String, String>newMapBuilder()
+                .put("a", "has_a")
+                .put("b", "has_b")
+                .map(),
+            true);
+        OneHotEncoding oneHotEncoding2 = new OneHotEncoding("f.11",
+            MapBuilder.<String, String>newMapBuilder()
+                .put("a", "has_2_a")
+                .put("b", "has_2_b")
+                .map(),
+            true);
+        Multi multi = new Multi(new PreProcessor[]{nGram, oneHotEncoding1, oneHotEncoding2}, true);
+        Map<String, Object> fields = randomFieldValues("text", "cat");
+        multi.process(fields);
+        assertThat(fields, hasEntry("has_a", 0));
+        assertThat(fields, hasEntry("has_b", 0));
+        assertThat(fields, hasEntry("has_2_a", 1));
+        assertThat(fields, hasEntry("has_2_b", 0));
+    }
+
+    public void testInputOutputFields() {
+        String field = "text";
+        NGram nGram = new NGram(field, Collections.singletonList(1), 0, 3, null, "f");
+        OneHotEncoding oneHotEncoding1 = new OneHotEncoding("f.10",
+            MapBuilder.<String, String>newMapBuilder()
+                .put("a", "has_a")
+                .put("b", "has_b")
+                .map(),
+            true);
+        OneHotEncoding oneHotEncoding2 = new OneHotEncoding("f.11",
+            MapBuilder.<String, String>newMapBuilder()
+                .put("a", "has_2_a")
+                .put("b", "has_2_b")
+                .map(),
+            true);
+        OneHotEncoding oneHotEncoding3 = new OneHotEncoding("some_other",
+            MapBuilder.<String, String>newMapBuilder()
+                .put("a", "has_3_a")
+                .put("b", "has_3_b")
+                .map(),
+            true);
+        Multi multi = new Multi(new PreProcessor[]{nGram, oneHotEncoding1, oneHotEncoding2, oneHotEncoding3}, true);
+        assertThat(multi.inputFields(), contains(field, "some_other"));
+        assertThat(multi.outputFields(),
+            contains(
+                "f.12",
+                "has_a",
+                "has_b",
+                "has_2_a",
+                "has_2_b",
+                "has_3_a",
+                "has_3_b")
+        );
+        assertThat(multi.getOutputFieldType("f.12"), equalTo("text"));
+        for (String fieldName : new String[]{"has_a", "has_b", "has_2_a", "has_2_b", "has_3_a", "has_3_b"}) {
+            assertThat(multi.getOutputFieldType(fieldName), equalTo("integer"));
+        }
+    }
+
+}

+ 5 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java

@@ -39,12 +39,16 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
     }
 
     public static OneHotEncoding createRandom(Boolean isCustom) {
+        return createRandom(isCustom, randomAlphaOfLength(10));
+    }
+
+    public static OneHotEncoding createRandom(Boolean isCustom, String inputField) {
         int valuesSize = randomIntBetween(1, 10);
         Map<String, String> valueMap = new HashMap<>();
         for (int i = 0; i < valuesSize; i++) {
             valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
         }
-        return new OneHotEncoding(randomAlphaOfLength(10),
+        return new OneHotEncoding(inputField,
             valueMap,
             isCustom);
     }

+ 17 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessingTests.java

@@ -5,12 +5,17 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.preprocessing;
 
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.hamcrest.Matcher;
 import org.junit.Before;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.function.Predicate;
 
@@ -31,6 +36,18 @@ public abstract class PreProcessingTests<T extends PreProcessor> extends Abstrac
         return lenient;
     }
 
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
+        entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(entries);
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
     @Override
     protected Predicate<String> getRandomFieldsExcludeFilter() {
         return field -> !field.isEmpty();

+ 5 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java

@@ -34,18 +34,21 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
         return createRandom();
     }
 
-
     public static TargetMeanEncoding createRandom() {
         return createRandom(randomBoolean() ? randomBoolean() : null);
     }
 
     public static TargetMeanEncoding createRandom(Boolean isCustom) {
+        return createRandom(isCustom, randomAlphaOfLength(10));
+    }
+
+    public static TargetMeanEncoding createRandom(Boolean isCustom, String inputField) {
         int valuesSize = randomIntBetween(1, 10);
         Map<String, Double> valueMap = new HashMap<>();
         for (int i = 0; i < valuesSize; i++) {
             valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
         }
-        return new TargetMeanEncoding(randomAlphaOfLength(10),
+        return new TargetMeanEncoding(inputField,
             randomAlphaOfLength(10),
             valueMap,
             randomDoubleBetween(0.0, 1.0, false),

+ 1 - 3
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -21,7 +21,6 @@ import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.common.collect.MapBuilder;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
-import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
@@ -34,7 +33,6 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigUpdate;
-import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
@@ -357,7 +355,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
 
         GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE,
-            new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet();
+            new GetTrainedModelsAction.Request(jobId + "*", Collections.emptyList(), Collections.singleton("definition"))).actionGet();
         assertThat(response.getResources().results().size(), equalTo(1));
         TrainedModelConfig modelConfig = response.getResources().results().get(0);
         modelConfig.ensureParsedDefinition(xContentRegistry());

+ 19 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalysisCustomFeatureIT.java

@@ -14,6 +14,7 @@ import org.elasticsearch.action.get.GetResponse;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.WriteRequest;
+import org.elasticsearch.common.collect.MapBuilder;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.index.query.QueryBuilders;
@@ -27,18 +28,24 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.Multi;
 import org.elasticsearch.xpack.core.ml.inference.preprocessing.NGram;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
 import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
 import org.junit.After;
 import org.junit.Before;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
+import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.everyItem;
 import static org.hamcrest.Matchers.hasKey;
@@ -111,7 +118,17 @@ public class DataFrameAnalysisCustomFeatureIT extends MlNativeDataFrameAnalytics
                 42L,
                 null,
                 null,
-                Collections.singletonList(new NGram(TEXT_FIELD, "f", new int[]{1, 2}, 0, 2, true))))
+                Arrays.asList(
+                    new NGram(TEXT_FIELD, "f", new int[]{1}, 0, 2, true),
+                    new Multi(new PreProcessor[]{
+                        new NGram(TEXT_FIELD, "ngram", new int[]{2}, 0, 3, true),
+                        new FrequencyEncoding("ngram.20",
+                            "frequency",
+                            MapBuilder.<String, Double>newMapBuilder().put("ca", 5.0).put("do", 1.0).map(), true),
+                        new OneHotEncoding("ngram.21", MapBuilder.<String, String>newMapBuilder().put("at", "is_cat").map(), true)
+                    },
+                        true)
+                    )))
             .setAnalyzedFields(new FetchSourceContext(true, new String[]{TEXT_FIELD, NUMERICAL_FIELD}, new String[]{}))
             .build();
         putAnalytics(config);
@@ -130,7 +147,7 @@ public class DataFrameAnalysisCustomFeatureIT extends MlNativeDataFrameAnalytics
             @SuppressWarnings("unchecked")
             List<Map<String, Object>> importanceArray = (List<Map<String, Object>>)resultsObject.get("feature_importance");
             assertThat(importanceArray.stream().map(m -> m.get("feature_name").toString()).collect(Collectors.toSet()),
-                everyItem(startsWith("f.")));
+                everyItem(anyOf(startsWith("f."), startsWith("ngram"), equalTo("is_cat"), equalTo("frequency"))));
         }
 
         assertProgressComplete(jobId);

+ 1 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

@@ -661,7 +661,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             "Started writing results",
             "Finished analysis");
         GetTrainedModelsAction.Response response = client().execute(GetTrainedModelsAction.INSTANCE,
-            new GetTrainedModelsAction.Request(jobId + "*", true, Collections.emptyList())).actionGet();
+            new GetTrainedModelsAction.Request(jobId + "*", Collections.emptyList(), Collections.singleton("definition"))).actionGet();
         assertThat(response.getResources().results().size(), equalTo(1));
         TrainedModelConfig modelConfig = response.getResources().results().get(0);
         modelConfig.ensureParsedDefinition(xContentRegistry());