Jelajahi Sumber

[ML][Inference] Add support for models shipped as resources (#50680)

This adds support for models that are shipped as resources in the ML plugin. The first of which is the `lang_ident` model.
Benjamin Trent 5 tahun lalu
induk
melakukan
d1f317be22

+ 5 - 5
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -2168,8 +2168,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
             GetTrainedModelsResponse getTrainedModelsResponse = execute(
                 GetTrainedModelsRequest.getAllTrainedModelConfigsRequest(),
                 machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync);
-            assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(numberOfModels));
-            assertThat(getTrainedModelsResponse.getCount(), equalTo(5L));
+            assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(numberOfModels + 1));
+            assertThat(getTrainedModelsResponse.getCount(), equalTo(5L + 1));
         }
         {
             GetTrainedModelsResponse getTrainedModelsResponse = execute(
@@ -2192,7 +2192,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
 
     public void testGetTrainedModelsStats() throws Exception {
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
-        String modelIdPrefix = "get-trained-model-stats-";
+        String modelIdPrefix = "a-get-trained-model-stats-";
         int numberOfModels = 5;
         for (int i = 0; i < numberOfModels; ++i) {
             String modelId = modelIdPrefix + i;
@@ -2224,8 +2224,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
             GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
                 GetTrainedModelsStatsRequest.getAllTrainedModelStatsRequest(),
                 machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
-            assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels));
-            assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L));
+            assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels + 1));
+            assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L + 1));
             assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(0).getPipelineCount(), equalTo(1));
             assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(1).getPipelineCount(), equalTo(0));
         }

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java

@@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -98,6 +99,10 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 pipelineCount = in.readVInt();
             }
 
+            public String getModelId() {
+                return modelId;
+            }
+
             @Override
             public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
                 builder.startObject();
@@ -186,6 +191,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                         0 :
                         ingestStats.getPipelineStats().size()));
                 });
+                trainedModelStats.sort(Comparator.comparing(TrainedModelStats::getModelId));
                 return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
             }
         }

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

@@ -409,6 +409,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
             return this;
         }
 
+        public Builder clearDefinition() {
+            this.definition = null;
+            return this;
+        }
+
         private Builder setLazyDefinition(TrainedModelDefinition.Builder parsedTrainedModel) {
             if (parsedTrainedModel == null) {
                 return this;

+ 3 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

@@ -87,10 +87,12 @@ public final class Messages {
     public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
         "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
     public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
+    public static final String INFERENCE_CANNOT_DELETE_MODEL =
+        "Unable to delete model [{0}]";
     public static final String MODEL_DEFINITION_TRUNCATED =
         "Model definition truncated. Unable to deserialize trained model definition [{0}]";
     public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
-    public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED =
+    public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED =
         "Getting model definition is not supported when getting more than one model";
     public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
 

+ 26 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java

@@ -232,6 +232,32 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
             containsString("Could not find trained model [test_classification_missing]"));
     }
 
+    public void testSimulateLangIdent() {
+        String source = "{\n" +
+            "  \"pipeline\": {\n" +
+            "    \"processors\": [\n" +
+            "      {\n" +
+            "        \"inference\": {\n" +
+            "          \"inference_config\": {\"classification\":{}},\n" +
+            "          \"model_id\": \"lang_ident_model_1\",\n" +
+            "          \"field_mappings\": {}\n" +
+            "        }\n" +
+            "      }\n" +
+            "    ]\n" +
+            "  },\n" +
+            "  \"docs\": [\n" +
+            "    {\"_source\": {\n" +
+            "      \"text\": \"this is some plain text.\"\n" +
+            "    }}]\n" +
+            "}";
+
+        SimulatePipelineResponse response = client().admin().cluster()
+            .prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
+                XContentType.JSON).get();
+        SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
+        assertThat(baseResult.getIngestDocument().getFieldValue("ml.inference.predicted_value", String.class), equalTo("en"));
+    }
+
     private Map<String, Object> generateSourceDoc() {
         return new HashMap<>(){{
             put("col1", randomFrom("female", "male"));

+ 21 - 21
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java

@@ -60,8 +60,8 @@ public class TrainedModelIT extends ESRestTestCase {
     }
 
     public void testGetTrainedModels() throws IOException {
-        String modelId = "test_regression_model";
-        String modelId2 = "test_regression_model-2";
+        String modelId = "a_test_regression_model";
+        String modelId2 = "a_test_regression_model-2";
         Request model1 = new Request("PUT",
             InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
         model1.setJsonEntity(buildRegressionModel(modelId));
@@ -84,36 +84,36 @@ public class TrainedModelIT extends ESRestTestCase {
         assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
         String response = EntityUtils.toString(getModel.getEntity());
 
-        assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
+        assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
         assertThat(response, containsString("\"count\":1"));
 
         getModel = client().performRequest(new Request("GET",
-            MachineLearning.BASE_PATH + "inference/test_regression*"));
+            MachineLearning.BASE_PATH + "inference/a_test_regression*"));
         assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
 
         response = EntityUtils.toString(getModel.getEntity());
-        assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
-        assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
+        assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
+        assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\""));
         assertThat(response, not(containsString("\"definition\"")));
         assertThat(response, containsString("\"count\":2"));
 
         getModel = client().performRequest(new Request("GET",
-            MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true"));
+            MachineLearning.BASE_PATH + "inference/a_test_regression_model?human=true&include_model_definition=true"));
         assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
 
         response = EntityUtils.toString(getModel.getEntity());
-        assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
+        assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
         assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
         assertThat(response, containsString("\"estimated_heap_memory_usage\""));
         assertThat(response, containsString("\"definition\""));
         assertThat(response, containsString("\"count\":1"));
 
         getModel = client().performRequest(new Request("GET",
-            MachineLearning.BASE_PATH + "inference/test_regression_model?decompress_definition=false&include_model_definition=true"));
+            MachineLearning.BASE_PATH + "inference/a_test_regression_model?decompress_definition=false&include_model_definition=true"));
         assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
 
         response = EntityUtils.toString(getModel.getEntity());
