Selaa lähdekoodia

[ML][Inference] PUT API (#50852)

This adds the `PUT` API for creating trained models that support our format. 

This includes

* HLRC change for the API
* API creation
* Validations of model format and call
Benjamin Trent 5 vuotta sitten
vanhempi
commit
4cecb7a5be
37 muutettua tiedostoa jossa 1638 lisäystä ja 412 poistoa
  1. 11 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java
  2. 44 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java
  3. 66 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java
  4. 63 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java
  5. 81 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java
  6. 68 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java
  7. 12 7
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java
  8. 5 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java
  9. 19 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java
  10. 62 56
      client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java
  11. 91 54
      client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java
  12. 52 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java
  13. 52 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java
  14. 70 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java
  15. 19 16
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java
  16. 8 6
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java
  17. 1 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java
  18. 53 0
      docs/java-rest/high-level/ml/put-trained-model.asciidoc
  19. 2 0
      docs/java-rest/high-level/supported-apis.asciidoc
  20. 137 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java
  21. 99 25
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java
  22. 6 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java
  23. 4 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java
  24. 45 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java
  25. 45 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java
  26. 22 18
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java
  27. 0 1
      x-pack/plugin/ml/qa/ml-with-security/build.gradle
  28. 37 44
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java
  29. 66 63
      x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java
  30. 6 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  31. 190 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java
  32. 2 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java
  33. 42 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java
  34. 25 81
      x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java
  35. 0 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java
  36. 28 0
      x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json
  37. 105 38
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

+ 11 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

@@ -73,6 +73,7 @@ import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest;
 import org.elasticsearch.client.ml.PutDatafeedRequest;
 import org.elasticsearch.client.ml.PutFilterRequest;
 import org.elasticsearch.client.ml.PutJobRequest;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
 import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
 import org.elasticsearch.client.ml.SetUpgradeModeRequest;
 import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
@@ -792,6 +793,16 @@ final class MLRequestConverters {
         return new Request(HttpDelete.METHOD_NAME, endpoint);
     }
 
+    static Request putTrainedModel(PutTrainedModelRequest putTrainedModelRequest) throws IOException {
+        String endpoint = new EndpointBuilder()
+            .addPathPartAsIs("_ml", "inference")
+            .addPathPart(putTrainedModelRequest.getTrainedModelConfig().getModelId())
+            .build();
+        Request request = new Request(HttpPut.METHOD_NAME, endpoint);
+        request.setEntity(createEntity(putTrainedModelRequest, REQUEST_BODY_CONTENT_TYPE));
+        return request;
+    }
+
     static Request putFilter(PutFilterRequest putFilterRequest) throws IOException {
         String endpoint = new EndpointBuilder()
             .addPathPartAsIs("_ml")

+ 44 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java

@@ -100,6 +100,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
 import org.elasticsearch.client.ml.PutFilterResponse;
 import org.elasticsearch.client.ml.PutJobRequest;
 import org.elasticsearch.client.ml.PutJobResponse;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
+import org.elasticsearch.client.ml.PutTrainedModelResponse;
 import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
 import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
 import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@@ -2340,6 +2342,48 @@ public final class MachineLearningClient {
             Collections.emptySet());
     }
 
+    /**
+     * Put trained model config
+     * <p>
+     * For additional info
+     * see <a href="TODO">
+     *     PUT Trained Model Config documentation</a>
+     *
+     * @param request The {@link PutTrainedModelRequest}
+     * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+     * @return {@link PutTrainedModelResponse} response object
+     */
+    public PutTrainedModelResponse putTrainedModel(PutTrainedModelRequest request, RequestOptions options) throws IOException {
+        return restHighLevelClient.performRequestAndParseEntity(request,
+            MLRequestConverters::putTrainedModel,
+            options,
+            PutTrainedModelResponse::fromXContent,
+            Collections.emptySet());
+    }
+
+    /**
+     * Put trained model config asynchronously and notifies listener upon completion
+     * <p>
+     * For additional info
+     * see <a href="TODO">
+     *     PUT Trained Model Config documentation</a>
+     *
+     * @param request The {@link PutTrainedModelRequest}
+     * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+     * @param listener Listener to be notified upon request completion
+     * @return cancellable that may be used to cancel the request
+     */
+    public Cancellable putTrainedModelAsync(PutTrainedModelRequest request,
+                                            RequestOptions options,
+                                            ActionListener<PutTrainedModelResponse> listener) {
+        return restHighLevelClient.performRequestAsyncAndParseEntity(request,
+            MLRequestConverters::putTrainedModel,
+            options,
+            PutTrainedModelResponse::fromXContent,
+            listener,
+            Collections.emptySet());
+    }
+
     /**
      * Gets trained model stats
      * <p>

+ 66 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelRequest.java

@@ -0,0 +1,66 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.Validatable;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+
+public class PutTrainedModelRequest implements Validatable, ToXContentObject {
+
+    private final TrainedModelConfig config;
+
+    public PutTrainedModelRequest(TrainedModelConfig config) {
+        this.config = config;
+    }
+
+    public TrainedModelConfig getTrainedModelConfig() {
+        return config;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        return config.toXContent(builder, params);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        PutTrainedModelRequest request = (PutTrainedModelRequest) o;
+        return Objects.equals(config, request.config);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(config);
+    }
+
+    @Override
+    public final String toString() {
+        return Strings.toString(config);
+    }
+}

+ 63 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutTrainedModelResponse.java

@@ -0,0 +1,63 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+
+public class PutTrainedModelResponse implements ToXContentObject {
+
+    private final TrainedModelConfig trainedModelConfig;
+
+    public static PutTrainedModelResponse fromXContent(XContentParser parser) throws IOException {
+        return new PutTrainedModelResponse(TrainedModelConfig.PARSER.parse(parser, null).build());
+    }
+
+    public PutTrainedModelResponse(TrainedModelConfig trainedModelConfig) {
+        this.trainedModelConfig = trainedModelConfig;
+    }
+
+    public TrainedModelConfig getResponse() {
+        return trainedModelConfig;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        return trainedModelConfig.toXContent(builder, params);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        PutTrainedModelResponse response = (PutTrainedModelResponse) o;
+        return Objects.equals(trainedModelConfig, response.trainedModelConfig);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(trainedModelConfig);
+    }
+}

+ 81 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressor.java

@@ -0,0 +1,81 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.inference;
+
+import org.elasticsearch.common.CheckedFunction;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.io.Streams;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.xcontent.DeprecationHandler;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.common.xcontent.XContentType;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.GZIPOutputStream;
+
+/**
+ * Collection of helper methods. Similar to CompressedXContent, but this utilizes GZIP.
+ */
+public final class InferenceToXContentCompressor {
+    private static final int BUFFER_SIZE = 4096;
+    private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum
+
+    private InferenceToXContentCompressor() {}
+
+    public static <T extends ToXContentObject> String deflate(T objectToCompress) throws IOException {
+        BytesReference reference = XContentHelper.toXContent(objectToCompress, XContentType.JSON, false);
+        return deflate(reference);
+    }
+
+    public static <T> T inflate(String compressedString,
+                         CheckedFunction<XContentParser, T, IOException> parserFunction,
+                         NamedXContentRegistry xContentRegistry) throws IOException {
+        try(XContentParser parser = XContentHelper.createParser(xContentRegistry,
+            DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
+            inflate(compressedString, MAX_INFLATED_BYTES),
+            XContentType.JSON)) {
+            return parserFunction.apply(parser);
+        }
+    }
+
+    static BytesReference inflate(String compressedString, long streamSize) throws IOException {
+        byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
+        InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
+        InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize);
+        return Streams.readFully(inflateStream);
+    }
+
+    private static String deflate(BytesReference reference) throws IOException {
+        BytesStreamOutput out = new BytesStreamOutput();
+        try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
+            reference.writeTo(compressedOutput);
+        }
+        return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
+    }
+}

+ 68 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/SimpleBoundedInputStream.java

@@ -0,0 +1,68 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference;
+
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Objects;
+
+/**
+ * This is a pared down bounded input stream.
+ * Only read is specifically enforced.
+ */
+final class SimpleBoundedInputStream extends InputStream {
+
+    private final InputStream in;
+    private final long maxBytes;
+    private long numBytes;
+
+    SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
+        this.in = Objects.requireNonNull(inputStream, "inputStream");
+        if (maxBytes < 0) {
+            throw new IllegalArgumentException("[maxBytes] must be greater than or equal to 0");
+        }
+        this.maxBytes = maxBytes;
+    }
+
+
+    /**
+     * A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
+     * @return The byte read. -1 on internal stream completion or when maxBytes is exceeded.
+     * @throws IOException on failure
+     */
+    @Override
+    public int read() throws IOException {
+        // We have reached the maximum, signal stream completion.
+        if (numBytes >= maxBytes) {
+            return -1;
+        }
+        numBytes++;
+        return in.read();
+    }
+
+    /**
+     * Delegates `close` to the wrapped InputStream
+     * @throws IOException on failure
+     */
+    @Override
+    public void close() throws IOException {
+        in.close();
+    }
+}

+ 12 - 7
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java

@@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
 import java.time.Instant;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -111,7 +112,7 @@ public class TrainedModelConfig implements ToXContentObject {
         this.modelId = modelId;
         this.createdBy = createdBy;
         this.version = version;
-        this.createTime = Instant.ofEpochMilli(createTime.toEpochMilli());
+        this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli());
         this.definition = definition;
         this.compressedDefinition = compressedDefinition;
         this.description = description;
@@ -293,12 +294,12 @@ public class TrainedModelConfig implements ToXContentObject {
             return this;
         }
 
