1
0
Эх сурвалжийг харах

[ML] Include model definition install status for Pytorch models (#95271)

Adds a new include flag definition_status to the GET trained models API.
When present the trained model configuration returned in the response 
will have the new boolean field fully_defined if the full model definition 
is exists.
David Kyle 2 жил өмнө
parent
commit
6de8469a51

+ 5 - 0
docs/changelog/95271.yaml

@@ -0,0 +1,5 @@
+pr: 95271
+summary: Include model definition install status for Pytorch models
+area: Machine Learning
+type: enhancement
+issues: []

+ 7 - 0
docs/reference/ml/trained-models/apis/get-trained-models.asciidoc

@@ -76,6 +76,8 @@ options are:
     it was specified by the user or tuned during hyperparameter optimization.
  - `total_feature_importance`: Includes the total {feat-imp} for the training
    data set.
+ - `definition_status`: Includes the field `fully_defined` indicating if the
+full model definition is present.
 The baseline and total {feat-imp} values are returned in the `metadata` field
 in the response body.
 
@@ -937,6 +939,11 @@ The input field names for the model definition.
 An array of input field names for the model.
 =====
 
+`fully_defined`::
+(boolean)
+True if the full model definition is present.
+This field is only present if `include=definition_status` was specified in the request.
+
 // Begin location
 `location`::
 (Optional, object)

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java

@@ -41,6 +41,7 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
         static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance";
         static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline";
         static final String HYPERPARAMETERS = "hyperparameters";
+        static final String DEFINITION_STATUS = TrainedModelConfig.DEFINITION_STATUS;
 
         private static final Set<String> KNOWN_INCLUDES;
         static {
@@ -49,6 +50,7 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
             includes.add(TOTAL_FEATURE_IMPORTANCE);
             includes.add(FEATURE_IMPORTANCE_BASELINE);
             includes.add(HYPERPARAMETERS);
+            includes.add(DEFINITION_STATUS);
             KNOWN_INCLUDES = Collections.unmodifiableSet(includes);
         }
 
@@ -103,6 +105,10 @@ public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Re
             return this.includes.contains(HYPERPARAMETERS);
         }
 
+        public boolean isIncludeDefinitionStatus() {
+            return this.includes.contains(DEFINITION_STATUS);
+        }
+
         @Override
         public boolean equals(Object o) {
             if (this == o) return true;

+ 12 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

@@ -71,6 +71,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
     public static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline";
     public static final String HYPERPARAMETERS = "hyperparameters";
     public static final String MODEL_ALIASES = "model_aliases";
+    public static final String DEFINITION_STATUS = "definition_status";
 
     private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
     private static final String MODEL_SIZE_HUMAN = "model_size";
@@ -187,6 +188,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
     private final LazyModelDefinition definition;
     private final TrainedModelLocation location;
     private final ModelPackageConfig modelPackageConfig;
+    private Boolean fullDefinition;
 
     TrainedModelConfig(
         String modelId,
@@ -266,8 +268,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         }
         if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
             modelPackageConfig = in.readOptionalWriteable(ModelPackageConfig::new);
+            fullDefinition = in.readOptionalBoolean();
         } else {
             modelPackageConfig = null;
+            fullDefinition = null;
         }
     }
 
@@ -395,6 +399,10 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         return Optional.ofNullable(inferenceConfig).map(InferenceConfig::isAllocateOnly).orElse(false);
     }
 
+    public void setFullDefinition(boolean fullDefinition) {
+        this.fullDefinition = fullDefinition;
+    }
+
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeString(modelId);
@@ -423,6 +431,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
 
         if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
             out.writeOptionalWriteable(modelPackageConfig);
+            out.writeOptionalBoolean(fullDefinition);
         }
     }
 
@@ -483,6 +492,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
         if (location != null) {
             writeNamedObject(builder, params, LOCATION.getPreferredName(), location);
         }