-        assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
+        assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
         assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
         assertThat(response, containsString("\"compressed_definition\""));
         assertThat(response, not(containsString("\"definition\"")));
@@ -121,17 +121,17 @@ public class TrainedModelIT extends ESRestTestCase {
 
         ResponseException responseException = expectThrows(ResponseException.class, () ->
             client().performRequest(new Request("GET",
-                MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true")));
+                MachineLearning.BASE_PATH + "inference/a_test_regression*?human=true&include_model_definition=true")));
         assertThat(EntityUtils.toString(responseException.getResponse().getEntity()),
-            containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED));
+            containsString(Messages.INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED));
 
         getModel = client().performRequest(new Request("GET",
-            MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2"));
+            MachineLearning.BASE_PATH + "inference/a_test_regression_model,a_test_regression_model-2"));
         assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
 
         response = EntityUtils.toString(getModel.getEntity());
-        assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
-        assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
+        assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
+        assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\""));
         assertThat(response, containsString("\"count\":2"));
 
         getModel = client().performRequest(new Request("GET",
@@ -149,17 +149,17 @@ public class TrainedModelIT extends ESRestTestCase {
         assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
 
         response = EntityUtils.toString(getModel.getEntity());
-        assertThat(response, containsString("\"count\":2"));
-        assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
-        assertThat(response, not(containsString("\"model_id\":\"test_regression_model-2\"")));
+        assertThat(response, containsString("\"count\":3"));
+        assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
+        assertThat(response, not(containsString("\"model_id\":\"a_test_regression_model-2\"")));
 
         getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1"));
         assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
 
         response = EntityUtils.toString(getModel.getEntity());
-        assertThat(response, containsString("\"count\":2"));
-        assertThat(response, not(containsString("\"model_id\":\"test_regression_model\"")));
-        assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
+        assertThat(response, containsString("\"count\":3"));
+        assertThat(response, not(containsString("\"model_id\":\"a_test_regression_model\"")));
+        assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\""));
     }
 
     public void testDeleteTrainedModels() throws IOException {

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java

@@ -50,7 +50,7 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
 
                 if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) {
                     listener.onFailure(
-                        ExceptionsHelper.badRequestException(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED)
+                        ExceptionsHelper.badRequestException(Messages.INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED)
                     );
                     return;
                 }