-        public Builder setCreatedBy(String createdBy) {
+        private Builder setCreatedBy(String createdBy) {
             this.createdBy = createdBy;
             return this;
         }
 
-        public Builder setVersion(Version version) {
+        private Builder setVersion(Version version) {
             this.version = version;
             return this;
         }
@@ -312,7 +313,7 @@ public class TrainedModelConfig implements ToXContentObject {
             return this;
         }
 
-        public Builder setCreateTime(Instant createTime) {
+        private Builder setCreateTime(Instant createTime) {
             this.createTime = createTime;
             return this;
         }
@@ -322,6 +323,10 @@ public class TrainedModelConfig implements ToXContentObject {
             return this;
         }
 
+        public Builder setTags(String... tags) {
+            return setTags(Arrays.asList(tags));
+        }
+
         public Builder setMetadata(Map<String, Object> metadata) {
             this.metadata = metadata;
             return this;
@@ -347,17 +352,17 @@ public class TrainedModelConfig implements ToXContentObject {
             return this;
         }
 
-        public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
+        private Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
             this.estimatedHeapMemory = estimatedHeapMemory;
             return this;
         }
 
-        public Builder setEstimatedOperations(Long estimatedOperations) {
+        private Builder setEstimatedOperations(Long estimatedOperations) {
             this.estimatedOperations = estimatedOperations;
             return this;
         }
 
-        public Builder setLicenseLevel(String licenseLevel) {
+        private Builder setLicenseLevel(String licenseLevel) {
             this.licenseLevel = licenseLevel;
             return this;
         }

+ 5 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java

@@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 
@@ -48,6 +49,10 @@ public class TrainedModelInput implements ToXContentObject {
         this.fieldNames = fieldNames;
     }
 
+    public TrainedModelInput(String... fieldNames) {
+        this(Arrays.asList(fieldNames));
+    }
+
     public static TrainedModelInput fromXContent(XContentParser parser) throws IOException {
         return PARSER.parse(parser, null);
     }

+ 19 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

@@ -71,6 +71,7 @@ import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest;
 import org.elasticsearch.client.ml.PutDatafeedRequest;
 import org.elasticsearch.client.ml.PutFilterRequest;
 import org.elasticsearch.client.ml.PutJobRequest;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
 import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
 import org.elasticsearch.client.ml.SetUpgradeModeRequest;
 import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest;
@@ -91,6 +92,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider;
 import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
 import org.elasticsearch.client.ml.job.config.AnalysisConfig;
 import org.elasticsearch.client.ml.job.config.Detector;
 import org.elasticsearch.client.ml.job.config.Job;
@@ -874,6 +878,20 @@ public class MLRequestConvertersTests extends ESTestCase {
         assertNull(request.getEntity());
     }
 
+    public void testPutTrainedModel() throws IOException {
+        TrainedModelConfig trainedModelConfig = TrainedModelConfigTests.createTestTrainedModelConfig();
+        PutTrainedModelRequest putTrainedModelRequest = new PutTrainedModelRequest(trainedModelConfig);
+
+        Request request = MLRequestConverters.putTrainedModel(putTrainedModelRequest);
+
+        assertEquals(HttpPut.METHOD_NAME, request.getMethod());
+        assertThat(request.getEndpoint(), equalTo("/_ml/inference/" + trainedModelConfig.getModelId()));
+        try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) {
+            TrainedModelConfig parsedTrainedModelConfig = TrainedModelConfig.PARSER.apply(parser, null).build();
+            assertThat(parsedTrainedModelConfig, equalTo(trainedModelConfig));
+        }
+    }
+
     public void testPutFilter() throws IOException {
         MlFilter filter = MlFilterTests.createRandomBuilder("foo").build();
         PutFilterRequest putFilterRequest = new PutFilterRequest(filter);
@@ -1046,6 +1064,7 @@ public class MLRequestConvertersTests extends ESTestCase {
         namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents());
         namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers());
         namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers());
+        namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
         return new NamedXContentRegistry(namedXContent);
     }
 

+ 62 - 56
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -101,6 +101,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
 import org.elasticsearch.client.ml.PutFilterResponse;
 import org.elasticsearch.client.ml.PutJobRequest;
 import org.elasticsearch.client.ml.PutJobResponse;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
+import org.elasticsearch.client.ml.PutTrainedModelResponse;
 import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
 import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
 import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@@ -146,9 +148,12 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal
 import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
 import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
 import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.client.ml.inference.TrainedModelConfig;
 import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
 import org.elasticsearch.client.ml.inference.TrainedModelStats;
 import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
@@ -162,14 +167,12 @@ import org.elasticsearch.client.ml.job.config.MlFilter;
 import org.elasticsearch.client.ml.job.process.ModelSnapshot;
 import org.elasticsearch.client.ml.job.stats.JobStats;
 import org.elasticsearch.common.bytes.BytesArray;
-import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.unit.ByteSizeUnit;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentFactory;
-import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.index.query.MatchAllQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
@@ -178,11 +181,9 @@ import org.elasticsearch.search.SearchHit;
 import org.junit.After;
 
 import java.io.IOException;
-import java.io.OutputStream;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Base64;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -190,7 +191,6 @@ import java.util.Locale;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
-import java.util.zip.GZIPOutputStream;
 
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.anyOf;
@@ -2192,6 +2192,50 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         }
     }
 
+    public void testPutTrainedModel() throws Exception {
+        String modelId = "test-put-trained-model";
+        String modelIdCompressed = "test-put-trained-model-compressed-definition";
+
+        MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+
+        TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+        TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+            .setDefinition(definition)
+            .setModelId(modelId)
+            .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+            .setDescription("test model")
+            .build();
+        PutTrainedModelResponse putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig),
+            machineLearningClient::putTrainedModel,
+            machineLearningClient::putTrainedModelAsync);
+        TrainedModelConfig createdModel = putTrainedModelResponse.getResponse();
+        assertThat(createdModel.getModelId(), equalTo(modelId));
+
+        definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+        trainedModelConfig = TrainedModelConfig.builder()
+            .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition))
+            .setModelId(modelIdCompressed)
+            .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+            .setDescription("test model")
+            .build();
+        putTrainedModelResponse = execute(new PutTrainedModelRequest(trainedModelConfig),
+            machineLearningClient::putTrainedModel,
+            machineLearningClient::putTrainedModelAsync);
+        createdModel = putTrainedModelResponse.getResponse();
+        assertThat(createdModel.getModelId(), equalTo(modelIdCompressed));
+
+        GetTrainedModelsResponse getTrainedModelsResponse = execute(
+            new GetTrainedModelsRequest(modelIdCompressed).setDecompressDefinition(true).setIncludeDefinition(true),
+            machineLearningClient::getTrainedModels,
+            machineLearningClient::getTrainedModelsAsync);
+
+        assertThat(getTrainedModelsResponse.getCount(), equalTo(1L));
+        assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(1));
+        assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getCompressedDefinition(), is(nullValue()));
+        assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getDefinition(), is(not(nullValue())));
+        assertThat(getTrainedModelsResponse.getTrainedModels().get(0).getModelId(), equalTo(modelIdCompressed));
+    }
+
     public void testGetTrainedModelsStats() throws Exception {
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
         String modelIdPrefix = "a-get-trained-model-stats-";
@@ -2474,56 +2518,13 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 
     private void putTrainedModel(String modelId) throws IOException {
         TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
-        highLevelClient().index(
-            new IndexRequest(".ml-inference-000001")
-                .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-                .source(modelConfigString(modelId), XContentType.JSON)
-                .id(modelId),
-            RequestOptions.DEFAULT);
-
-        highLevelClient().index(
-            new IndexRequest(".ml-inference-000001")
-                .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-                .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
-                .id("trained_model_definition_doc-" + modelId + "-0"),
-            RequestOptions.DEFAULT);
-    }
-
-    private String compressDefinition(TrainedModelDefinition definition) throws IOException {
-        BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
-        BytesStreamOutput out = new BytesStreamOutput();
-        try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
-            reference.writeTo(compressedOutput);
-        }
-        return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
-    }
-
-    private static String modelConfigString(String modelId) {
-        return "{\n" +
-            "  \"doc_type\": \"trained_model_config\",\n" +
-            "  \"model_id\": \"" + modelId + "\",\n" +
-            "  \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
-            "  \"description\": \"test model\",\n" +
-            "  \"version\": \"7.6.0\",\n" +
-            "  \"license_level\": \"platinum\",\n" +
-            "  \"created_by\": \"ml_test\",\n" +
-            "  \"estimated_heap_memory_usage_bytes\": 0," +
-            "  \"estimated_operations\": 0," +
-            "  \"created_time\": 0\n" +
-            "}";
-    }
-
-    private static String modelDocString(String compressedDefinition, String modelId) {
-        return "" +
-            "{" +
-            "\"model_id\": \"" + modelId + "\",\n" +
-            "\"doc_num\": 0,\n" +
-            "\"doc_type\": \"trained_model_definition_doc\",\n" +
-            "  \"compression_version\": " + 1 + ",\n" +
-            "  \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
-            "  \"definition_length\": " + compressedDefinition.length() + ",\n" +
-            "\"definition\": \"" + compressedDefinition + "\"\n" +
-            "}";
+        TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+            .setDefinition(definition)
+            .setModelId(modelId)
+            .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+            .setDescription("test model")
+            .build();
+        highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
     }
 
     private void waitForJobToClose(String jobId) throws Exception {
@@ -2768,4 +2769,9 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         mlInfoResponse = machineLearningClient.getMlInfo(new MlInfoRequest(), RequestOptions.DEFAULT);
         assertThat(mlInfoResponse.getInfo().get("upgrade_mode"), equalTo(false));
     }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
 }

+ 91 - 54
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -114,6 +114,8 @@ import org.elasticsearch.client.ml.PutFilterRequest;
 import org.elasticsearch.client.ml.PutFilterResponse;
 import org.elasticsearch.client.ml.PutJobRequest;
 import org.elasticsearch.client.ml.PutJobResponse;
