Browse Source

[ML] adds new for_export flag to GET _ml/inference API (#57351)

Adds a new boolean flag, `for_export` to the `GET _ml/inference/<model_id>` API.

This flag is useful for moving models between clusters.
Benjamin Trent 5 years ago
parent
commit
251b17009a

+ 3 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

@@ -770,6 +770,9 @@ final class MLRequestConverters {
         if (getTrainedModelsRequest.getTags() != null) {
             params.putParam(GetTrainedModelsRequest.TAGS, Strings.collectionToCommaDelimitedString(getTrainedModelsRequest.getTags()));
         }
+        if (getTrainedModelsRequest.getForExport() != null) {
+            params.putParam(GetTrainedModelsRequest.FOR_EXPORT, Boolean.toString(getTrainedModelsRequest.getForExport()));
+        }
         Request request = new Request(HttpGet.METHOD_NAME, endpoint);
         request.addParameters(params.asMap());
         return request;

+ 21 - 1
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetTrainedModelsRequest.java

@@ -34,6 +34,7 @@ public class GetTrainedModelsRequest implements Validatable {
 
     public static final String ALLOW_NO_MATCH = "allow_no_match";
     public static final String INCLUDE_MODEL_DEFINITION = "include_model_definition";
+    public static final String FOR_EXPORT = "for_export";
     public static final String DECOMPRESS_DEFINITION = "decompress_definition";
     public static final String TAGS = "tags";
 
@@ -41,6 +42,7 @@ public class GetTrainedModelsRequest implements Validatable {
     private Boolean allowNoMatch;
     private Boolean includeDefinition;
     private Boolean decompressDefinition;
+    private Boolean forExport;
     private PageParams pageParams;
     private List<String> tags;
 
@@ -137,6 +139,23 @@ public class GetTrainedModelsRequest implements Validatable {
         return setTags(Arrays.asList(tags));
     }
 
+    public Boolean getForExport() {
+        return forExport;
+    }
+
+    /**
+     * Setting this flag to `true` removes certain fields from the model definition on retrieval.
+     *
+     * This is useful when getting the model and wanting to put it in another cluster.
+     *
+     * Default value is false.
+     * @param forExport Boolean value indicating if certain fields should be removed from the mode on GET
+     */
+    public GetTrainedModelsRequest setForExport(Boolean forExport) {
+        this.forExport = forExport;
+        return this;
+    }
+
     @Override
     public Optional<ValidationException> validate() {
         if (ids == null || ids.isEmpty()) {
@@ -155,11 +174,12 @@ public class GetTrainedModelsRequest implements Validatable {
             && Objects.equals(allowNoMatch, other.allowNoMatch)
             && Objects.equals(decompressDefinition, other.decompressDefinition)
             && Objects.equals(includeDefinition, other.includeDefinition)
+            && Objects.equals(forExport, other.forExport)
             && Objects.equals(pageParams, other.pageParams);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition);
+        return Objects.hash(ids, allowNoMatch, pageParams, decompressDefinition, includeDefinition, forExport);
     }
 }

+ 2 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -3611,7 +3611,8 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
                 .setIncludeDefinition(false) // <3>
                 .setDecompressDefinition(false) // <4>
                 .setAllowNoMatch(true) // <5>
-                .setTags("regression"); // <6>
+                .setTags("regression") // <6>
+                .setForExport(false); // <7>
             // end::get-trained-models-request
             request.setTags((List<String>)null);
 

+ 3 - 0
docs/java-rest/high-level/ml/get-trained-models.asciidoc

@@ -32,6 +32,9 @@ include-tagged::{doc-tests-file}[{api}-request]
 <6> An optional list of tags used to narrow the model search. A Trained Model
     can have many tags or none. The trained models in the response will
     contain all the provided tags.
+<7> Optional boolean value indicating if certain fields should be removed on
+    retrieval. This is useful for getting the trained model in a format that
+    can then be put into another cluster.
 
 include::../execution.asciidoc[]
 

+ 6 - 0
docs/reference/ml/df-analytics/apis/get-inference-trained-model.asciidoc

@@ -82,6 +82,12 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=size]
 (Optional, string)
 include::{docdir}/ml/ml-shared.asciidoc[tag=tags]
 
+`for_export`::
+(Optional, boolean)
+Indicates if certain fields should be removed from the model configuration on
+retrieval. This allows the model to be in an acceptable format to be retrieved
+and then added to another cluster. Default is false.
+
 [role="child_attributes"]
 [[ml-get-inference-results]]
 ==== {api-response-body-title}

+ 14 - 10
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

@@ -49,6 +49,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
     public static final String NAME = "trained_model_config";
     public static final int CURRENT_DEFINITION_COMPRESSION_VERSION = 1;
     public static final String DECOMPRESS_DEFINITION = "decompress_definition";
+    public static final String FOR_EXPORT = "for_export";
 
     private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
 
@@ -304,13 +305,22 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
-        builder.field(MODEL_ID.getPreferredName(), modelId);
-        builder.field(CREATED_BY.getPreferredName(), createdBy);
-        builder.field(VERSION.getPreferredName(), version.toString());
+        // If the model is to be exported for future import to another cluster, these fields are irrelevant.
+        if (params.paramAsBoolean(FOR_EXPORT, false) == false) {
+            builder.field(MODEL_ID.getPreferredName(), modelId);
+            builder.field(CREATED_BY.getPreferredName(), createdBy);
+            builder.field(VERSION.getPreferredName(), version.toString());
+            builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
+            builder.humanReadableField(
+                ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
+                ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
+                new ByteSizeValue(estimatedHeapMemory));
+            builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
+            builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
+        }
         if (description != null) {
             builder.field(DESCRIPTION.getPreferredName(), description);
         }
-        builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
         // We don't store the definition in the same document as the configuration
         if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
             if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
@@ -327,12 +337,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
             builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
         }
         builder.field(INPUT.getPreferredName(), input);
-        builder.humanReadableField(
-            ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
-            ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
-            new ByteSizeValue(estimatedHeapMemory));
-        builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
-        builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
         if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) {
             builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
         }

+ 38 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java

@@ -40,6 +40,7 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 
 import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
 import static org.hamcrest.Matchers.containsString;
@@ -187,6 +188,43 @@ public class TrainedModelIT extends ESRestTestCase {
         assertThat(response, containsString("\"definition\""));
     }
 
+    @SuppressWarnings("unchecked")
+    public void testExportImportModel() throws IOException {
+        String modelId = "regression_model_to_export";
+        putRegressionModel(modelId);
+        Response getModel = client().performRequest(new Request("GET",
+            MachineLearning.BASE_PATH + "inference/" + modelId));
+
+        assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
+        String response = EntityUtils.toString(getModel.getEntity());
+        assertThat(response, containsString("\"model_id\":\"regression_model_to_export\""));
+        assertThat(response, containsString("\"count\":1"));
+
+        getModel = client().performRequest(new Request("GET",
+            MachineLearning.BASE_PATH +
+                "inference/" + modelId +
+                "?include_model_definition=true&decompress_definition=false&for_export=true"));
+        assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
+
+        Map<String, Object> exportedModel = entityAsMap(getModel);
+        Map<String, Object> modelDefinition = ((List<Map<String, Object>>)exportedModel.get("trained_model_configs")).get(0);
+
+        String importedModelId = "regression_model_to_import";
+        try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
+            builder.map(modelDefinition);
+            Request model = new Request("PUT", "_ml/inference/" + importedModelId);
+            model.setJsonEntity(XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON));
+            assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));
+        }
+        getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference/regression*"));
+
+        assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
+        response = EntityUtils.toString(getModel.getEntity());
+        assertThat(response, containsString("\"model_id\":\"regression_model_to_export\""));
+        assertThat(response, containsString("\"model_id\":\"regression_model_to_import\""));
+        assertThat(response, containsString("\"count\":2"));
+    }
+
     private void putRegressionModel(String modelId) throws IOException {
         try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
             TrainedModelDefinition.Builder definition = new TrainedModelDefinition.Builder()

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java

@@ -73,7 +73,7 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
 
     @Override
     protected Set<String> responseParams() {
-        return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
+        return Set.of(TrainedModelConfig.DECOMPRESS_DEFINITION, TrainedModelConfig.FOR_EXPORT);
     }
 
     private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {

+ 6 - 0
x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_trained_models.json

@@ -62,6 +62,12 @@
         "required":false,
         "type":"list",
         "description":"A comma-separated list of tags that the model must have."
+      },
+      "for_export": {
+        "required": false,
+        "type": "boolean",
+        "default": false,
+        "description": "Omits fields that are illegal to set on model PUT"
       }
     }
   }

+ 21 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

@@ -818,3 +818,24 @@ setup:
                }
             }
           }
+---
+"Test for_export flag":
+  - do:
+      ml.get_trained_models:
+        model_id: "a-regression-model-1"
+        for_export: true
+        include_model_definition: true
+        decompress_definition: false
+
+  - match: { trained_model_configs.0.description: "empty model for tests" }
+  - is_true:  trained_model_configs.0.compressed_definition
+  - is_true:  trained_model_configs.0.input
+  - is_true:  trained_model_configs.0.inference_config
+  - is_true:  trained_model_configs.0.tags
+  - is_false: trained_model_configs.0.model_id
+  - is_false: trained_model_configs.0.created_by
+  - is_false: trained_model_configs.0.version
+  - is_false: trained_model_configs.0.create_time
+  - is_false: trained_model_configs.0.estimated_heap_memory_usage
+  - is_false: trained_model_configs.0.estimated_operations
+  - is_false: trained_model_configs.0.license_level