+ 107 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.persistence;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.ResourceAlreadyExistsException;
 import org.elasticsearch.ResourceNotFoundException;
@@ -31,6 +32,7 @@ import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.collect.Tuple;
+import org.elasticsearch.common.io.Streams;
 import org.elasticsearch.common.regex.Regex;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
@@ -39,6 +41,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentFactory;
+import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.index.IndexNotFoundException;
@@ -64,8 +67,10 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
 
 import java.io.IOException;
 import java.io.InputStream;
+import java.net.URL;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.List;
@@ -78,6 +83,10 @@ import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_FA
 
 public class TrainedModelProvider {
 
+    public static final Set<String> MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1");
+    private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/";
+    private static final String MODEL_RESOURCE_FILE_EXT = ".json";
+
     private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
     private final Client client;
     private final NamedXContentRegistry xContentRegistry;
@@ -91,6 +100,12 @@ public class TrainedModelProvider {
 
     public void storeTrainedModel(TrainedModelConfig trainedModelConfig,
                                   ActionListener<Boolean> listener) {
+        if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) {
+            listener.onFailure(new ResourceAlreadyExistsException(
+                Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
+            return;
+        }
+
         try {
             trainedModelConfig.ensureParsedDefinition(xContentRegistry);
         } catch (IOException ex) {
@@ -184,6 +199,16 @@ public class TrainedModelProvider {
 
     public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
 
+        if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
+            try {
+                listener.onResponse(loadModelFromResource(modelId, includeDefinition == false));
+                return;
+            } catch (ElasticsearchException ex) {
+                listener.onFailure(ex);
+                return;
+            }
+        }
+
         QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
             .idsQuery()
             .addIds(modelId));
@@ -267,11 +292,29 @@ public class TrainedModelProvider {
             .addSort("_index", SortOrder.DESC)
             .setQuery(queryBuilder)
             .request();
+        List<TrainedModelConfig> configs = new ArrayList<>(modelIds.size());
+        Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
+        Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
+        for(String modelId : modelsAsResource) {
+            try {
+                configs.add(loadModelFromResource(modelId, true));
+            } catch (ElasticsearchException ex) {
+                listener.onFailure(ex);
+                return;
+            }
+        }
+        if (modelsInIndex.isEmpty()) {
+            configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
+            listener.onResponse(configs);
+            return;
+        }
 
         ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
             searchResponse -> {
-                Set<String> observedIds = new HashSet<>(searchResponse.getHits().getHits().length, 1.0f);
-                List<TrainedModelConfig> configs = new ArrayList<>(searchResponse.getHits().getHits().length);
+                Set<String> observedIds = new HashSet<>(
+                    searchResponse.getHits().getHits().length + modelsAsResource.size(),
+                    1.0f);
+                observedIds.addAll(modelsAsResource);
                 for(SearchHit searchHit : searchResponse.getHits().getHits()) {
                     try {
                         if (observedIds.contains(searchHit.getId()) == false) {
@@ -294,6 +337,8 @@ public class TrainedModelProvider {
                     listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
                     return;
                 }
+                // Ensure sorted even with the injection of locally resourced models
+                configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
                 listener.onResponse(configs);
             },
             listener::onFailure
@@ -303,6 +348,10 @@ public class TrainedModelProvider {
     }
 
     public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener) {
+        if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
+            listener.onFailure(ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CANNOT_DELETE_MODEL, modelId)));
+            return;
+        }
         DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false);
 
         request.indices(InferenceIndexConstants.INDEX_PATTERN);
@@ -359,8 +408,8 @@ public class TrainedModelProvider {
             searchRequest,
             ActionListener.<SearchResponse>wrap(
                 response -> {
-                    Set<String> foundResourceIds = new LinkedHashSet<>();
-                    long totalHitCount = response.getHits().getTotalHits().value;
+                    Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens));
+                    long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
                     for (SearchHit hit : response.getHits().getHits()) {
                         Map<String, Object> docSource = hit.getSourceAsMap();
                         if (docSource == null) {
@@ -385,6 +434,37 @@ public class TrainedModelProvider {
 
     }
 
+    TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
+        URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
+        if (resource == null) {
+            logger.error("[{}] presumed stored as a resource but not found", modelId);
+            throw new ResourceNotFoundException(
+                Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId));
+        }
+        try {
+            BytesReference bytes = Streams.readFully(getClass()
+                .getResourceAsStream(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT));
+            try (XContentParser parser =
+                     XContentHelper.createParser(xContentRegistry,
+                         LoggingDeprecationHandler.INSTANCE,
+                         bytes,
+                         XContentType.JSON)) {
+                TrainedModelConfig.Builder builder = TrainedModelConfig.fromXContent(parser, true);
+                if (nullOutDefinition) {
+                    builder.clearDefinition();
+                }
+                return builder.build();
+            } catch (IOException ioEx) {
+                logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
+                throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);
+            }
+        } catch (IOException ex) {
+            String msg = new ParameterizedMessage("[{}] failed to read model as resource", modelId).getFormattedMessage();
+            logger.error(msg, ex);
+            throw ExceptionsHelper.serverError(msg, ex);
+        }
+    }
+
     private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
         BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
             .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));
@@ -413,6 +493,29 @@ public class TrainedModelProvider {
         return boolQuery;
     }
 
+    private Set<String> matchedResourceIds(String[] tokens) {
+        if (Strings.isAllOrWildcard(tokens)) {
+            return new HashSet<>(MODELS_STORED_AS_RESOURCE);
+        }
+
+        Set<String> matchedModels = new HashSet<>();
+
+        for (String token : tokens) {
+            if (Regex.isSimpleMatchPattern(token)) {
+                for (String modelId : MODELS_STORED_AS_RESOURCE) {
+                    if(Regex.simpleMatch(token, modelId)) {
+                        matchedModels.add(modelId);
+                    }
+                }
+            } else {
+                if (MODELS_STORED_AS_RESOURCE.contains(token)) {
+                    matchedModels.add(token);
+                }
+            }
+        }
+        return matchedModels;
+    }
+
     private static <T> T handleSearchItem(MultiSearchResponse.Item item,
                                           String resourceId,
                                           CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {

+ 74 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java

@@ -0,0 +1,74 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.inference.persistence;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
+import static org.mockito.Mockito.mock;
+
+public class TrainedModelProviderTests extends ESTestCase {
+
+    public void testDeleteModelStoredAsResource() {
+        TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
+        PlainActionFuture<Boolean> future = new PlainActionFuture<>();
+        // Should be OK as we don't make any client calls
+        trainedModelProvider.deleteTrainedModel("lang_ident_model_1", future);
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class, future::actionGet);
+        assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_CANNOT_DELETE_MODEL, "lang_ident_model_1")));
+    }
+
+    public void testPutModelThatExistsAsResource() {
+        TrainedModelConfig config = TrainedModelConfigTests.createTestInstance("lang_ident_model_1").build();
+        TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
+        PlainActionFuture<Boolean> future = new PlainActionFuture<>();
+        trainedModelProvider.storeTrainedModel(config, future);
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class, future::actionGet);
+        assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, "lang_ident_model_1")));
+    }
+
+    public void testGetModelThatExistsAsResource() throws Exception {
+        TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
+        for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
+            PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
+            trainedModelProvider.getTrainedModel(modelId, true, future);
+            TrainedModelConfig configWithDefinition = future.actionGet();
+
+            assertThat(configWithDefinition.getModelId(), equalTo(modelId));
+            assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
+
+            PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
+            trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition);
+            TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
+
+            assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));
+            assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
+        }
+    }
+
+    public void testGetModelThatExistsAsResourceButIsMissing() {
+        TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
+        ElasticsearchException ex = expectThrows(ElasticsearchException.class,
+            () -> trainedModelProvider.loadModelFromResource("missing_model", randomBoolean()));
+        assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, "missing_model")));
+    }
+
+    @Override
+    protected NamedXContentRegistry xContentRegistry() {
+        return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
+    }
+}