+import org.elasticsearch.client.ml.PutTrainedModelRequest;
+import org.elasticsearch.client.ml.PutTrainedModelResponse;
 import org.elasticsearch.client.ml.RevertModelSnapshotRequest;
 import org.elasticsearch.client.ml.RevertModelSnapshotResponse;
 import org.elasticsearch.client.ml.SetUpgradeModeRequest;
@@ -162,10 +164,14 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal
 import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
 import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
 import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
+import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.client.ml.inference.TrainedModelConfig;
 import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
 import org.elasticsearch.client.ml.inference.TrainedModelStats;
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.client.ml.job.config.AnalysisConfig;
 import org.elasticsearch.client.ml.job.config.AnalysisLimits;
 import org.elasticsearch.client.ml.job.config.DataDescription;
@@ -186,12 +192,11 @@ import org.elasticsearch.client.ml.job.results.Influencer;
 import org.elasticsearch.client.ml.job.results.OverallBucket;
 import org.elasticsearch.client.ml.job.stats.JobStats;
 import org.elasticsearch.common.bytes.BytesReference;
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.unit.ByteSizeUnit;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentFactory;
-import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.index.query.MatchAllQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
@@ -202,12 +207,10 @@ import org.elasticsearch.tasks.TaskId;
 import org.junit.After;
 
 import java.io.IOException;
-import java.io.OutputStream;
 import java.nio.charset.StandardCharsets;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.Arrays;
-import java.util.Base64;
 import java.util.Collections;
 import java.util.Date;
 import java.util.HashMap;
@@ -216,7 +219,6 @@ import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
-import java.util.zip.GZIPOutputStream;
 
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.closeTo;
@@ -3625,6 +3627,79 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
         }
     }
 
+    public void testPutTrainedModel() throws Exception {
+        TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+        // tag::put-trained-model-config
+        TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+            .setDefinition(definition) // <1>
+            .setCompressedDefinition(InferenceToXContentCompressor.deflate(definition)) // <2>
+            .setModelId("my-new-trained-model") // <3>
+            .setInput(new TrainedModelInput("col1", "col2", "col3", "col4")) // <4>
+            .setDescription("test model") // <5>
+            .setMetadata(new HashMap<>()) // <6>
+            .setTags("my_regression_models") // <7>
+            .build();
+        // end::put-trained-model-config
+
+        trainedModelConfig = TrainedModelConfig.builder()
+            .setDefinition(definition)
+            .setModelId("my-new-trained-model")
+            .setInput(new TrainedModelInput("col1", "col2", "col3", "col4"))
+            .setDescription("test model")
+            .setMetadata(new HashMap<>())
+            .setTags("my_regression_models")
+            .build();
+
+        RestHighLevelClient client = highLevelClient();
+        {
+            // tag::put-trained-model-request
+            PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); // <1>
+            // end::put-trained-model-request
+
+            // tag::put-trained-model-execute
+            PutTrainedModelResponse response = client.machineLearning().putTrainedModel(request, RequestOptions.DEFAULT);
+            // end::put-trained-model-execute
+
+            // tag::put-trained-model-response
+            TrainedModelConfig model = response.getResponse();
+            // end::put-trained-model-response
+
+            assertThat(model.getModelId(), equalTo(trainedModelConfig.getModelId()));
+            highLevelClient().machineLearning()
+                .deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
+        }
+        {
+            PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig);
+            
+            // tag::put-trained-model-execute-listener
+            ActionListener<PutTrainedModelResponse> listener = new ActionListener<>() {
+                @Override
+                public void onResponse(PutTrainedModelResponse response) {
+                    // <1>
+                }
+
+                @Override
+                public void onFailure(Exception e) {
+                    // <2>
+                }
+            };
+            // end::put-trained-model-execute-listener
+
+            // Replace the empty listener by a blocking listener in test
+            CountDownLatch latch = new CountDownLatch(1);
+            listener = new LatchedActionListener<>(listener, latch);
+
+            // tag::put-trained-model-execute-async
+            client.machineLearning().putTrainedModelAsync(request, RequestOptions.DEFAULT, listener); // <1>
+            // end::put-trained-model-execute-async
+
+            assertTrue(latch.await(30L, TimeUnit.SECONDS));
+
+            highLevelClient().machineLearning()
+                .deleteTrainedModel(new DeleteTrainedModelRequest("my-new-trained-model"), RequestOptions.DEFAULT);
+        }
+    }
+
     public void testGetTrainedModelsStats() throws Exception {
         putTrainedModel("my-trained-model");
         RestHighLevelClient client = highLevelClient();
@@ -4088,57 +4163,19 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
     }
 
     private void putTrainedModel(String modelId) throws IOException {
-        TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
-        highLevelClient().index(
-            new IndexRequest(".ml-inference-000001")
-                .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-                .source(modelConfigString(modelId), XContentType.JSON)
-                .id(modelId),
-            RequestOptions.DEFAULT);
-
-        highLevelClient().index(
-            new IndexRequest(".ml-inference-000001")
-                .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-                .source(modelDocString(compressDefinition(definition), modelId), XContentType.JSON)
-                .id("trained_model_definition_doc-" + modelId + "-0"),
-            RequestOptions.DEFAULT);
-    }
-
-    private String compressDefinition(TrainedModelDefinition definition) throws IOException {
-        BytesReference reference = XContentHelper.toXContent(definition, XContentType.JSON, false);
-        BytesStreamOutput out = new BytesStreamOutput();
-        try (OutputStream compressedOutput = new GZIPOutputStream(out, 4096)) {
-            reference.writeTo(compressedOutput);
-        }
-        return new String(Base64.getEncoder().encode(BytesReference.toBytes(out.bytes())), StandardCharsets.UTF_8);
-    }
-
-    private static String modelConfigString(String modelId) {
-        return "{\n" +
-            "  \"doc_type\": \"trained_model_config\",\n" +
-            "  \"model_id\": \"" + modelId + "\",\n" +
-            "  \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
-            "  \"description\": \"test model for\",\n" +
-            "  \"version\": \"7.6.0\",\n" +
-            "  \"license_level\": \"platinum\",\n" +
-            "  \"created_by\": \"ml_test\",\n" +
-            "  \"estimated_heap_memory_usage_bytes\": 0," +
-            "  \"estimated_operations\": 0," +
-            "  \"created_time\": 0\n" +
-            "}";
+        TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
+        TrainedModelConfig trainedModelConfig = TrainedModelConfig.builder()
+            .setDefinition(definition)
+            .setModelId(modelId)
+            .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3", "col4")))
+            .setDescription("test model")
+            .build();
+        highLevelClient().machineLearning().putTrainedModel(new PutTrainedModelRequest(trainedModelConfig), RequestOptions.DEFAULT);
     }
 
-    private static String modelDocString(String compressedDefinition, String modelId) {
-        return "" +
-            "{" +
-            "\"model_id\": \"" + modelId + "\",\n" +
-            "\"doc_num\": 0,\n" +
-            "\"doc_type\": \"trained_model_definition_doc\",\n" +
-            "  \"compression_version\": " + 1 + ",\n" +
-            "  \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
-            "  \"definition_length\": " + compressedDefinition.length() + ",\n" +
-            "\"definition\": \"" + compressedDefinition + "\"\n" +
-            "}";
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
     }
 
     private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG =

+ 52 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionRequestTests.java

@@ -0,0 +1,52 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class PutTrainedModelActionRequestTests extends AbstractXContentTestCase<PutTrainedModelRequest> {
+
+    @Override
+    protected PutTrainedModelRequest createTestInstance() {
+        return new PutTrainedModelRequest(TrainedModelConfigTests.createTestTrainedModelConfig());
+    }
+
+    @Override
+    protected PutTrainedModelRequest doParseInstance(XContentParser parser) throws IOException {
+        return new PutTrainedModelRequest(TrainedModelConfig.PARSER.apply(parser, null).build());
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return false;
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+}

+ 52 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutTrainedModelActionResponseTests.java

@@ -0,0 +1,52 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+import java.io.IOException;
+
+public class PutTrainedModelActionResponseTests extends AbstractXContentTestCase<PutTrainedModelResponse> {
+
+    @Override
+    protected PutTrainedModelResponse createTestInstance() {
+        return new PutTrainedModelResponse(TrainedModelConfigTests.createTestTrainedModelConfig());
+    }
+
+    @Override
+    protected PutTrainedModelResponse doParseInstance(XContentParser parser) throws IOException {
+        return new PutTrainedModelResponse(TrainedModelConfig.PARSER.apply(parser, null).build());
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return false;
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+}

+ 70 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/InferenceToXContentCompressorTests.java

@@ -0,0 +1,70 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.client.ml.inference;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.common.xcontent.XContentType;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class InferenceToXContentCompressorTests extends ESTestCase {
+
+    public void testInflateAndDeflate() throws IOException {
+        for(int i = 0; i < 10; i++) {
+            TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
+            String firstDeflate = InferenceToXContentCompressor.deflate(definition);
+            TrainedModelDefinition inflatedDefinition = InferenceToXContentCompressor.inflate(firstDeflate,
+                parser -> TrainedModelDefinition.fromXContent(parser).build(),
+                xContentRegistry());
+
+            // Did we inflate to the same object?
+            assertThat(inflatedDefinition, equalTo(definition));
+        }
+    }
+
+    public void testInflateTooLargeStream() throws IOException {
+        TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
+        String firstDeflate = InferenceToXContentCompressor.deflate(definition);
+        BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L);
+        assertThat(inflatedBytes.length(), equalTo(10));
+        try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
+            LoggingDeprecationHandler.INSTANCE,
+            inflatedBytes,
+            XContentType.JSON)) {
+            expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser));
+        }
+    }
+
+    public void testInflateGarbage() {
+        expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L));
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+}

+ 19 - 16
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java

