Browse Source

[ML][Inference] don't return inflated definition when storing trained models (#52573)

When `PUT` is called to store a trained model, it is useful to return the newly create model config. But, it is NOT useful to return the inflated definition.

These definitions can be large and returning the inflated definition causes undo work on the server and client side.
Benjamin Trent 5 years ago
parent
commit
1c1d45130c

+ 2 - 0
docs/java-rest/high-level/ml/put-trained-model.asciidoc

@@ -46,6 +46,8 @@ include::../execution.asciidoc[]
 ==== Response
 
 The returned +{response}+ contains the newly created trained model.
+The +{response}+ will omit the model definition as a precaution against
+streaming large model definitions back to the client.
 
 ["source","java",subs="attributes,callouts,macros"]
 --------------------------------------------------

+ 4 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

@@ -280,7 +280,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
         // We don't store the definition in the same document as the configuration
         if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
-            if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) {
+            if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
                 builder.field(DEFINITION.getPreferredName(), definition);
             } else {
                 builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString());
@@ -371,6 +371,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
             this.tags = config.getTags();
             this.metadata = config.getMetadata();
             this.input = config.getInput();
+            this.estimatedOperations = config.estimatedOperations;
+            this.estimatedHeapMemory = config.estimatedHeapMemory;
+            this.licenseLevel = config.licenseLevel.description();
         }
 
         public Builder setModelId(String modelId) {

+ 6 - 6
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java

@@ -143,21 +143,21 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
             "platinum");
 
         BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
-        assertThat(reference.utf8ToString(), containsString("\"definition\""));
+        assertThat(reference.utf8ToString(), containsString("\"compressed_definition\""));
 
         reference = XContentHelper.toXContent(config,
             XContentType.JSON,
             new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")),
             false);
         assertThat(reference.utf8ToString(), not(containsString("definition")));
+        assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));
 
         reference = XContentHelper.toXContent(config,
             XContentType.JSON,
-            new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "false")),
+            new ToXContent.MapParams(Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, "true")),
             false);
-        assertThat(reference.utf8ToString(), not(containsString("\"definition\"")));
-        assertThat(reference.utf8ToString(), containsString("compressed_definition"));
-        assertThat(reference.utf8ToString(), containsString(lazyModelDefinition.getCompressedString()));
+        assertThat(reference.utf8ToString(), containsString("\"definition\""));
+        assertThat(reference.utf8ToString(), not(containsString("compressed_definition")));
     }
 
     public void testParseWithBothDefinitionAndCompressedSupplied() throws IOException {
@@ -180,7 +180,7 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
         BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
         Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();
 
-        objectMap.put(TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName(), lazyModelDefinition.getCompressedString());
+        objectMap.put(TrainedModelConfig.DEFINITION.getPreferredName(), config.getModelDefinition());
 
         try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(objectMap);
             XContentParser parser = XContentType.JSON

+ 1 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java

@@ -93,6 +93,7 @@ public class TrainedModelIT extends ESRestTestCase {
         assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
         assertThat(response, containsString("\"estimated_heap_memory_usage\""));
         assertThat(response, containsString("\"definition\""));
+        assertThat(response, not(containsString("\"compressed_definition\"")));
         assertThat(response, containsString("\"count\":1"));
 
         getModel = client().performRequest(new Request("GET",

+ 4 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

@@ -108,7 +108,10 @@ public class TransportPutTrainedModelAction extends TransportMasterNodeAction<Re
 
         ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
             r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
-                storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
+                bool -> {
+                    TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build();
+                    listener.onResponse(new PutTrainedModelAction.Response(configToReturn));
+                },
                 listener::onFailure
             )),
             listener::onFailure

+ 32 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java

@@ -8,8 +8,14 @@ package org.elasticsearch.xpack.ml.rest.inference;
 import org.elasticsearch.client.node.NodeClient;
 import org.elasticsearch.cluster.metadata.MetaData;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.rest.BaseRestHandler;
+import org.elasticsearch.rest.BytesRestResponse;
+import org.elasticsearch.rest.RestChannel;
 import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.RestResponse;
 import org.elasticsearch.rest.action.RestToXContentListener;
 import org.elasticsearch.xpack.core.action.util.PageParams;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
@@ -18,7 +24,9 @@ import org.elasticsearch.xpack.ml.MachineLearning;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 
 import static java.util.Arrays.asList;
@@ -34,6 +42,8 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
             new Route(GET, MachineLearning.BASE_PATH + "inference"));
     }
 
+    private static final Map<String, String> DEFAULT_TO_XCONTENT_VALUES =
+        Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true));
     @Override
     public String getName() {
         return "ml_get_trained_models_action";
@@ -56,7 +66,9 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
                 restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
         }
         request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
-        return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel));
+        return channel -> client.execute(GetTrainedModelsAction.INSTANCE,
+            request,
+            new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES));
     }
 
     @Override
@@ -64,4 +76,23 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
         return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
     }
 
+    private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
+        private final Map<String, String> defaultToXContentParamValues;
+
+        private RestToXContentListenerWithDefaultValues(RestChannel channel, Map<String, String> defaultToXContentParamValues) {
+            super(channel);
+            this.defaultToXContentParamValues = defaultToXContentParamValues;
+        }
+
+        @Override
+        public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception {
+            assert response.isFragment() == false; //would be nice if we could make default methods final
+            Map<String, String> params = new HashMap<>(channel.request().params());
+            defaultToXContentParamValues.forEach((k, v) ->
+                params.computeIfAbsent(k, defaultToXContentParamValues::get)
+            );
+            response.toXContent(builder, new ToXContent.MapParams(params));
+            return new BytesRestResponse(getStatus(response), builder);
+        }
+    }
 }

+ 50 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

@@ -460,3 +460,53 @@ setup:
               }
             }
           }
+---
+"Test put model":
+  - do:
+      ml.put_trained_model:
+        model_id: my-regression-model
+        body: >
+          {
+            "description": "model for tests",
+            "input": {"field_names": ["field1", "field2"]},
+            "definition": {
+               "preprocessors": [],
+               "trained_model": {
+                  "ensemble": {
+                    "target_type": "regression",
+                    "trained_models": [
+                      {
+                        "tree": {
+                          "feature_names": ["field1", "field2"],
+                          "tree_structure": [
+                             {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
+                             {"node_index": 1, "leaf_value": 0},
+                             {"node_index": 2, "leaf_value": 1}
+                          ],
+                          "target_type": "regression"
+                        }
+                      },
+                      {
+                        "tree": {
+                          "feature_names": ["field1", "field2"],
+                          "tree_structure": [
+                             {"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
+                             {"node_index": 1, "leaf_value": 0},
+                             {"node_index": 2, "leaf_value": 1}
+                          ],
+                          "target_type": "regression"
+                        }
+                      }
+                    ]
+                  }
+               }
+            }
+          }
+  - match: { model_id: my-regression-model }
+  - match: { estimated_operations: 6 }
+  - is_false: definition
+  - is_false: compressed_definition
+  - is_true: license_level
+  - is_true: create_time
+  - is_true: version
+  - is_true: estimated_heap_memory_usage_bytes