+ 10 - 20
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java

@@ -5,10 +5,9 @@
  */
 package org.elasticsearch.xpack.ml.inference.trainedmodels.langident;
 
-import org.elasticsearch.common.xcontent.DeprecationHandler;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.client.Client;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
-import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@@ -16,22 +15,26 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples;
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 
-import java.io.IOException;
-import java.nio.file.Files;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.Matchers.closeTo;
-
+import static org.mockito.Mockito.mock;
 
 public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
 
     public void testLangInference() throws Exception {
-        TrainedModelConfig config = getLangIdentModel();
+        TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
+        PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
+        // Should be OK as we don't make any client calls
+        trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future);
+        TrainedModelConfig config = future.actionGet();
 
+        config.ensureParsedDefinition(xContentRegistry());
         TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
         List<LanguageExamples.LanguageExampleEntry> examples = new LanguageExamples().getLanguageExamples();
         ClassificationConfig classificationConfig = new ClassificationConfig(1);
@@ -53,19 +56,6 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
         }
     }
 
-    private TrainedModelConfig getLangIdentModel() throws IOException {
-        String path = "/org/elasticsearch/xpack/ml/inference/persistence/lang_ident_model_1.json";
-        try(XContentParser parser =
-                XContentType.JSON.xContent().createParser(
-                    NamedXContentRegistry.EMPTY,
-                    DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
-                    Files.newInputStream(getDataPath(path)))) {
-            TrainedModelConfig config = TrainedModelConfig.fromXContent(parser, true).build();
-            config.ensureParsedDefinition(xContentRegistry());
-            return config;
-        }
-    }
-
     @Override
     protected NamedXContentRegistry xContentRegistry() {
         return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());