@@ -37,6 +37,24 @@ import java.util.stream.Stream;
 
 public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedModelConfig> {
 
+    public static TrainedModelConfig createTestTrainedModelConfig() {
+        return new TrainedModelConfig(
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            Version.CURRENT,
+            randomBoolean() ? null : randomAlphaOfLength(100),
+            Instant.ofEpochMilli(randomNonNegativeLong()),
+            randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
+            randomBoolean() ? null : randomAlphaOfLength(100),
+            randomBoolean() ? null :
+                Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
+            randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
+            randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
+            randomBoolean() ? null : randomNonNegativeLong(),
+            randomBoolean() ? null : randomNonNegativeLong(),
+            randomBoolean() ? null : randomFrom("platinum", "basic"));
+    }
+
     @Override
     protected TrainedModelConfig doParseInstance(XContentParser parser) throws IOException {
         return TrainedModelConfig.fromXContent(parser);
@@ -54,22 +72,7 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
 
     @Override
     protected TrainedModelConfig createTestInstance() {
-        return new TrainedModelConfig(
-            randomAlphaOfLength(10),
-            randomAlphaOfLength(10),
-            Version.CURRENT,
-            randomBoolean() ? null : randomAlphaOfLength(100),
-            Instant.ofEpochMilli(randomNonNegativeLong()),
-            randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
-            randomBoolean() ? null : randomAlphaOfLength(100),
-            randomBoolean() ? null :
-                Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
-            randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
-            randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
-            randomBoolean() ? null : randomNonNegativeLong(),
-            randomBoolean() ? null : randomNonNegativeLong(),
-            randomBoolean() ? null : randomFrom("platinum", "basic"));
-
+        return createTestTrainedModelConfig();
     }
 
     @Override

+ 8 - 6
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java

@@ -67,15 +67,17 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
             .collect(Collectors.toList());
         int numberOfModels = randomIntBetween(1, 10);
         List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType))
-            .limit(numberOfFeatures)
+            .limit(numberOfModels)
             .collect(Collectors.toList());
-        OutputAggregator outputAggregator = null;
-        if (randomBoolean()) {
-            List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
-            outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights), new LogisticRegression(weights));
+        List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
+        List<OutputAggregator> possibleAggregators = new ArrayList<>(Arrays.asList(new WeightedMode(weights),
+            new LogisticRegression(weights)));
+        if (targetType.equals(TargetType.REGRESSION)) {
+            possibleAggregators.add(new WeightedSum(weights));
         }
+        OutputAggregator outputAggregator = randomFrom(possibleAggregators.toArray(new OutputAggregator[0]));
         List<String> categoryLabels = null;
-        if (randomBoolean()) {
+        if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
             categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
         }
         return new Ensemble(featureNames,

+ 1 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java

@@ -84,7 +84,7 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
             childNodes = nextNodes;
         }
         List<String> categoryLabels = null;
-        if (randomBoolean()) {
+        if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
             categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
         }
         return builder.setClassificationLabels(categoryLabels)

+ 53 - 0
docs/java-rest/high-level/ml/put-trained-model.asciidoc

@@ -0,0 +1,53 @@
+--
+:api: put-trained-model
+:request: PutTrainedModelRequest
+:response: PutTrainedModelResponse
+--
+[role="xpack"]
+[id="{upid}-{api}"]
+=== Put Trained Model API
+
+Creates a new trained model for inference.
+The API accepts a +{request}+ object as a request and returns a +{response}+.
+
+[id="{upid}-{api}-request"]
+==== Put Trained Model request
+
+A +{request}+ requires the following argument:
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-request]
+--------------------------------------------------
+<1> The configuration of the {infer} Trained Model to create
+
+[id="{upid}-{api}-config"]
+==== Trained Model configuration
+
+The `TrainedModelConfig` object contains all the details about the trained model
+configuration and contains the following arguments:
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-config]
+--------------------------------------------------
+<1> The {infer} definition for the model
+<2> Optionally, if the {infer} definition is large, you may choose to compress it for transport.
+    Do not supply both the compressed and uncompressed definitions.
+<3> The unique model id
+<4> The input field names for the model definition
+<5> Optionally, a human-readable description
+<6> Optionally, an object map contain metadata about the model
+<7> Optionally, an array of tags to organize the model
+
+include::../execution.asciidoc[]
+
+[id="{upid}-{api}-response"]
+==== Response
+
+The returned +{response}+ contains the newly created trained model.
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-response]
+--------------------------------------------------

+ 2 - 0
docs/java-rest/high-level/supported-apis.asciidoc

@@ -304,6 +304,7 @@ The Java High Level REST Client supports the following Machine Learning APIs:
 * <<{upid}-evaluate-data-frame>>
 * <<{upid}-explain-data-frame-analytics>>
 * <<{upid}-get-trained-models>>
+* <<{upid}-put-trained-model>>
 * <<{upid}-get-trained-models-stats>>
 * <<{upid}-delete-trained-model>>
 * <<{upid}-put-filter>>
@@ -359,6 +360,7 @@ include::ml/stop-data-frame-analytics.asciidoc[]
 include::ml/evaluate-data-frame.asciidoc[]
 include::ml/explain-data-frame-analytics.asciidoc[]
 include::ml/get-trained-models.asciidoc[]
+include::ml/put-trained-model.asciidoc[]
 include::ml/get-trained-models-stats.asciidoc[]
 include::ml/delete-trained-model.asciidoc[]
 include::ml/put-filter.asciidoc[]

+ 137 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java

@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.master.AcknowledgedRequest;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+
+import java.io.IOException;
+import java.util.Objects;
+
+
+public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Response> {
+
+    public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction();
+    public static final String NAME = "cluster:monitor/xpack/ml/inference/put";
+    private PutTrainedModelAction() {
+        super(NAME, Response::new);
+    }
+
+    public static class Request extends AcknowledgedRequest<Request> {
+
+        public static Request parseRequest(String modelId, XContentParser parser) {
+            TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null);
+
+            if (builder.getModelId() == null) {
+                builder.setModelId(modelId).build();
+            } else if (!Strings.isNullOrEmpty(modelId) && !modelId.equals(builder.getModelId())) {
+                // If we have model_id in both URI and body, they must be identical
+                throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID,
+                    TrainedModelConfig.MODEL_ID.getPreferredName(),
+                    builder.getModelId(),
+                    modelId));
+            }
+            // Validations are done against the builder so we can build the full config object.
+            // This allows us to not worry about serializing a builder class between nodes.
+            return new Request(builder.validate(true).build());
+        }
+
+        private final TrainedModelConfig config;
+
+        public Request(TrainedModelConfig config) {
+            this.config = config;
+        }
+
+        public Request(StreamInput in) throws IOException {
+            super(in);
+            this.config = new TrainedModelConfig(in);
+        }
+
+        public TrainedModelConfig getTrainedModelConfig() {
+            return config;
+        }
+
+        @Override
+        public ActionRequestValidationException validate() {
+            return null;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            super.writeTo(out);
+            config.writeTo(out);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Request request = (Request) o;
+            return Objects.equals(config, request.config);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(config);
+        }
+
+        @Override
+        public final String toString() {
+            return Strings.toString(config);
+        }
+    }
+
+    public static class Response extends ActionResponse implements ToXContentObject {
+
+        private final TrainedModelConfig trainedModelConfig;
+
+        public Response(TrainedModelConfig trainedModelConfig) {
+            this.trainedModelConfig = trainedModelConfig;
+        }
+
+        public Response(StreamInput in) throws IOException {
+            super(in);
+            trainedModelConfig = new TrainedModelConfig(in);
+        }
+
+        public TrainedModelConfig getResponse() {
+            return trainedModelConfig;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            trainedModelConfig.writeTo(out);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            return trainedModelConfig.toXContent(builder, params);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Response response = (Response) o;
+            return Objects.equals(trainedModelConfig, response.trainedModelConfig);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(trainedModelConfig);
+        }
+    }
+}

+ 99 - 25
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