+        if (params.paramAsBoolean(DEFINITION_STATUS, false) && fullDefinition != null) {
+            builder.field("fully_defined", fullDefinition);
+        }
         builder.endObject();
         return builder;
     }

+ 82 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java

@@ -9,40 +9,54 @@ package org.elasticsearch.xpack.ml.action;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
+import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.client.internal.OriginSettingClient;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.core.Tuple;
+import org.elasticsearch.index.IndexNotFoundException;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+
 public class TransportGetTrainedModelsAction extends HandledTransportAction<Request, Response> {
 
     private final TrainedModelProvider provider;
     private final ClusterService clusterService;
+    private final Client client;
 
     @Inject
     public TransportGetTrainedModelsAction(
         TransportService transportService,
         ActionFilters actionFilters,
         ClusterService clusterService,
+        Client client,
         TrainedModelProvider trainedModelProvider
     ) {
         super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new);
         this.provider = trainedModelProvider;
         this.clusterService = clusterService;
+        this.client = client;
     }
 
     @Override
@@ -51,6 +65,34 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
 
         Response.Builder responseBuilder = Response.builder();
 
+        ActionListener<List<TrainedModelConfig>> getModelDefinitionStatusListener = ActionListener.wrap(configs -> {
+            if (request.getIncludes().isIncludeDefinitionStatus() == false) {
+                listener.onResponse(responseBuilder.setModels(configs).build());
+                return;
+            }
+
+            assert configs.size() <= 1;
+            if (configs.isEmpty()) {
+                listener.onResponse(responseBuilder.setModels(configs).build());
+                return;
+            }
+
+            if (configs.get(0).getModelType() != TrainedModelType.PYTORCH) {
+                listener.onFailure(ExceptionsHelper.badRequestException("Definition status is only relevant to PyTorch model types"));
+                return;
+            }
+
+            definitionStatus(
+                new OriginSettingClient(client, ML_ORIGIN),
+                configs.get(0).getModelId(),
+                configs.get(0).getLocation().getResourceName(),
+                ActionListener.wrap(isDownloaded -> {
+                    configs.get(0).setFullDefinition(isDownloaded);
+                    listener.onResponse(responseBuilder.setModels(configs).build());
+                }, listener::onFailure)
+            );
+        }, listener::onFailure);
+
         ActionListener<Tuple<Long, Map<String, Set<String>>>> idExpansionListener = ActionListener.wrap(totalAndIds -> {
             responseBuilder.setTotalCount(totalAndIds.v1());
 
@@ -64,6 +106,15 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
                 return;
             }
 
+            if (request.getIncludes().isIncludeDefinitionStatus() && totalAndIds.v2().size() > 1) {
+                listener.onFailure(
+                    ExceptionsHelper.badRequestException(
+                        "Getting the model download status is not supported when getting more than one model"
+                    )
+                );
+                return;
+            }
+
             if (request.getIncludes().isIncludeModelDefinition()) {
                 Map.Entry<String, Set<String>> modelIdAndAliases = totalAndIds.v2().entrySet().iterator().next();
                 provider.getTrainedModel(
@@ -72,8 +123,8 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
                     request.getIncludes(),
                     parentTaskId,
                     ActionListener.wrap(
-                        config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
-                        listener::onFailure
+                        config -> getModelDefinitionStatusListener.onResponse(Collections.singletonList(config)),
+                        getModelDefinitionStatusListener::onFailure
                     )
                 );
             } else {
@@ -82,7 +133,7 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
                     request.getIncludes(),
                     request.isAllowNoResources(),
                     parentTaskId,
-                    ActionListener.wrap(configs -> listener.onResponse(responseBuilder.setModels(configs).build()), listener::onFailure)
+                    getModelDefinitionStatusListener
                 );
             }
         }, listener::onFailure);
@@ -97,4 +148,32 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
         );
     }
 
