Browse Source

[ML][Inference] Adding preprocessors to definition object (#47320)

* [ML][Inference] Adding preprocessors to definition object

* Update TrainedModelConfig.java
Benjamin Trent 6 years ago
parent
commit
e2395addb0

+ 9 - 21
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java

@@ -20,7 +20,6 @@ package org.elasticsearch.client.ml.inference;
 
 import org.elasticsearch.Version;
 import org.elasticsearch.client.common.TimeUtil;
-import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.xcontent.ObjectParser;
@@ -31,7 +30,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
 import java.io.IOException;
 import java.time.Instant;
 import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
@@ -64,9 +62,8 @@ public class TrainedModelConfig implements ToXContentObject {
         PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
         PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
         PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
-        PARSER.declareNamedObjects(TrainedModelConfig.Builder::setDefinition,
-            (p, c, n) -> p.namedObject(TrainedModel.class, n, null),
-            (modelDocBuilder) -> { /* Noop does not matter client side */ },
+        PARSER.declareObject(TrainedModelConfig.Builder::setDefinition,
+            (p, c) -> TrainedModelDefinition.fromXContent(p),
             DEFINITION);
     }
 
@@ -82,7 +79,7 @@ public class TrainedModelConfig implements ToXContentObject {
     private final Long modelVersion;
     private final String modelType;
     private final Map<String, Object> metadata;
-    private final TrainedModel definition;
+    private final TrainedModelDefinition definition;
 
     TrainedModelConfig(String modelId,
                        String createdBy,
@@ -91,7 +88,7 @@ public class TrainedModelConfig implements ToXContentObject {
                        Instant createdTime,
                        Long modelVersion,
                        String modelType,
-                       TrainedModel definition,
+                       TrainedModelDefinition definition,
                        Map<String, Object> metadata) {
         this.modelId = modelId;
         this.createdBy = createdBy;
@@ -136,7 +133,7 @@ public class TrainedModelConfig implements ToXContentObject {
         return metadata;
     }
 
-    public TrainedModel getDefinition() {
+    public TrainedModelDefinition getDefinition() {
         return definition;
     }
 
@@ -169,11 +166,7 @@ public class TrainedModelConfig implements ToXContentObject {
             builder.field(MODEL_TYPE.getPreferredName(), modelType);
         }
         if (definition != null) {
-            NamedXContentObjectHelper.writeNamedObjects(builder,
-                params,
-                false,
-                DEFINITION.getPreferredName(),
-                Collections.singletonList(definition));
+            builder.field(DEFINITION.getPreferredName(), definition);
         }
         if (metadata != null) {
             builder.field(METADATA.getPreferredName(), metadata);
@@ -227,7 +220,7 @@ public class TrainedModelConfig implements ToXContentObject {
         private Long modelVersion;
         private String modelType;
         private Map<String, Object> metadata;
-        private TrainedModel definition;
+        private TrainedModelDefinition.Builder definition;
 
         public Builder setModelId(String modelId) {
             this.modelId = modelId;
@@ -273,16 +266,11 @@ public class TrainedModelConfig implements ToXContentObject {
             return this;
         }
 
-        public Builder setDefinition(TrainedModel definition) {
+        public Builder setDefinition(TrainedModelDefinition.Builder definition) {
             this.definition = definition;
             return this;
         }
 
-        private Builder setDefinition(List<TrainedModel> definition) {
-            assert definition.size() == 1;
-            return setDefinition(definition.get(0));
-        }
-
         public TrainedModelConfig build() {
             return new TrainedModelConfig(
                 modelId,
@@ -292,7 +280,7 @@ public class TrainedModelConfig implements ToXContentObject {
                 createdTime,
                 modelVersion,
                 modelType,
-                definition,
+                definition == null ? null : definition.build(),
                 metadata);
         }
     }

+ 137 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java

@@ -0,0 +1,137 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference;
+
+import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
+import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+public class TrainedModelDefinition implements ToXContentObject {
+
+    public static final String NAME = "trained_model_doc";
+
+    public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
+    public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
+
+    public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
+            true,
+            TrainedModelDefinition.Builder::new);
+    static {
+        PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
+            (p, c, n) -> p.namedObject(TrainedModel.class, n, null),
+            (modelDocBuilder) -> { /* Noop does not matter client side*/ },
+            TRAINED_MODEL);
+        PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
+            (p, c, n) -> p.namedObject(PreProcessor.class, n, null),
+            (trainedModelDefBuilder) -> {/* Does not matter client side*/ },
+            PREPROCESSORS);
+    }
+
+    public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
+        return PARSER.parse(parser, null);
+    }
+
+    private final TrainedModel trainedModel;
+    private final List<PreProcessor> preProcessors;
+
+    TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
+        this.trainedModel = trainedModel;
+        this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        NamedXContentObjectHelper.writeNamedObjects(builder,
+            params,
+            false,
+            TRAINED_MODEL.getPreferredName(),
+            Collections.singletonList(trainedModel));
+        NamedXContentObjectHelper.writeNamedObjects(builder,
+            params,
+            true,
+            PREPROCESSORS.getPreferredName(),
+            preProcessors);
+        builder.endObject();
+        return builder;
+    }
+
+    public TrainedModel getTrainedModel() {
+        return trainedModel;
+    }
+
+    public List<PreProcessor> getPreProcessors() {
+        return preProcessors;
+    }
+
+    @Override
+    public String toString() {
+        return Strings.toString(this);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TrainedModelDefinition that = (TrainedModelDefinition) o;
+        return Objects.equals(trainedModel, that.trainedModel) &&
+            Objects.equals(preProcessors, that.preProcessors) ;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(trainedModel, preProcessors);
+    }
+
+    public static class Builder {
+
+        private List<PreProcessor> preProcessors;
+        private TrainedModel trainedModel;
+
+        public Builder setPreProcessors(List<PreProcessor> preProcessors) {
+            this.preProcessors = preProcessors;
+            return this;
+        }
+
+        public Builder setTrainedModel(TrainedModel trainedModel) {
+            this.trainedModel = trainedModel;
+            return this;
+        }
+
+        private Builder setTrainedModel(List<TrainedModel> trainedModel) {
+            assert trainedModel.size() == 1;
+            return setTrainedModel(trainedModel.get(0));
+        }
+
+        public TrainedModelDefinition build() {
+            return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
+        }
+    }
+
+}

+ 2 - 2
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/PreProcessor.java

@@ -18,13 +18,13 @@
  */
 package org.elasticsearch.client.ml.inference.preprocessing;
 
-import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.client.ml.inference.NamedXContentObject;
 
 
 /**
  * Describes a pre-processor for a defined machine learning model
  */
-public interface PreProcessor extends ToXContentObject {
+public interface PreProcessor extends NamedXContentObject {
 
     /**
      * @return The name of the pre-processor

+ 1 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java

@@ -19,7 +19,6 @@
 package org.elasticsearch.client.ml.inference;
 
 import org.elasticsearch.Version;
-import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParser;
@@ -61,7 +60,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
             Instant.ofEpochMilli(randomNonNegativeLong()),
             randomBoolean() ? null : randomNonNegativeLong(),
             randomAlphaOfLength(10),
-            randomFrom(TreeTests.createRandom()),
+            randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
             randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
     }
 

+ 83 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java

@@ -0,0 +1,83 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference;
+
+import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
+import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
+import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
+import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class TrainedModelDefinitionTests extends AbstractXContentTestCase<TrainedModelDefinition> {
+
+    @Override
+    protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
+        return TrainedModelDefinition.fromXContent(parser).build();
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> !field.isEmpty();
+    }
+
+    public static TrainedModelDefinition.Builder createRandomBuilder() {
+        int numberOfProcessors = randomIntBetween(1, 10);
+        return new TrainedModelDefinition.Builder()
+            .setPreProcessors(
+                randomBoolean() ? null :
+                    Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
+                        OneHotEncodingTests.createRandom(),
+                        TargetMeanEncodingTests.createRandom()))
+                        .limit(numberOfProcessors)
+                        .collect(Collectors.toList()))
+            .setTrainedModel(randomFrom(TreeTests.createRandom()));
+    }
+
+    @Override
+    protected TrainedModelDefinition createTestInstance() {
+        return createRandomBuilder().build();
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        return new NamedXContentRegistry(namedXContent);
+    }
+
+}

+ 12 - 33
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

@@ -17,18 +17,13 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.common.time.TimeUtils;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.utils.MlStrings;
-import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
 
 import java.io.IOException;
 import java.time.Instant;
 import java.util.Collections;
-import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
@@ -65,11 +60,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION);
         parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE);
         parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
-        parser.declareNamedObjects(TrainedModelConfig.Builder::setDefinition,
-            (p, c, n) -> ignoreUnknownFields ?
-                p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
-                p.namedObject(StrictlyParsedTrainedModel.class, n, null),
-            (modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
+        parser.declareObject(TrainedModelConfig.Builder::setDefinition,
+            (p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields),
             DEFINITION);
         return parser;
     }
@@ -94,7 +86,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
     // TODO how to reference and store large models that will not be executed in Java???
     // Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something
     // TODO Should this be lazily parsed when loading via the index???
-    private final TrainedModel definition;
+    private final TrainedModelDefinition definition;
     TrainedModelConfig(String modelId,
                        String createdBy,
                        Version version,
@@ -102,7 +94,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
                        Instant createdTime,
                        Long modelVersion,
                        String modelType,
-                       TrainedModel definition,
+                       TrainedModelDefinition definition,
                        Map<String, Object> metadata) {
         this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
         this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
@@ -123,7 +115,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         createdTime = in.readInstant();
         modelVersion = in.readVLong();
         modelType = in.readString();
-        definition = in.readOptionalNamedWriteable(TrainedModel.class);
+        definition = in.readOptionalWriteable(TrainedModelDefinition::new);
         metadata = in.readMap();
     }
 
@@ -160,7 +152,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
     }
 
     @Nullable
-    public TrainedModel getDefinition() {
+    public TrainedModelDefinition getDefinition() {
         return definition;
     }
 
@@ -177,7 +169,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         out.writeInstant(createdTime);
         out.writeVLong(modelVersion);
         out.writeString(modelType);
-        out.writeOptionalNamedWriteable(definition);
+        out.writeOptionalWriteable(definition);
         out.writeMap(metadata);
     }
 
@@ -194,11 +186,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         builder.field(MODEL_VERSION.getPreferredName(), modelVersion);
         builder.field(MODEL_TYPE.getPreferredName(), modelType);
         if (definition != null) {
-            NamedXContentObjectHelper.writeNamedObjects(builder,
-                params,
-                false,
-                DEFINITION.getPreferredName(),
-                Collections.singletonList(definition));
+            builder.field(DEFINITION.getPreferredName(), definition);
         }
         if (metadata != null) {
             builder.field(METADATA.getPreferredName(), metadata);
@@ -241,7 +229,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
             modelVersion);
     }
 
-
     public static class Builder {
 
         private String modelId;
@@ -252,7 +239,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         private Long modelVersion;
         private String modelType;
         private Map<String, Object> metadata;
-        private TrainedModel definition;
+        private TrainedModelDefinition.Builder definition;
 
         public Builder setModelId(String modelId) {
             this.modelId = modelId;
@@ -298,19 +285,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
             return this;
         }
 
-        public Builder setDefinition(TrainedModel definition) {
+        public Builder setDefinition(TrainedModelDefinition.Builder definition) {
             this.definition = definition;
             return this;
         }
 
-        private Builder setDefinition(List<TrainedModel> definition) {
-            if (definition.size() != 1) {
-                throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
-                    DEFINITION.getPreferredName());
-            }
-            return setDefinition(definition.get(0));
-        }
-
         // TODO move to REST level instead of here in the builder
         public void validate() {
             // We require a definition to be available until we support other means of supplying the definition
@@ -352,7 +331,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
                 createdTime,
                 modelVersion,
                 modelType,
-                definition,
+                definition == null ? null : definition.build(),
                 metadata);
         }
 
@@ -365,7 +344,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
                 Instant.now(),
                 modelVersion,
                 modelType,
-                definition,
+                definition == null ? null : definition.build(),
                 metadata);
         }
     }

+ 176 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java

@@ -0,0 +1,176 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+public class TrainedModelDefinition implements ToXContentObject, Writeable {
+
+    public static final String NAME = "trained_model_doc";
+
+    public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
+    public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
+
+    // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
+    public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
+    public static final ObjectParser<TrainedModelDefinition.Builder, Void> STRICT_PARSER = createParser(false);
+
+    private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(boolean ignoreUnknownFields) {
+        ObjectParser<TrainedModelDefinition.Builder, Void> parser = new ObjectParser<>(NAME,
+            ignoreUnknownFields,
+            TrainedModelDefinition.Builder::new);
+        parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel,
+            (p, c, n) -> ignoreUnknownFields ?
+                p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
+                p.namedObject(StrictlyParsedTrainedModel.class, n, null),
+            (modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ },
+            TRAINED_MODEL);
+        parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
+            (p, c, n) -> ignoreUnknownFields ?
+                p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
+                p.namedObject(StrictlyParsedPreProcessor.class, n, null),
+            (trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
+            PREPROCESSORS);
+        return parser;
+    }
+
+    public static TrainedModelDefinition.Builder fromXContent(XContentParser parser, boolean lenient) throws IOException {
+        return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
+    }
+
+    private final TrainedModel trainedModel;
+    private final List<PreProcessor> preProcessors;
+
+    TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
+        this.trainedModel = trainedModel;
+        this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
+    }
+
+    public TrainedModelDefinition(StreamInput in) throws IOException {
+        this.trainedModel = in.readNamedWriteable(TrainedModel.class);
+        this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeNamedWriteable(trainedModel);
+        out.writeNamedWriteableList(preProcessors);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        NamedXContentObjectHelper.writeNamedObjects(builder,
+            params,
+            false,
+            TRAINED_MODEL.getPreferredName(),
+            Collections.singletonList(trainedModel));
+        NamedXContentObjectHelper.writeNamedObjects(builder,
+            params,
+            true,
+            PREPROCESSORS.getPreferredName(),
+            preProcessors);
+        builder.endObject();
+        return builder;
+    }
+
+    public TrainedModel getTrainedModel() {
+        return trainedModel;
+    }
+
+    public List<PreProcessor> getPreProcessors() {
+        return preProcessors;
+    }
+
+    @Override
+    public String toString() {
+        return Strings.toString(this);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TrainedModelDefinition that = (TrainedModelDefinition) o;
+        return Objects.equals(trainedModel, that.trainedModel) &&
+            Objects.equals(preProcessors, that.preProcessors) ;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(trainedModel, preProcessors);
+    }
+
+    public static class Builder {
+
+        private List<PreProcessor> preProcessors;
+        private TrainedModel trainedModel;
+        private boolean processorsInOrder;
+
+        private static Builder builderForParser() {
+            return new Builder(false);
+        }
+
+        private Builder(boolean processorsInOrder) {
+            this.processorsInOrder = processorsInOrder;
+        }
+
+        public Builder() {
+            this(true);
+        }
+
+        public Builder setPreProcessors(List<PreProcessor> preProcessors) {
+            this.preProcessors = preProcessors;
+            return this;
+        }
+
+        public Builder setTrainedModel(TrainedModel trainedModel) {
+            this.trainedModel = trainedModel;
+            return this;
+        }
+
+        private Builder setTrainedModel(List<TrainedModel> trainedModel) {
+            if (trainedModel.size() != 1) {
+                throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
+                    TRAINED_MODEL.getPreferredName());
+            }
+            return setTrainedModel(trainedModel.get(0));
+        }
+
+        private void setProcessorsInOrder(boolean value) {
+            this.processorsInOrder = value;
+        }
+
+        public TrainedModelDefinition build() {
+            if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
+                throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
+            }
+            return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
+        }
+    }
+
+}

+ 10 - 7
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java

@@ -14,7 +14,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.test.AbstractSerializingTestCase;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.MlStrings;
 import org.junit.Before;
@@ -65,7 +64,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
             Instant.ofEpochMilli(randomNonNegativeLong()),
             randomBoolean() ? null : randomNonNegativeLong(),
             randomAlphaOfLength(10),
-            randomBoolean() ? null : randomFrom(TreeTests.createRandom()),
+            randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
             randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
     }
 
@@ -97,14 +96,18 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
     public void testValidateWithInvalidID() {
         String modelId = "InvalidID-";
         ElasticsearchException ex = expectThrows(ElasticsearchException.class,
-            () -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate());
+            () -> TrainedModelConfig.builder()
+                .setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
+                .setModelId(modelId).validate());
         assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
     }
 
     public void testValidateWithLongID() {
         String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
         ElasticsearchException ex = expectThrows(ElasticsearchException.class,
-            () -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate());
+            () -> TrainedModelConfig.builder()
+                .setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
+                .setModelId(modelId).validate());
         assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
     }
 
@@ -112,21 +115,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
         String modelId = "simplemodel";
         ElasticsearchException ex = expectThrows(ElasticsearchException.class,
             () -> TrainedModelConfig.builder()
-                .setDefinition(randomFrom(TreeTests.createRandom()))
+                .setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
                 .setCreatedTime(Instant.now())
                 .setModelId(modelId).validate());
         assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation"));
 
         ex = expectThrows(ElasticsearchException.class,
             () -> TrainedModelConfig.builder()
-                .setDefinition(randomFrom(TreeTests.createRandom()))
+                .setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
                 .setVersion(Version.CURRENT)
                 .setModelId(modelId).validate());
         assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
 
         ex = expectThrows(ElasticsearchException.class,
             () -> TrainedModelConfig.builder()
-                .setDefinition(randomFrom(TreeTests.createRandom()))
+                .setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
                 .setCreatedBy("ml_user")
                 .setModelId(modelId).validate());
         assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));

+ 91 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java

@@ -0,0 +1,91 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.inference;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
+import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+
+public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {
+
+    private boolean lenient;
+
+    @Before
+    public void chooseStrictOrLenient() {
+        lenient = randomBoolean();
+    }
+
+    @Override
+    protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
+        return TrainedModelDefinition.fromXContent(parser, lenient).build();
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return lenient;
+    }
+
+    @Override
+    protected Predicate<String> getRandomFieldsExcludeFilter() {
+        return field -> !field.isEmpty();
+    }
+
+    public static TrainedModelDefinition.Builder createRandomBuilder() {
+        int numberOfProcessors = randomIntBetween(1, 10);
+        return new TrainedModelDefinition.Builder()
+            .setPreProcessors(
+                randomBoolean() ? null :
+                    Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
+                        OneHotEncodingTests.createRandom(),
+                        TargetMeanEncodingTests.createRandom()))
+                        .limit(numberOfProcessors)
+                        .collect(Collectors.toList()))
+            .setTrainedModel(randomFrom(TreeTests.createRandom()));
+    }
+    @Override
+    protected TrainedModelDefinition createTestInstance() {
+        return createRandomBuilder().build();
+    }
+
+    @Override
+    protected Writeable.Reader<TrainedModelDefinition> instanceReader() {
+        return TrainedModelDefinition::new;
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
+        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
+        return new NamedXContentRegistry(namedXContent);
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
+        entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
+        return new NamedWriteableRegistry(entries);
+    }
+
+}

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java

@@ -11,7 +11,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
-import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
@@ -93,7 +93,7 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
     private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) {
         return TrainedModelConfig.builder()
             .setCreatedBy("ml_test")
-            .setDefinition(TreeTests.createRandom())
+            .setDefinition(TrainedModelDefinitionTests.createRandomBuilder())
             .setDescription("trained model config for test")
             .setModelId(modelId)
             .setModelType("binary_decision_tree")