@@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.inference;
 
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
@@ -34,6 +35,9 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.action.ValidateActions.addValidationError;
 
 
 public class TrainedModelConfig implements ToXContentObject, Writeable {
@@ -352,13 +356,31 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         private Long estimatedHeapMemory;
         private Long estimatedOperations;
         private LazyModelDefinition definition;
-        private String licenseLevel = License.OperationMode.PLATINUM.description();
+        private String licenseLevel;
+
+        public Builder() {}
+
+        public Builder(TrainedModelConfig config) {
+            this.modelId = config.getModelId();
+            this.createdBy = config.getCreatedBy();
+            this.version = config.getVersion();
+            this.createTime = config.getCreateTime();
+            this.definition = config.definition == null ? null : new LazyModelDefinition(config.definition);
+            this.description = config.getDescription();
+            this.tags = config.getTags();
+            this.metadata = config.getMetadata();
+            this.input = config.getInput();
+        }
 
         public Builder setModelId(String modelId) {
             this.modelId = modelId;
             return this;
         }
 
+        public String getModelId() {
+            return this.modelId;
+        }
+
         public Builder setCreatedBy(String createdBy) {
             this.createdBy = createdBy;
             return this;
@@ -466,51 +488,96 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
             return this;
         }
 
-        // TODO move to REST level instead of here in the builder
-        public void validate() {
-            // We require a definition to be available here even though it will be stored in a different doc
-            ExceptionsHelper.requireNonNull(definition, DEFINITION);
-            ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
+        public Builder validate() {
+            return validate(false);
+        }
 
-            if (MlStrings.isValidId(modelId) == false) {
-                throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INVALID_ID, MODEL_ID.getPreferredName(), modelId));
+        /**
+         * Runs validations against the builder.
+         * @return The current builder object if validations are successful
+         * @throws ActionRequestValidationException when there are validation failures.
+         */
+        public Builder validate(boolean forCreation) {
+            // We require a definition to be available here even though it will be stored in a different doc
+            ActionRequestValidationException validationException = null;
+            if (definition == null) {
+                validationException = addValidationError("[" + DEFINITION.getPreferredName() + "] must not be null.", validationException);
+            }
+            if (modelId == null) {
+                validationException = addValidationError("[" + MODEL_ID.getPreferredName() + "] must not be null.", validationException);
             }
 
-            if (MlStrings.hasValidLengthForId(modelId) == false) {
-                throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.ID_TOO_LONG,
-                    MODEL_ID.getPreferredName(),
+            if (modelId != null && MlStrings.isValidId(modelId) == false) {
+                validationException = addValidationError(Messages.getMessage(Messages.INVALID_ID,
+                    TrainedModelConfig.MODEL_ID.getPreferredName(),
+                    modelId),
+                    validationException);
+            }
+            if (modelId != null && MlStrings.hasValidLengthForId(modelId) == false) {
+                validationException = addValidationError(Messages.getMessage(Messages.ID_TOO_LONG,
+                    TrainedModelConfig.MODEL_ID.getPreferredName(),
                     modelId,
-                    MlStrings.ID_LENGTH_LIMIT));
+                    MlStrings.ID_LENGTH_LIMIT), validationException);
+            }
+            List<String> badTags = tags.stream()
+                .filter(tag -> (MlStrings.isValidId(tag) && MlStrings.hasValidLengthForId(tag)) == false)
+                .collect(Collectors.toList());
+            if (badTags.isEmpty() == false) {
+                validationException = addValidationError(Messages.getMessage(Messages.INFERENCE_INVALID_TAGS,
+                    badTags,
+                    MlStrings.ID_LENGTH_LIMIT),
+                    validationException);
+            }
+
+            for(String tag : tags) {
+                if (tag.equals(modelId)) {
+                    validationException = addValidationError("none of the tags must equal the model_id", validationException);
+                    break;
+                }
+            }
+            if (forCreation) {
+                validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
+                validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
+                validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
+                validationException = checkIllegalSetting(estimatedHeapMemory,
+                    ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
+                    validationException);
+                validationException = checkIllegalSetting(estimatedOperations,
+                    ESTIMATED_OPERATIONS.getPreferredName(),
+                    validationException);
+                validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
             }
 
-            checkIllegalSetting(version, VERSION.getPreferredName());
-            checkIllegalSetting(createdBy, CREATED_BY.getPreferredName());
-            checkIllegalSetting(createTime, CREATE_TIME.getPreferredName());
-            checkIllegalSetting(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
-            checkIllegalSetting(estimatedOperations, ESTIMATED_OPERATIONS.getPreferredName());
-            checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName());
+            if (validationException != null) {
+                throw validationException;
+            }
+
+            return this;
         }
 
-        private static void checkIllegalSetting(Object value, String setting) {
+        private static ActionRequestValidationException checkIllegalSetting(Object value,
+                                                                            String setting,
+                                                                            ActionRequestValidationException validationException) {
             if (value != null) {
-                throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", setting);
+                return addValidationError("illegal to set [" + setting + "] at inference model creation", validationException);
             }
+            return validationException;
         }
 
         public TrainedModelConfig build() {
             return new TrainedModelConfig(
                 modelId,
-                createdBy,
-                version,
+                createdBy == null ? "user" : createdBy,
+                version == null ? Version.CURRENT : version,
                 description,
                 createTime == null ? Instant.now() : createTime,
                 definition,
                 tags,
                 metadata,
                 input,
-                estimatedHeapMemory,
-                estimatedOperations,
-                licenseLevel);
+                estimatedHeapMemory == null ? 0 : estimatedHeapMemory,
+                estimatedOperations == null ? 0 : estimatedOperations,
+                licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel);
         }
     }
 
@@ -531,6 +598,13 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
             return new LazyModelDefinition(input.readString(), null);
         }
 
+        private LazyModelDefinition(LazyModelDefinition definition) {
+            if (definition != null) {
+                this.compressedString = definition.compressedString;
+                this.parsedDefinition = definition.parsedDefinition;
+            }
+        }
+
         private LazyModelDefinition(String compressedString, TrainedModelDefinition trainedModelDefinition) {
             if (compressedString == null && trainedModelDefinition == null) {
                 throw new IllegalArgumentException("unexpected null model definition");

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java

@@ -179,6 +179,12 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable, Acco
             this(true);
         }
 
+        public Builder(TrainedModelDefinition definition) {
+            this(true);
+            this.preProcessors = new ArrayList<>(definition.getPreProcessors());
+            this.trainedModel = definition.trainedModel;
+        }
+
         public Builder setPreProcessors(List<PreProcessor> preProcessors) {
             this.preProcessors = preProcessors;
             return this;

+ 4 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

@@ -95,6 +95,10 @@ public final class Messages {
     public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED =
         "Getting model definition is not supported when getting more than one model";
     public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
+    public static final String INFERENCE_INVALID_TAGS = "Invalid tags {0}; must only can contain lowercase alphanumeric (a-z and 0-9), " +
+        "hyphens or underscores, must start and end with alphanumeric, and must be less than {1} characters.";
+    public static final String INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE = "The provided tags {0} must not match existing model_ids.";
+    public static final String INFERENCE_MODEL_ID_AND_TAGS_UNIQUE = "The provided model_id {0} must not match existing tags.";
 
     public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
     public static final String JOB_AUDIT_CREATED = "Job created";

+ 45 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java

@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
+
+public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTestCase<Request> {
+
+    @Override
+    protected Request createTestInstance() {
+        String modelId = randomAlphaOfLength(10);
+        return new Request(TrainedModelConfigTests.createTestInstance(modelId)
+            .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
+            .build());
+    }
+
+    @Override
+    protected Writeable.Reader<Request> instanceReader() {
+        return (in) -> {
+            Request request = new Request(in);
+            request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry());
+            return request;
+        };
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+    }
+}

+ 45 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionResponseTests.java

@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
+
+public class PutTrainedModelActionResponseTests extends AbstractWireSerializingTestCase<Response> {
+
+    @Override
+    protected Response createTestInstance() {
+        String modelId = randomAlphaOfLength(10);
+        return new Response(TrainedModelConfigTests.createTestInstance(modelId)
+            .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
+            .build());
+    }
+
+    @Override
+    protected Writeable.Reader<Response> instanceReader() {
+        return (in) -> {
+            Response response = new Response(in);
+            response.getResponse().ensureParsedDefinition(xContentRegistry());
+            return response;
+        };
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+
+    @Override
+    protected NamedWriteableRegistry getNamedWriteableRegistry() {
+        return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables());
+    }
+}

+ 22 - 18
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java

@@ -5,8 +5,8 @@
  */
 package org.elasticsearch.xpack.core.ml.inference;
 
-import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -56,14 +56,16 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
         return TrainedModelConfig.builder()
             .setInput(TrainedModelInputTests.createRandomInput())
             .setMetadata(randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)))
-            .setCreateTime(Instant.ofEpochMilli(randomNonNegativeLong()))
+            .setCreateTime(Instant.ofEpochMilli(randomLongBetween(Instant.MIN.getEpochSecond(), Instant.MAX.getEpochSecond())))
             .setVersion(Version.CURRENT)
             .setModelId(modelId)
             .setCreatedBy(randomAlphaOfLength(10))
             .setDescription(randomBoolean() ? null : randomAlphaOfLength(100))
             .setEstimatedHeapMemory(randomNonNegativeLong())
             .setEstimatedOperations(randomNonNegativeLong())
-            .setLicenseLevel(License.OperationMode.PLATINUM.description())
+            .setLicenseLevel(randomFrom(License.OperationMode.PLATINUM.description(),
+                License.OperationMode.GOLD.description(),
+                License.OperationMode.BASIC.description()))
             .setTags(tags);
     }
 
@@ -191,50 +193,52 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
     }
 
     public void testValidateWithNullDefinition() {
-        IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate());
-        assertThat(ex.getMessage(), equalTo("[definition] must not be null."));
+        ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
+            () -> TrainedModelConfig.builder().validate());
+        assertThat(ex.getMessage(), containsString("[definition] must not be null."));
     }
 
     public void testValidateWithInvalidID() {
         String modelId = "InvalidID-";
-        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+        ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
             () -> TrainedModelConfig.builder()
                 .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
                 .setModelId(modelId).validate());
-        assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
+        assertThat(ex.getMessage(), containsString(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId)));
     }
 
     public void testValidateWithLongID() {
         String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining());
-        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+        ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
             () -> TrainedModelConfig.builder()
                 .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
                 .setModelId(modelId).validate());
-        assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
+        assertThat(ex.getMessage(),
+            containsString(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT)));
     }
 
     public void testValidateWithIllegallyUserProvidedFields() {
         String modelId = "simplemodel";
-        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+        ActionRequestValidationException ex = expectThrows(ActionRequestValidationException.class,
             () -> TrainedModelConfig.builder()
                 .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
                 .setCreateTime(Instant.now())
-                .setModelId(modelId).validate());
-        assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation"));
+                .setModelId(modelId).validate(true));
+        assertThat(ex.getMessage(), containsString("illegal to set [create_time] at inference model creation"));
 
-        ex = expectThrows(ElasticsearchException.class,
+        ex = expectThrows(ActionRequestValidationException.class,
             () -> TrainedModelConfig.builder()
                 .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
                 .setVersion(Version.CURRENT)
-                .setModelId(modelId).validate());
-        assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation"));
+                .setModelId(modelId).validate(true));
+        assertThat(ex.getMessage(), containsString("illegal to set [version] at inference model creation"));
 