+ 8 - 15
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

@@ -1,18 +1,3 @@
----
-"Test get-all given no trained models exist":
-
-  - do:
-      ml.get_trained_models:
-        model_id: "_all"
-  - match: { count: 0 }
-  - match: { trained_model_configs: [] }
-
-  - do:
-      ml.get_trained_models:
-        model_id: "*"
-  - match: { count: 0 }
-  - match: { trained_model_configs: [] }
-
 ---
 "Test get given missing trained model":
 
@@ -111,3 +96,11 @@
       catch: conflict
       ml.delete_trained_model:
         model_id: "used-regression-model"
+---
+"Test get pre-packaged trained models":
+  - do:
+      ml.get_trained_models:
+        model_id: "_all"
+        allow_no_match: false
+  - match: { count: 1 }
+  - match: { trained_model_configs.0.model_id: "lang_ident_model_1" }

+ 24 - 30
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml

@@ -5,17 +5,15 @@ setup:
       headers:
         Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
       index:
-        id: trained_model_config-unused-regression-model1-0
+        id: trained_model_config-a-unused-regression-model1-0
         index: .ml-inference-000001
         body: >
           {
-            "model_id": "unused-regression-model1",
+            "model_id": "a-unused-regression-model1",
             "created_by": "ml_tests",
             "version": "8.0.0",
             "description": "empty model for tests",
             "create_time": 0,
-            "model_version": 0,
-            "model_type": "local",
             "doc_type": "trained_model_config"
           }
 
@@ -23,34 +21,30 @@ setup:
       headers:
         Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
       index:
-        id: trained_model_config-unused-regression-model-0
+        id: trained_model_config-a-unused-regression-model-0
         index: .ml-inference-000001
         body: >
           {
-            "model_id": "unused-regression-model",
+            "model_id": "a-unused-regression-model",
             "created_by": "ml_tests",
             "version": "8.0.0",
             "description": "empty model for tests",
             "create_time": 0,
-            "model_version": 0,
-            "model_type": "local",
             "doc_type": "trained_model_config"
           }
   - do:
       headers:
         Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
       index:
-        id: trained_model_config-used-regression-model-0
+        id: trained_model_config-a-used-regression-model-0
         index: .ml-inference-000001
         body: >
           {
-            "model_id": "used-regression-model",
+            "model_id": "a-used-regression-model",
             "created_by": "ml_tests",
             "version": "8.0.0",
             "description": "empty model for tests",
             "create_time": 0,
-            "model_version": 0,
-            "model_type": "local",
             "doc_type": "trained_model_config"
           }
 