+    private void definitionStatus(Client client, String modelId, String index, ActionListener<Boolean> listener) {
+        client.prepareSearch(index)
+            .setQuery(
+                QueryBuilders.constantScoreQuery(
+                    QueryBuilders.boolQuery()
+                        .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
+                        .filter(
+                            QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME)
+                        )
+                        .filter(QueryBuilders.termQuery(TrainedModelDefinitionDoc.EOS.getPreferredName(), true))
+                )
+            )
+            .setFetchSource(false)
+            .setSize(1)
+            .setTrackTotalHits(false)
+            // eos field is not mapped, use a runtime mapping
+            .setRuntimeMappings(Map.of("eos", Map.of("type", "boolean")))
+            .execute(ActionListener.wrap(response -> {
+                listener.onResponse(response.getHits().getHits().length > 0);
+            }, e -> {
+                // if no parts have been uploaded the index may not exist
+                if (e instanceof IndexNotFoundException) {
+                    listener.onResponse(false);
+                } else {
+                    listener.onFailure(e);
+                }
+            }));
+    }
 }

+ 9 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestGetTrainedModelsAction.java

@@ -106,7 +106,7 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
         return channel -> new RestCancellableNodeClient(client, restRequest.getHttpChannel()).execute(
             GetTrainedModelsAction.INSTANCE,
             request,
-            new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES)
+            new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES, includes)
         );
     }
 
@@ -117,10 +117,16 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
 
     private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
         private final Map<String, String> defaultToXContentParamValues;
+        private final Set<String> includes;
 
-        private RestToXContentListenerWithDefaultValues(RestChannel channel, Map<String, String> defaultToXContentParamValues) {
+        private RestToXContentListenerWithDefaultValues(
+            RestChannel channel,
+            Map<String, String> defaultToXContentParamValues,
+            Set<String> includes
+        ) {
             super(channel);
             this.defaultToXContentParamValues = defaultToXContentParamValues;
+            this.includes = includes;
         }
 
         @Override
@@ -128,6 +134,7 @@ public class RestGetTrainedModelsAction extends BaseRestHandler {
             assert response.isFragment() == false; // would be nice if we could make default methods final
             Map<String, String> params = new HashMap<>(channel.request().params());
             defaultToXContentParamValues.forEach((k, v) -> params.computeIfAbsent(k, defaultToXContentParamValues::get));
+            includes.forEach(include -> params.put(include, "true"));
             response.toXContent(builder, new ToXContent.MapParams(params));
             return new RestResponse(getStatus(response), builder);
         }

+ 8 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

@@ -379,3 +379,11 @@ setup:
         model_alias: "pytorch"
         model_id: "another_test_model"
         reassign: true
+
+---
+"Test include model definition status":
+  - do:
+      ml.get_trained_models:
+        model_id: test_model
+        include: definition_status
+  - match: { trained_model_configs.0.fully_defined: true }

+ 36 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/inference_crud.yml

@@ -1178,3 +1178,39 @@ setup:
               }
             }
           }
+
+---
+"Test include model definition status":
+  - do:
+      ml.put_trained_model:
+        model_id: model-without-definition
+        body: >
+          {
+            "model_type": "pytorch",
+            "inference_config": {
+              "ner": {
+              }
+            }
+          }
+
+  - do:
+      ml.get_trained_models:
+        model_id: model-without-definition
+        include: definition_status
+  - match: { count: 1 }
+  - match: { trained_model_configs.0.fully_defined: false }
+  - do:
+      ml.get_trained_models:
+        model_id: model-without-definition
+  - match: { count: 1 }
+  - match: { trained_model_configs.0.fully_defined: null }
+  - do:
+      catch: /Getting the model download status is not supported when getting more than one model/
+      ml.get_trained_models:
+        model_id: _all
+        include: definition_status
+  - do:
+      catch: /Definition status is only relevant to PyTorch model types/
+      ml.get_trained_models:
+        model_id: a-regression-model-0
+        include: definition_status