-        ex = expectThrows(ElasticsearchException.class,
+        ex = expectThrows(ActionRequestValidationException.class,
             () -> TrainedModelConfig.builder()
                 .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
                 .setCreatedBy("ml_user")
-                .setModelId(modelId).validate());
-        assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation"));
+                .setModelId(modelId).validate(true));
+        assertThat(ex.getMessage(), containsString("illegal to set [created_by] at inference model creation"));
     }
 
     public void testSerializationWithLazyDefinition() throws IOException {

+ 0 - 1
x-pack/plugin/ml/qa/ml-with-security/build.gradle

@@ -133,7 +133,6 @@ integTest.runner {
     'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id',
     'ml/get_datafeeds/Test get datafeed given missing datafeed_id',
     'ml/inference_crud/Test delete given used trained model',
-    'ml/inference_crud/Test delete given unused trained model',
     'ml/inference_crud/Test delete with missing model',
     'ml/inference_crud/Test get given missing trained model',
     'ml/inference_crud/Test get given expression without matches and allow_no_match is false',

+ 37 - 44
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java

@@ -9,15 +9,19 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
 import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
 import org.elasticsearch.action.ingest.SimulatePipelineResponse;
 import org.elasticsearch.action.search.SearchRequest;
-import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.xcontent.DeprecationHandler;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.index.query.QueryBuilders;
-import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
-import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
-import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
-import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
+import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.junit.After;
 import org.junit.Before;
 
 import java.io.IOException;
@@ -34,26 +38,14 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
 
     @Before
     public void createBothModels() throws Exception {
-        assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
-            .setId("test_classification")
-            .setSource(CLASSIFICATION_CONFIG, XContentType.JSON)
-            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-            .get().status(), equalTo(RestStatus.CREATED));
-        assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
-            .setId(TrainedModelDefinitionDoc.docId("test_classification", 0))
-            .setSource(buildClassificationModelDoc(), XContentType.JSON)
-            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-            .get().status(), equalTo(RestStatus.CREATED));
-        assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
-            .setId("test_regression")
-            .setSource(REGRESSION_CONFIG, XContentType.JSON)
-            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-            .get().status(), equalTo(RestStatus.CREATED));
-        assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
-            .setId(TrainedModelDefinitionDoc.docId("test_regression", 0))
-            .setSource(buildRegressionModelDoc(), XContentType.JSON)
-            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-            .get().status(), equalTo(RestStatus.CREATED));
+        client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet();
+        client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet();
+    }
+
+    @After
+    public void deleteBothModels() {
+        client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet();
+        client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet();
     }
 
     public void testPipelineCreationAndDeletion() throws Exception {
@@ -391,6 +383,7 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
         "  \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
         "  \"description\": \"test model for regression\",\n" +
         "  \"version\": \"8.0.0\",\n" +
+        "  \"definition\": " + REGRESSION_DEFINITION + ","+
         "  \"license_level\": \"platinum\",\n" +
         "  \"created_by\": \"ml_test\",\n" +
         "  \"estimated_heap_memory_usage_bytes\": 0," +
@@ -518,28 +511,27 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
         "  }\n" +
         "}";
 
-    private static String buildClassificationModelDoc() throws IOException {
-        String compressed =
-            InferenceToXContentCompressor.deflate(new BytesArray(CLASSIFICATION_DEFINITION.getBytes(StandardCharsets.UTF_8)));
-        return modelDocString(compressed, "test_classification");
+    private TrainedModelConfig buildClassificationModel() throws IOException {
+        try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
+            DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
+            new BytesArray(CLASSIFICATION_CONFIG),
+            XContentType.JSON)) {
+            return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
+        }
     }
 
-    private static String buildRegressionModelDoc() throws IOException {
-        String compressed = InferenceToXContentCompressor.deflate(new BytesArray(REGRESSION_DEFINITION.getBytes(StandardCharsets.UTF_8)));
-        return modelDocString(compressed, "test_regression");
+    private TrainedModelConfig buildRegressionModel() throws IOException {
+        try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
+            DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
+            new BytesArray(REGRESSION_CONFIG),
+            XContentType.JSON)) {
+            return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
+        }
     }
 
-    private static String modelDocString(String compressedDefinition, String modelId) {
-        return "" +
-            "{" +
-            "\"model_id\": \"" + modelId + "\",\n" +
-            "\"doc_num\": 0,\n" +
-            "\"doc_type\": \"trained_model_definition_doc\",\n" +
-            "  \"compression_version\": " + 1 + ",\n" +
-            "  \"total_definition_length\": " + compressedDefinition.length() + ",\n" +
-            "  \"definition_length\": " + compressedDefinition.length() + ",\n" +
-            "\"definition\": \"" + compressedDefinition + "\"\n" +
-            "}";
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
     }
 
     private static final String CLASSIFICATION_CONFIG = "" +
@@ -547,9 +539,10 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
         "  \"model_id\": \"test_classification\",\n" +
         "  \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
         "  \"description\": \"test model for classification\",\n" +
+        "  \"definition\": " + CLASSIFICATION_DEFINITION + ","+
         "  \"version\": \"8.0.0\",\n" +
         "  \"license_level\": \"platinum\",\n" +
-        "  \"created_by\": \"benwtrent\",\n" +
+        "  \"created_by\": \"es_test\",\n" +
         "  \"estimated_heap_memory_usage_bytes\": 0," +
         "  \"estimated_operations\": 0," +
         "  \"created_time\": 0\n" +

+ 66 - 63
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java

@@ -6,10 +6,18 @@
 package org.elasticsearch.xpack.ml.integration;
 
 import org.apache.http.util.EntityUtils;
-import org.elasticsearch.Version;
 import org.elasticsearch.client.Request;
 import org.elasticsearch.client.Response;
 import org.elasticsearch.client.ResponseException;
+import org.elasticsearch.client.ml.inference.TrainedModelConfig;
+import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
+import org.elasticsearch.client.ml.inference.TrainedModelInput;
+import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
+import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
+import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeNode;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
@@ -18,26 +26,19 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentType;
-import org.elasticsearch.license.License;
 import org.elasticsearch.test.SecuritySettingsSourceField;
 import org.elasticsearch.test.rest.ESRestTestCase;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
-import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
-import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
 import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
-import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
 import org.elasticsearch.xpack.ml.MachineLearning;
-import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
 import org.junit.After;
 
 import java.io.IOException;
-import java.time.Instant;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.List;
 
 import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
 import static org.hamcrest.Matchers.containsString;
@@ -62,22 +63,8 @@ public class TrainedModelIT extends ESRestTestCase {
     public void testGetTrainedModels() throws IOException {
         String modelId = "a_test_regression_model";
         String modelId2 = "a_test_regression_model-2";
-        Request model1 = new Request("PUT",
-            InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
-        model1.setJsonEntity(buildRegressionModel(modelId));
-        assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
-
-        Request modelDefinition1 = new Request("PUT",
-            InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0));
-        modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId));
-        assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
-
-        Request model2 = new Request("PUT",
-            InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2);
-        model2.setJsonEntity(buildRegressionModel(modelId2));
-        assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201));
-
-        adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
+        putRegressionModel(modelId);
+        putRegressionModel(modelId2);
         Response getModel = client().performRequest(new Request("GET",
             MachineLearning.BASE_PATH + "inference/" + modelId));
 
@@ -164,17 +151,7 @@ public class TrainedModelIT extends ESRestTestCase {
 
     public void testDeleteTrainedModels() throws IOException {
         String modelId = "test_delete_regression_model";
-        Request model1 = new Request("PUT",
-            InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
-        model1.setJsonEntity(buildRegressionModel(modelId));
-        assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
-
-        Request modelDefinition1 = new Request("PUT",
-            InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinitionDoc.docId(modelId, 0));
-        modelDefinition1.setJsonEntity(buildRegressionModelDefinitionDoc(modelId));
-        assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
-
-        adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
+        putRegressionModel(modelId);
 
         Response delModel = client().performRequest(new Request("DELETE",
             MachineLearning.BASE_PATH + "inference/" + modelId));
@@ -208,42 +185,68 @@ public class TrainedModelIT extends ESRestTestCase {
         assertThat(response, containsString("\"definition\""));
     }
 
-    private static String buildRegressionModel(String modelId) throws IOException {
+    private void putRegressionModel(String modelId) throws IOException {
         try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
+            TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()
+                .setPreProcessors(Collections.emptyList())
+                .setTrainedModel(buildRegression());
             TrainedModelConfig.builder()
+                .setDefinition(definition)
                 .setModelId(modelId)
                 .setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3")))
-                .setCreatedBy("ml_test")
-                .setVersion(Version.CURRENT)
-                .setCreateTime(Instant.now())
-                .setEstimatedOperations(0)
-                .setLicenseLevel(License.OperationMode.PLATINUM.description())
-                .setEstimatedHeapMemory(0)
-                .build()
-                .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
-            return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
+                .build().toXContent(builder, ToXContent.EMPTY_PARAMS);
+            Request model = new Request("PUT", "_ml/inference/" + modelId);
+            model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON));
+            assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
         }
     }
 
-    private static String buildRegressionModelDefinitionDoc(String modelId) throws IOException {
-        try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
-            TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
-                .setPreProcessors(Collections.emptyList())
-                .setTrainedModel(LocalModelTests.buildRegression())
-                .build();
-            String compressedString = InferenceToXContentCompressor.deflate(definition);
-            TrainedModelDefinitionDoc doc = new TrainedModelDefinitionDoc.Builder().setDocNum(0)
-                .setCompressedString(compressedString)
-                .setTotalDefinitionLength(compressedString.length())
-                .setDefinitionLength(compressedString.length())
-                .setCompressionVersion(1)
-                .setModelId(modelId).build();
-            doc.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
-            return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
-        }
+    private static TrainedModel buildRegression() {
+        List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
+        Tree tree1 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setNodes(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(0.5),
+                TreeNode.builder(1).setLeafValue(0.3),
+                TreeNode.builder(2)
+                .setThreshold(0.0)
+                .setSplitFeature(3)
+                .setLeftChild(3)
+                .setRightChild(4),
+                TreeNode.builder(3).setLeafValue(0.1),
+                TreeNode.builder(4).setLeafValue(0.2))
+            .build();
+        Tree tree2 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setNodes(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(2)
+                .setThreshold(1.0),
+                TreeNode.builder(1).setLeafValue(1.5),
+                TreeNode.builder(2).setLeafValue(0.9))
+            .build();
+        Tree tree3 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setNodes(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(1)
+                .setThreshold(0.2),
+                TreeNode.builder(1).setLeafValue(1.5),
+                TreeNode.builder(2).setLeafValue(0.9))
+            .build();
+        return Ensemble.builder()
+            .setTargetType(TargetType.REGRESSION)
+            .setFeatureNames(featureNames)
+            .setTrainedModels(Arrays.asList(tree1, tree2, tree3))
+            .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5)))
+            .build();
     }
 