@@ -69,7 +63,7 @@ setup:
             "processors": [
               {
                 "inference" : {
-                  "model_id" : "used-regression-model",
+                  "model_id" : "a-used-regression-model",
                   "inference_config": {"regression": {}},
                   "target_field": "regression_field",
                   "field_mappings": {}
@@ -87,7 +81,7 @@ setup:
             "processors": [
               {
                 "inference" : {
-                  "model_id" : "used-regression-model",
+                  "model_id" : "a-used-regression-model",
                   "inference_config": {"regression": {}},
                   "target_field": "regression_field",
                   "field_mappings": {}
@@ -125,18 +119,18 @@ setup:
 
   - do:
       ml.get_trained_models_stats:
-        model_id: "unused-regression-model"
+        model_id: "a-unused-regression-model"
 
   - match: { count: 1 }
 
   - do:
       ml.get_trained_models_stats:
         model_id: "_all"
-  - match: { count: 3 }
-  - match: { trained_model_stats.0.model_id: unused-regression-model }
+  - match: { count: 4 }
+  - match: { trained_model_stats.0.model_id: a-unused-regression-model }
   - match: { trained_model_stats.0.pipeline_count: 0 }
   - is_false: trained_model_stats.0.ingest
-  - match: { trained_model_stats.1.model_id: unused-regression-model1 }
+  - match: { trained_model_stats.1.model_id: a-unused-regression-model1 }
   - match: { trained_model_stats.1.pipeline_count: 0 }
   - is_false: trained_model_stats.1.ingest
   - match: { trained_model_stats.2.pipeline_count: 2 }
@@ -145,11 +139,11 @@ setup:
   - do:
       ml.get_trained_models_stats:
         model_id: "*"
-  - match: { count: 3 }
-  - match: { trained_model_stats.0.model_id: unused-regression-model }
+  - match: { count: 4 }
+  - match: { trained_model_stats.0.model_id: a-unused-regression-model }
   - match: { trained_model_stats.0.pipeline_count: 0 }
   - is_false: trained_model_stats.0.ingest
-  - match: { trained_model_stats.1.model_id: unused-regression-model1 }
+  - match: { trained_model_stats.1.model_id: a-unused-regression-model1 }
   - match: { trained_model_stats.1.pipeline_count: 0 }
   - is_false: trained_model_stats.1.ingest
   - match: { trained_model_stats.2.pipeline_count: 2 }
@@ -157,40 +151,40 @@ setup:
 
   - do:
       ml.get_trained_models_stats:
-        model_id: "unused-regression-model*"
+        model_id: "a-unused-regression-model*"
   - match: { count: 2 }
-  - match: { trained_model_stats.0.model_id: unused-regression-model }
+  - match: { trained_model_stats.0.model_id: a-unused-regression-model }
   - match: { trained_model_stats.0.pipeline_count: 0 }
   - is_false: trained_model_stats.0.ingest
-  - match: { trained_model_stats.1.model_id: unused-regression-model1 }
+  - match: { trained_model_stats.1.model_id: a-unused-regression-model1 }
   - match: { trained_model_stats.1.pipeline_count: 0 }
   - is_false: trained_model_stats.1.ingest
 
   - do:
       ml.get_trained_models_stats:
-        model_id: "unused-regression-model*"
+        model_id: "a-unused-regression-model*"
         size: 1
   - match: { count: 2 }
-  - match: { trained_model_stats.0.model_id: unused-regression-model }
+  - match: { trained_model_stats.0.model_id: a-unused-regression-model }
   - match: { trained_model_stats.0.pipeline_count: 0 }
   - is_false: trained_model_stats.0.ingest
 
   - do:
       ml.get_trained_models_stats:
-        model_id: "unused-regression-model*"
+        model_id: "a-unused-regression-model*"
         from: 1
         size: 1
   - match: { count: 2 }
-  - match: { trained_model_stats.0.model_id: unused-regression-model1 }
+  - match: { trained_model_stats.0.model_id: a-unused-regression-model1 }
   - match: { trained_model_stats.0.pipeline_count: 0 }
   - is_false: trained_model_stats.0.ingest
 
   - do:
       ml.get_trained_models_stats:
-        model_id: "used-regression-model"
+        model_id: "a-used-regression-model"
 
   - match: { count: 1 }
-  - match: { trained_model_stats.0.model_id: used-regression-model }
+  - match: { trained_model_stats.0.model_id: a-used-regression-model }
   - match: { trained_model_stats.0.pipeline_count: 2 }
   - match:
       trained_model_stats.0.ingest.total: