瀏覽代碼

[ML] throw when definition is requested for pytorch models (#80310)

Since pytorch models are not built in Elasticsearch, we don't need to provide supplying the definition when retrieving the trained model.

In fact, some of these definitions are so large, that returning them is prohibitive.

related to #80254
Benjamin Trent 4 年之前
父節點
當前提交
8900572f79

+ 12 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -417,6 +417,18 @@ public class PyTorchModelIT extends ESRestTestCase {
         assertThat(ex.getMessage(), containsString("Could not find trained model [missing_model]"));
     }
 
+    public void testGetPytorchModelWithDefinition() throws IOException {
+        String model = "should-fail-get";
+        createTrainedModel(model);
+        putVocabulary(List.of("once", "twice"), model);
+        putModelDefinition(model);
+        Exception ex = expectThrows(
+            Exception.class,
+            () -> client().performRequest(new Request("GET", "_ml/trained_models/" + model + "?include=definition"))
+        );
+        assertThat(ex.getMessage(), containsString("[should-fail-get] is type [pytorch] and does not support retrieving the definition"));
+    }
+
     public void testInferencePipelineAgainstUnallocatedModel() throws IOException {
         String model = "not-deployed";
         createTrainedModel(model);

+ 58 - 54
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

@@ -24,9 +24,7 @@ import org.elasticsearch.action.bulk.BulkRequestBuilder;
 import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.index.IndexAction;
 import org.elasticsearch.action.index.IndexRequest;
-import org.elasticsearch.action.search.MultiSearchAction;
 import org.elasticsearch.action.search.MultiSearchRequest;
-import org.elasticsearch.action.search.MultiSearchRequestBuilder;
 import org.elasticsearch.action.search.MultiSearchResponse;
 import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
@@ -602,26 +600,17 @@ public class TrainedModelProvider {
         }, finalListener::onFailure);
 
         QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelId));
-        MultiSearchRequestBuilder multiSearchRequestBuilder = client.prepareMultiSearch()
-            .add(
-                client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
-                    .setQuery(queryBuilder)
-                    // use sort to get the last
-                    .addSort("_index", SortOrder.DESC)
-                    .setSize(1)
-                    .request()
-            );
-
-        if (includes.isIncludeModelDefinition()) {
-            multiSearchRequestBuilder.add(
-                ChunkedTrainedModelRestorer.buildSearch(client, modelId, InferenceIndexConstants.INDEX_PATTERN, MAX_NUM_DEFINITION_DOCS)
-            );
-        }
+        SearchRequest trainedModelConfigSearch = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
+            .setQuery(queryBuilder)
+            // use sort to get the last
+            .addSort("_index", SortOrder.DESC)
+            .setSize(1)
+            .request();
 
-        ActionListener<MultiSearchResponse> multiSearchResponseActionListener = ActionListener.wrap(multiSearchResponse -> {
+        ActionListener<SearchResponse> trainedModelSearchHandler = ActionListener.wrap(modelSearchResponse -> {
             TrainedModelConfig.Builder builder;
             try {
-                builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseModelConfigLenientlyFromSource);
+                builder = handleHits(modelSearchResponse.getHits().getHits(), modelId, this::parseModelConfigLenientlyFromSource).get(0);
             } catch (ResourceNotFoundException ex) {
                 getTrainedModelListener.onFailure(
                     new ResourceNotFoundException(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))
@@ -631,46 +620,58 @@ public class TrainedModelProvider {
                 getTrainedModelListener.onFailure(ex);
                 return;
             }
-
-            if (includes.isIncludeModelDefinition()) {
-                try {
-                    List<TrainedModelDefinitionDoc> docs = handleSearchItems(
-                        multiSearchResponse.getResponses()[1],
+            if (includes.isIncludeModelDefinition() == false) {
+                getTrainedModelListener.onResponse(builder);
+                return;
+            }
+            if (builder.getModelType() == TrainedModelType.PYTORCH && includes.isIncludeModelDefinition()) {
+                finalListener.onFailure(
+                    ExceptionsHelper.badRequestException(
+                        "[{}] is type [{}] and does not support retrieving the definition",
                         modelId,
-                        (bytes, resourceId) -> ChunkedTrainedModelRestorer.parseModelDefinitionDocLenientlyFromSource(
-                            bytes,
-                            resourceId,
-                            xContentRegistry
-                        )
-                    );
+                        builder.getModelType()
+                    )
+                );
+                return;
+            }
+            executeAsyncWithOrigin(
+                client,
+                ML_ORIGIN,
+                SearchAction.INSTANCE,
+                ChunkedTrainedModelRestorer.buildSearch(client, modelId, InferenceIndexConstants.INDEX_PATTERN, MAX_NUM_DEFINITION_DOCS),
+                ActionListener.wrap(definitionSearchResponse -> {
                     try {
-                        BytesReference compressedData = getDefinitionFromDocs(docs, modelId);
-                        builder.setDefinitionFromBytes(compressedData);
-                    } catch (ElasticsearchException elasticsearchException) {
-                        getTrainedModelListener.onFailure(elasticsearchException);
+                        List<TrainedModelDefinitionDoc> docs = handleHits(
+                            definitionSearchResponse.getHits().getHits(),
+                            modelId,
+                            (bytes, resourceId) -> ChunkedTrainedModelRestorer.parseModelDefinitionDocLenientlyFromSource(
+                                bytes,
+                                resourceId,
+                                xContentRegistry
+                            )
+                        );
+                        try {
+                            BytesReference compressedData = getDefinitionFromDocs(docs, modelId);
+                            builder.setDefinitionFromBytes(compressedData);
+                        } catch (ElasticsearchException elasticsearchException) {
+                            getTrainedModelListener.onFailure(elasticsearchException);
+                            return;
+                        }
+
+                    } catch (ResourceNotFoundException ex) {
+                        getTrainedModelListener.onFailure(
+                            new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))
+                        );
+                        return;
+                    } catch (Exception ex) {
+                        getTrainedModelListener.onFailure(ex);
                         return;
                     }
-
-                } catch (ResourceNotFoundException ex) {
-                    getTrainedModelListener.onFailure(
-                        new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))
-                    );
-                    return;
-                } catch (Exception ex) {
-                    getTrainedModelListener.onFailure(ex);
-                    return;
-                }
-            }
-            getTrainedModelListener.onResponse(builder);
+                    getTrainedModelListener.onResponse(builder);
+                }, getTrainedModelListener::onFailure)
+            );
         }, getTrainedModelListener::onFailure);
-
-        executeAsyncWithOrigin(
-            client,
-            ML_ORIGIN,
-            MultiSearchAction.INSTANCE,
-            multiSearchRequestBuilder.request(),
-            multiSearchResponseActionListener
-        );
+        executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, trainedModelConfigSearch, trainedModelSearchHandler);
     }
 
     public void getTrainedModels(
@@ -1204,6 +1205,9 @@ public class TrainedModelProvider {
         String resourceId,
         CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently
     ) throws Exception {
+        if (hits.length == 0) {
+            throw new ResourceNotFoundException(resourceId);
+        }
         List<T> results = new ArrayList<>(hits.length);
         String initialIndex = hits[0].getIndex();
         for (SearchHit hit : hits) {