-
     @After
     public void clearMlState() throws Exception {
         new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();

+ 6 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -113,6 +113,7 @@ import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
 import org.elasticsearch.xpack.core.ml.action.PutFilterAction;
 import org.elasticsearch.xpack.core.ml.action.PutJobAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
 import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction;
 import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
@@ -184,6 +185,7 @@ import org.elasticsearch.xpack.ml.action.TransportPutDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.ml.action.TransportPutDatafeedAction;
 import org.elasticsearch.xpack.ml.action.TransportPutFilterAction;
 import org.elasticsearch.xpack.ml.action.TransportPutJobAction;
+import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAction;
 import org.elasticsearch.xpack.ml.action.TransportRevertModelSnapshotAction;
 import org.elasticsearch.xpack.ml.action.TransportSetUpgradeModeAction;
 import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction;
@@ -276,6 +278,7 @@ import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
+import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction;
 import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction;
 import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction;
 import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction;
@@ -761,7 +764,8 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
             new RestExplainDataFrameAnalyticsAction(restController),
             new RestGetTrainedModelsAction(restController),
             new RestDeleteTrainedModelAction(restController),
-            new RestGetTrainedModelsStatsAction(restController)
+            new RestGetTrainedModelsStatsAction(restController),
+            new RestPutTrainedModelAction(restController)
         );
     }
 
@@ -837,6 +841,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
                 new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class),
                 new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class),
                 new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class),
+                new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class),
                 usageAction,
                 infoAction);
     }

+ 190 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

@@ -0,0 +1,190 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.action;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.master.TransportMasterNodeAction;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.block.ClusterBlockException;
+import org.elasticsearch.cluster.block.ClusterBlockLevel;
+import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.license.License;
+import org.elasticsearch.license.LicenseUtils;
+import org.elasticsearch.license.XPackLicenseState;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.core.XPackField;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Request;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
+
+import java.io.IOException;
+import java.time.Instant;
+import java.util.List;
+
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
+
+public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Request, Response> {
+
+    private final TrainedModelProvider trainedModelProvider;
+    private final XPackLicenseState licenseState;
+    private final NamedXContentRegistry xContentRegistry;
+    private final Client client;
+
+    @Inject
+    public TransportPutTrainedModelAction(TransportService transportService, ClusterService clusterService,
+                                          ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters,
+                                          IndexNameExpressionResolver indexNameExpressionResolver, Client client,
+                                          TrainedModelProvider trainedModelProvider, NamedXContentRegistry xContentRegistry) {
+        super(PutTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters, Request::new,
+            indexNameExpressionResolver);
+        this.licenseState = licenseState;
+        this.trainedModelProvider = trainedModelProvider;
+        this.xContentRegistry = xContentRegistry;
+        this.client = client;
+    }
+
+    @Override
+    protected String executor() {
+        return ThreadPool.Names.SAME;
+    }
+
+    @Override
+    protected Response read(StreamInput in) throws IOException {
+        return new Response(in);
+    }
+
+    @Override
+    protected void masterOperation(Task task,
+                                   PutTrainedModelAction.Request request,
+                                   ClusterState state,
+                                   ActionListener<Response> listener) {
+        try {
+            request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry);
+            request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate();
+        } catch (IOException ex) {
+            listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]",
+                ex,
+                request.getTrainedModelConfig().getModelId()));
+            return;
+        } catch (ElasticsearchException ex) {
+            listener.onFailure(ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.",
+                ex,
+                request.getTrainedModelConfig().getModelId()));
+            return;
+        }
+
+        TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
+            .setVersion(Version.CURRENT)
+            .setCreateTime(Instant.now())
+            .setCreatedBy("api_user")
+            .setLicenseLevel(License.OperationMode.PLATINUM.description())
+            .setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
+            .setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations())
+            .build();
+
+        ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
+            r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
+                storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
+                listener::onFailure
+            )),
+            listener::onFailure
+        );
+
+        ActionListener<Void> modelIdTagCheckListener = ActionListener.wrap(
+            r -> checkTagsAgainstModelIds(request.getTrainedModelConfig().getTags(), tagsModelIdCheckListener),
+            listener::onFailure
+        );
+
+        checkModelIdAgainstTags(request.getTrainedModelConfig().getModelId(), modelIdTagCheckListener);
+    }
+
+    private void checkModelIdAgainstTags(String modelId, ActionListener<Void> listener) {
+        QueryBuilder builder = QueryBuilders.constantScoreQuery(
+            QueryBuilders.boolQuery()
+                .filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), modelId)));
+        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1);
+        SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder);
+        executeAsyncWithOrigin(client.threadPool().getThreadContext(),
+            ML_ORIGIN,
+            searchRequest,
+            ActionListener.<SearchResponse>wrap(
+                response -> {
+                    if (response.getHits().getTotalHits().value > 0) {
+                        listener.onFailure(
+                            ExceptionsHelper.badRequestException(
+                                Messages.getMessage(Messages.INFERENCE_MODEL_ID_AND_TAGS_UNIQUE, modelId)));
+                        return;
+                    }
+                    listener.onResponse(null);
+                },
+                listener::onFailure
+            ),
+            client::search);
+    }
+
+    private void checkTagsAgainstModelIds(List<String> tags, ActionListener<Void> listener) {
+        if (tags.isEmpty()) {
+            listener.onResponse(null);
+            return;
+        }
+
+        QueryBuilder builder = QueryBuilders.constantScoreQuery(
+            QueryBuilders.boolQuery()
+                .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), tags)));
+        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(builder).size(0).trackTotalHitsUpTo(1);
+        SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN).source(sourceBuilder);
+        executeAsyncWithOrigin(client.threadPool().getThreadContext(),
+            ML_ORIGIN,
+            searchRequest,
+            ActionListener.<SearchResponse>wrap(
+                response -> {
+                    if (response.getHits().getTotalHits().value > 0) {
+                        listener.onFailure(
+                            ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE, tags)));
+                        return;
+                    }
+                    listener.onResponse(null);
+                },
+                listener::onFailure
+            ),
+            client::search);
+    }
+
+    @Override
+    protected ClusterBlockException checkBlock(Request request, ClusterState state) {
+        return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
+    }
+
+    @Override
+    protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
+        if (licenseState.isMachineLearningAllowed()) {
+            super.doExecute(task, request, listener);
+        } else {
+            listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
+        }
+    }
+}

+ 2 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

@@ -174,10 +174,12 @@ public class TrainedModelProvider {
             r -> {
                 assert r.getItems().length == 2;
                 if (r.getItems()[0].isFailed()) {
+
                     logger.error(new ParameterizedMessage(
                             "[{}] failed to store trained model config for inference",
                             trainedModelConfig.getModelId()),
                         r.getItems()[0].getFailure().getCause());
+
                     wrappedListener.onFailure(r.getItems()[0].getFailure().getCause());
                     return;
                 }

+ 42 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAction.java

@@ -0,0 +1,42 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.rest.inference;
+
+import org.elasticsearch.client.node.NodeClient;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.rest.BaseRestHandler;
+import org.elasticsearch.rest.RestController;
+import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.action.RestToXContentListener;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.ml.MachineLearning;
+
+import java.io.IOException;
+
+public class RestPutTrainedModelAction extends BaseRestHandler {
+
+    public RestPutTrainedModelAction(RestController controller) {
+        controller.registerHandler(RestRequest.Method.PUT,
+            MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}",
+            this);
+    }
+
+    @Override
+    public String getName() {
+        return "xpack_ml_put_trained_model_action";
+    }
+
+    @Override
+    protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
+        String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
+        XContentParser parser = restRequest.contentParser();
+        PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, parser);
+        putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout()));
+
+        return channel -> client.execute(PutTrainedModelAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
+    }
+}

+ 25 - 81
x-pack/plugin/ml/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java

@@ -13,7 +13,6 @@ import org.elasticsearch.action.ingest.SimulatePipelineAction;
 import org.elasticsearch.action.ingest.SimulatePipelineRequest;
 import org.elasticsearch.action.ingest.SimulatePipelineResponse;
 import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.common.bytes.BytesArray;
@@ -32,24 +31,28 @@ import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
 import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
 import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
 import org.elasticsearch.xpack.core.ml.action.PutJobAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
 import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
 import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
-import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
-import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
-import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
 import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
 import org.junit.Before;
 
 import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
 import java.util.Collections;
 
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
-import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
@@ -481,12 +484,7 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
             "          \"target_field\": \"regression_value\",\n" +
             "          \"model_id\": \"modelprocessorlicensetest\",\n" +
             "          \"inference_config\": {\"regression\": {}},\n" +
-            "          \"field_mappings\": {\n" +
-            "            \"col1\": \"col1\",\n" +
-            "            \"col2\": \"col2\",\n" +
-            "            \"col3\": \"col3\",\n" +
-            "            \"col4\": \"col4\"\n" +
-            "          }\n" +
+            "          \"field_mappings\": {}\n" +
             "        }\n" +
             "      }]}\n";
         // Creating a pipeline should work
