|
@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.persistence;
|
|
|
import org.apache.logging.log4j.LogManager;
|
|
|
import org.apache.logging.log4j.Logger;
|
|
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
|
|
+import org.elasticsearch.ElasticsearchException;
|
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.ResourceAlreadyExistsException;
|
|
|
import org.elasticsearch.ResourceNotFoundException;
|
|
@@ -31,6 +32,7 @@ import org.elasticsearch.common.Nullable;
|
|
|
import org.elasticsearch.common.Strings;
|
|
|
import org.elasticsearch.common.bytes.BytesReference;
|
|
|
import org.elasticsearch.common.collect.Tuple;
|
|
|
+import org.elasticsearch.common.io.Streams;
|
|
|
import org.elasticsearch.common.regex.Regex;
|
|
|
import org.elasticsearch.common.util.set.Sets;
|
|
|
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
|
@@ -39,6 +41,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
|
|
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
|
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.common.xcontent.XContentFactory;
|
|
|
+import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.common.xcontent.XContentParser;
|
|
|
import org.elasticsearch.common.xcontent.XContentType;
|
|
|
import org.elasticsearch.index.IndexNotFoundException;
|
|
@@ -64,8 +67,10 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.io.InputStream;
|
|
|
+import java.net.URL;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.Collections;
|
|
|
+import java.util.Comparator;
|
|
|
import java.util.HashSet;
|
|
|
import java.util.LinkedHashSet;
|
|
|
import java.util.List;
|
|
@@ -78,6 +83,10 @@ import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_FA
|
|
|
|
|
|
public class TrainedModelProvider {
|
|
|
|
|
|
+ public static final Set<String> MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1");
|
|
|
+ private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/";
|
|
|
+ private static final String MODEL_RESOURCE_FILE_EXT = ".json";
|
|
|
+
|
|
|
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
|
|
|
private final Client client;
|
|
|
private final NamedXContentRegistry xContentRegistry;
|
|
@@ -91,6 +100,12 @@ public class TrainedModelProvider {
|
|
|
|
|
|
public void storeTrainedModel(TrainedModelConfig trainedModelConfig,
|
|
|
ActionListener<Boolean> listener) {
|
|
|
+ if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) {
|
|
|
+ listener.onFailure(new ResourceAlreadyExistsException(
|
|
|
+ Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
try {
|
|
|
trainedModelConfig.ensureParsedDefinition(xContentRegistry);
|
|
|
} catch (IOException ex) {
|
|
@@ -184,6 +199,16 @@ public class TrainedModelProvider {
|
|
|
|
|
|
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
|
|
|
|
|
|
+ if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
|
|
+ try {
|
|
|
+ listener.onResponse(loadModelFromResource(modelId, includeDefinition == false));
|
|
|
+ return;
|
|
|
+ } catch (ElasticsearchException ex) {
|
|
|
+ listener.onFailure(ex);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
|
|
|
.idsQuery()
|
|
|
.addIds(modelId));
|
|
@@ -267,11 +292,29 @@ public class TrainedModelProvider {
|
|
|
.addSort("_index", SortOrder.DESC)
|
|
|
.setQuery(queryBuilder)
|
|
|
.request();
|
|
|
+ List<TrainedModelConfig> configs = new ArrayList<>(modelIds.size());
|
|
|
+ Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
|
|
|
+ Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
|
|
|
+ for(String modelId : modelsAsResource) {
|
|
|
+ try {
|
|
|
+ configs.add(loadModelFromResource(modelId, true));
|
|
|
+ } catch (ElasticsearchException ex) {
|
|
|
+ listener.onFailure(ex);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (modelsInIndex.isEmpty()) {
|
|
|
+ configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
|
|
|
+ listener.onResponse(configs);
|
|
|
+ return;
|
|
|
+ }
|
|
|
|
|
|
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
|
|
|
searchResponse -> {
|
|
|
- Set<String> observedIds = new HashSet<>(searchResponse.getHits().getHits().length, 1.0f);
|
|
|
- List<TrainedModelConfig> configs = new ArrayList<>(searchResponse.getHits().getHits().length);
|
|
|
+ Set<String> observedIds = new HashSet<>(
|
|
|
+ searchResponse.getHits().getHits().length + modelsAsResource.size(),
|
|
|
+ 1.0f);
|
|
|
+ observedIds.addAll(modelsAsResource);
|
|
|
for(SearchHit searchHit : searchResponse.getHits().getHits()) {
|
|
|
try {
|
|
|
if (observedIds.contains(searchHit.getId()) == false) {
|
|
@@ -294,6 +337,8 @@ public class TrainedModelProvider {
|
|
|
listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
|
|
|
return;
|
|
|
}
|
|
|
+ // Ensure sorted even with the injection of locally resourced models
|
|
|
+ configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
|
|
|
listener.onResponse(configs);
|
|
|
},
|
|
|
listener::onFailure
|
|
@@ -303,6 +348,10 @@ public class TrainedModelProvider {
|
|
|
}
|
|
|
|
|
|
public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener) {
|
|
|
+ if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
|
|
+ listener.onFailure(ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CANNOT_DELETE_MODEL, modelId)));
|
|
|
+ return;
|
|
|
+ }
|
|
|
DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false);
|
|
|
|
|
|
request.indices(InferenceIndexConstants.INDEX_PATTERN);
|
|
@@ -359,8 +408,8 @@ public class TrainedModelProvider {
|
|
|
searchRequest,
|
|
|
ActionListener.<SearchResponse>wrap(
|
|
|
response -> {
|
|
|
- Set<String> foundResourceIds = new LinkedHashSet<>();
|
|
|
- long totalHitCount = response.getHits().getTotalHits().value;
|
|
|
+ Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens));
|
|
|
+ long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
|
|
|
for (SearchHit hit : response.getHits().getHits()) {
|
|
|
Map<String, Object> docSource = hit.getSourceAsMap();
|
|
|
if (docSource == null) {
|
|
@@ -385,6 +434,37 @@ public class TrainedModelProvider {
|
|
|
|
|
|
}
|
|
|
|
|
|
+ TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
|
|
|
+ URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
|
|
|
+ if (resource == null) {
|
|
|
+ logger.error("[{}] presumed stored as a resource but not found", modelId);
|
|
|
+ throw new ResourceNotFoundException(
|
|
|
+ Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId));
|
|
|
+ }
|
|
|
+ try {
|
|
|
+ BytesReference bytes = Streams.readFully(getClass()
|
|
|
+ .getResourceAsStream(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT));
|
|
|
+ try (XContentParser parser =
|
|
|
+ XContentHelper.createParser(xContentRegistry,
|
|
|
+ LoggingDeprecationHandler.INSTANCE,
|
|
|
+ bytes,
|
|
|
+ XContentType.JSON)) {
|
|
|
+ TrainedModelConfig.Builder builder = TrainedModelConfig.fromXContent(parser, true);
|
|
|
+ if (nullOutDefinition) {
|
|
|
+ builder.clearDefinition();
|
|
|
+ }
|
|
|
+ return builder.build();
|
|
|
+ } catch (IOException ioEx) {
|
|
|
+ logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
|
|
|
+ throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);
|
|
|
+ }
|
|
|
+ } catch (IOException ex) {
|
|
|
+ String msg = new ParameterizedMessage("[{}] failed to read model as resource", modelId).getFormattedMessage();
|
|
|
+ logger.error(msg, ex);
|
|
|
+ throw ExceptionsHelper.serverError(msg, ex);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
|
|
|
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
|
|
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));
|
|
@@ -413,6 +493,29 @@ public class TrainedModelProvider {
|
|
|
return boolQuery;
|
|
|
}
|
|
|
|
|
|
+ private Set<String> matchedResourceIds(String[] tokens) {
|
|
|
+ if (Strings.isAllOrWildcard(tokens)) {
|
|
|
+ return new HashSet<>(MODELS_STORED_AS_RESOURCE);
|
|
|
+ }
|
|
|
+
|
|
|
+ Set<String> matchedModels = new HashSet<>();
|
|
|
+
|
|
|
+ for (String token : tokens) {
|
|
|
+ if (Regex.isSimpleMatchPattern(token)) {
|
|
|
+ for (String modelId : MODELS_STORED_AS_RESOURCE) {
|
|
|
+ if(Regex.simpleMatch(token, modelId)) {
|
|
|
+ matchedModels.add(modelId);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if (MODELS_STORED_AS_RESOURCE.contains(token)) {
|
|
|
+ matchedModels.add(token);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return matchedModels;
|
|
|
+ }
|
|
|
+
|
|
|
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
|
|
|
String resourceId,
|
|
|
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
|