@@ -668,76 +666,22 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
         assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
     }
 
-    private void putInferenceModel(String modelId) throws Exception {
-        String config = "" +
-            "{\n" +
-            "  \"model_id\": \"" + modelId + "\",\n" +
-            "  \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
-            "  \"description\": \"test model for classification\",\n" +
-            "  \"version\": \"8.0.0\",\n" +
-            "  \"created_by\": \"benwtrent\",\n" +
-            "  \"license_level\": \"platinum\",\n" +
-            "  \"estimated_heap_memory_usage_bytes\": 0,\n" +
-            "  \"estimated_operations\": 0,\n" +
-            "  \"created_time\": 0\n" +
-            "}";
-        String definition = "" +
-            "{" +
-            "  \"trained_model\": {\n" +
-            "    \"tree\": {\n" +
-            "      \"feature_names\": [\n" +
-            "        \"col1_male\",\n" +
-            "        \"col1_female\",\n" +
-            "        \"col2_encoded\",\n" +
-            "        \"col3_encoded\",\n" +
-            "        \"col4\"\n" +
-            "      ],\n" +
-            "      \"tree_structure\": [\n" +
-            "        {\n" +
-            "          \"node_index\": 0,\n" +
-            "            \"split_feature\": 0,\n" +
-            "            \"split_gain\": 12.0,\n" +
-            "            \"threshold\": 10.0,\n" +
-            "            \"decision_type\": \"lte\",\n" +
-            "            \"default_left\": true,\n" +
-            "            \"left_child\": 1,\n" +
-            "            \"right_child\": 2\n" +
-            "         },\n" +
-            "         {\n" +
-            "           \"node_index\": 1,\n" +
-            "           \"leaf_value\": 1\n" +
-            "         },\n" +
-            "         {\n" +
-            "           \"node_index\": 2,\n" +
-            "           \"leaf_value\": 2\n" +
-            "         }\n" +
-            "      ],\n" +
-            "     \"target_type\": \"regression\"\n" +
-            "    }\n" +
-            "  }" +
-            "}";
-        String compressedDefinitionString =
-            InferenceToXContentCompressor.deflate(new BytesArray(definition.getBytes(StandardCharsets.UTF_8)));
-        String compressedDefinition = "" +
-            "{" +
-            "  \"model_id\": \"" + modelId + "\",\n" +
-            "  \"doc_type\": \"" + TrainedModelDefinitionDoc.NAME + "\",\n" +
-            "  \"doc_num\": " + 0 + ",\n" +
-            "  \"compression_version\": " + 1 + ",\n" +
-            "  \"total_definition_length\": " + compressedDefinitionString.length() + ",\n" +
-            "  \"definition_length\": " + compressedDefinitionString.length() + ",\n" +
-            "  \"definition\": \"" + compressedDefinitionString + "\"\n" +
-            "}";
-        assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
-            .setId(modelId)
-            .setSource(config, XContentType.JSON)
-            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-            .get().status(), equalTo(RestStatus.CREATED));
-        assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME)
-            .setId(TrainedModelDefinitionDoc.docId(modelId, 0))
-            .setSource(compressedDefinition, XContentType.JSON)
-            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-            .get().status(), equalTo(RestStatus.CREATED));
+    private void putInferenceModel(String modelId) {
+        TrainedModelConfig config = TrainedModelConfig.builder()
+            .setParsedDefinition(
+            new TrainedModelDefinition.Builder()
+            .setTrainedModel(
+            Tree.builder()
+                .setTargetType(TargetType.REGRESSION)
+                .setFeatureNames(Arrays.asList("feature1"))
+                .setNodes(TreeNode.builder(0).setLeafValue(1.0))
+                .build())
+            .setPreProcessors(Collections.emptyList()))
+            .setModelId(modelId)
+            .setDescription("test model for classification")
+            .setInput(new TrainedModelInput(Arrays.asList("feature1")))
+            .build();
+        client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
     }
 
     private static OperationMode randomInvalidLicenseType() {

+ 0 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java

@@ -199,7 +199,6 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
         return TrainedModelConfig.builder()
             .setCreatedBy("ml_test")
             .setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
-
             .setDescription("trained model config for test")
             .setModelId(modelId)
             .setVersion(Version.CURRENT)

+ 28 - 0
x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model.json

@@ -0,0 +1,28 @@
+{
+  "ml.put_trained_model":{
+    "documentation":{
+      "url":"TODO"
+    },
+    "stability":"experimental",
+    "url":{
+      "paths":[
+        {
+          "path":"/_ml/inference/{model_id}",
+          "methods":[
+            "PUT"
+          ],
+          "parts":{
+            "model_id":{
+              "type":"string",
+              "description":"The ID of the trained models to store"
+            }
+          }
+        }
+      ]
+    },
+    "body": {
+      "description":"The trained model configuration",
+      "required":true
+    }
+  }
+}

+ 105 - 38
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

@@ -1,3 +1,74 @@
+setup:
+  - skip:
+      features: headers
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      ml.put_trained_model:
+        model_id: a-regression-model-0
+        body: >
+          {
+            "description": "empty model for tests",
+            "input": {"field_names": ["field1", "field2"]},
+            "definition": {
+               "preprocessors": [],
+               "trained_model": {
+                  "tree": {
+                     "feature_names": ["field1", "field2"],
+                     "tree_structure": [
+                        {"node_index": 0, "leaf_value": 1}
+                     ],
+                     "target_type": "regression"
+                  }
+               }
+            }
+          }
+
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      ml.put_trained_model:
+        model_id: a-regression-model-1
+        body: >
+          {
+            "description": "empty model for tests",
+            "input": {"field_names": ["field1", "field2"]},
+            "definition": {
+               "preprocessors": [],
+               "trained_model": {
+                  "tree": {
+                     "feature_names": ["field1", "field2"],
+                     "tree_structure": [
+                        {"node_index": 0, "leaf_value": 1}
+                     ],
+                     "target_type": "regression"
+                  }
+               }
+            }
+          }
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      ml.put_trained_model:
+        model_id: a-classification-model
+        body: >
+          {
+            "description": "empty model for tests",
+            "input": {"field_names": ["field1", "field2"]},
+            "definition": {
+               "preprocessors": [],
+               "trained_model": {
+                  "tree": {
+                     "feature_names": ["field1", "field2"],
+                     "tree_structure": [
+                        {"node_index": 0, "leaf_value": 1}
+                     ],
+                     "target_type": "classification",
+                     "classification_labels": ["no", "yes"]
+                  }
+               }
+            }
+          }
 ---
 "Test get given missing trained model":
 
@@ -24,56 +95,52 @@
   - match: { count: 0 }
   - match: { trained_model_configs: [] }
 ---
-"Test delete given unused trained model":
+"Test get models":
+  - do:
+      ml.get_trained_models:
+        model_id: "*"
+  - match: { count: 4 }
+  - match: { trained_model_configs.0.model_id: "a-classification-model" }
+  - match: { trained_model_configs.1.model_id: "a-regression-model-0" }
+  - match: { trained_model_configs.2.model_id: "a-regression-model-1" }
 
   - do:
-      index:
-        id: trained_model_config-unused-regression-model-0
-        index: .ml-inference-000001
-        body: >
-          {
-            "model_id": "unused-regression-model",
-            "created_by": "ml_tests",
-            "version": "8.0.0",
-            "description": "empty model for tests",
-            "create_time": 0,
-            "model_version": 0,
-            "model_type": "local"
-          }
+      ml.get_trained_models:
+        model_id: "a-regression*"
+  - match: { count: 2 }
+  - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
+  - match: { trained_model_configs.1.model_id: "a-regression-model-1" }
+
   - do:
-      indices.refresh: {}
+      ml.get_trained_models:
+        model_id: "*"
+        from: 0
+        size: 2
+  - match: { count: 4 }
+  - match: { trained_model_configs.0.model_id: "a-classification-model" }
+  - match: { trained_model_configs.1.model_id: "a-regression-model-0" }
 
+  - do:
+      ml.get_trained_models:
+        model_id: "*"
+        from: 1
+        size: 1
+  - match: { count: 4 }
+  - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
+---
+"Test delete given unused trained model":
   - do:
       ml.delete_trained_model:
-        model_id: "unused-regression-model"
+        model_id: "a-classification-model"
   - match: { acknowledged: true }
-
 ---
 "Test delete with missing model":
   - do:
       catch: missing
       ml.delete_trained_model:
         model_id: "missing-trained-model"
-
 ---
 "Test delete given used trained model":
-  - do:
-      index:
-        id: trained_model_config-used-regression-model-0
-        index: .ml-inference-000001
-        body: >
-          {
-            "model_id": "used-regression-model",
-            "created_by": "ml_tests",
-            "version": "8.0.0",
-            "description": "empty model for tests",
-            "create_time": 0,
-            "model_version": 0,
-            "model_type": "local"
-          }
-  - do:
-      indices.refresh: {}
-
   - do:
       ingest.put_pipeline:
         id: "regression-model-pipeline"
@@ -82,7 +149,7 @@
             "processors": [
               {
                 "inference" : {
-                  "model_id" : "used-regression-model",
+                  "model_id" : "a-regression-model-0",
                   "inference_config": {"regression": {}},
                   "target_field": "regression_field",
                   "field_mappings": {}
@@ -95,12 +162,12 @@
   - do:
       catch: conflict
       ml.delete_trained_model:
-        model_id: "used-regression-model"
+        model_id: "a-regression-model-0"
 ---
 "Test get pre-packaged trained models":
   - do:
       ml.get_trained_models:
-        model_id: "_all"
+        model_id: "lang_ident_model_1"
         allow_no_match: false
   - match: { count: 1 }
   - match: { trained_model_configs.0.model_id: "lang_ident_model_1" }