|
@@ -88,9 +88,11 @@ import java.util.Arrays;
|
|
|
import java.util.Collection;
|
|
|
import java.util.Collections;
|
|
|
import java.util.Comparator;
|
|
|
+import java.util.HashMap;
|
|
|
import java.util.HashSet;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
+import java.util.Objects;
|
|
|
import java.util.Set;
|
|
|
import java.util.TreeSet;
|
|
|
import java.util.stream.Collectors;
|
|
@@ -234,14 +236,14 @@ public class TrainedModelProvider {
|
|
|
));
|
|
|
}
|
|
|
|
|
|
- public void getTrainedModelMetadata(String modelId, ActionListener<TrainedModelMetadata> listener) {
|
|
|
+ public void getTrainedModelMetadata(Collection<String> modelIds, ActionListener<Map<String, TrainedModelMetadata>> listener) {
|
|
|
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
|
|
.setQuery(QueryBuilders.constantScoreQuery(QueryBuilders
|
|
|
.boolQuery()
|
|
|
- .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
|
|
|
+ .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds))
|
|
|
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(),
|
|
|
TrainedModelMetadata.NAME))))
|
|
|
- .setSize(1)
|
|
|
+ .setSize(10_000)
|
|
|
// First find the latest index
|
|
|
.addSort("_index", SortOrder.DESC)
|
|
|
.request();
|
|
@@ -249,18 +251,20 @@ public class TrainedModelProvider {
|
|
|
searchResponse -> {
|
|
|
if (searchResponse.getHits().getHits().length == 0) {
|
|
|
listener.onFailure(new ResourceNotFoundException(
|
|
|
- Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
|
|
|
+ Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
|
|
|
return;
|
|
|
}
|
|
|
- List<TrainedModelMetadata> metadataList = handleHits(searchResponse.getHits().getHits(),
|
|
|
- modelId,
|
|
|
- this::parseMetadataLenientlyFromSource);
|
|
|
- listener.onResponse(metadataList.get(0));
|
|
|
+ HashMap<String, TrainedModelMetadata> map = new HashMap<>();
|
|
|
+ for (SearchHit hit : searchResponse.getHits().getHits()) {
|
|
|
+ String modelId = TrainedModelMetadata.modelId(Objects.requireNonNull(hit.getId()));
|
|
|
+ map.putIfAbsent(modelId, parseMetadataLenientlyFromSource(hit.getSourceRef(), modelId));
|
|
|
+ }
|
|
|
+ listener.onResponse(map);
|
|
|
},
|
|
|
e -> {
|
|
|
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
|
|
|
listener.onFailure(new ResourceNotFoundException(
|
|
|
- Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId)));
|
|
|
+ Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelIds)));
|
|
|
return;
|
|
|
}
|
|
|
listener.onFailure(e);
|
|
@@ -370,7 +374,7 @@ public class TrainedModelProvider {
|
|
|
// TODO Change this when we get more than just langIdent stored
|
|
|
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
|
|
try {
|
|
|
- TrainedModelConfig config = loadModelFromResource(modelId, false).ensureParsedDefinition(xContentRegistry);
|
|
|
+ TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry);
|
|
|
assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork;
|
|
|
listener.onResponse(
|
|
|
InferenceDefinition.builder()
|
|
@@ -433,18 +437,50 @@ public class TrainedModelProvider {
|
|
|
));
|
|
|
}
|
|
|
|
|
|
- public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
|
|
|
+ public void getTrainedModel(final String modelId,
|
|
|
+ final boolean includeDefinition,
|
|
|
+ final boolean includeTotalFeatureImportance,
|
|
|
+ final ActionListener<TrainedModelConfig> finalListener) {
|
|
|
|
|
|
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
|
|
try {
|
|
|
- listener.onResponse(loadModelFromResource(modelId, includeDefinition == false));
|
|
|
+ finalListener.onResponse(loadModelFromResource(modelId, includeDefinition == false).build());
|
|
|
return;
|
|
|
} catch (ElasticsearchException ex) {
|
|
|
- listener.onFailure(ex);
|
|
|
+ finalListener.onFailure(ex);
|
|
|
return;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ ActionListener<TrainedModelConfig.Builder> getTrainedModelListener = ActionListener.wrap(
|
|
|
+ modelBuilder -> {
|
|
|
+ if (includeTotalFeatureImportance == false) {
|
|
|
+ finalListener.onResponse(modelBuilder.build());
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ this.getTrainedModelMetadata(Collections.singletonList(modelId), ActionListener.wrap(
|
|
|
+ metadata -> {
|
|
|
+ TrainedModelMetadata modelMetadata = metadata.get(modelId);
|
|
|
+ if (modelMetadata != null) {
|
|
|
+ modelBuilder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
|
|
|
+ }
|
|
|
+ finalListener.onResponse(modelBuilder.build());
|
|
|
+ },
|
|
|
+ failure -> {
|
|
|
+ // total feature importance is not necessary for a model to be valid
|
|
|
+ // we shouldn't fail if it is not found
|
|
|
+ if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
|
|
|
+ finalListener.onResponse(modelBuilder.build());
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ finalListener.onFailure(failure);
|
|
|
+ }
|
|
|
+ ));
|
|
|
+
|
|
|
+ },
|
|
|
+ finalListener::onFailure
|
|
|
+ );
|
|
|
+
|
|
|
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
|
|
|
.idsQuery()
|
|
|
.addIds(modelId));
|
|
@@ -482,11 +518,11 @@ public class TrainedModelProvider {
|
|
|
try {
|
|
|
builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource);
|
|
|
} catch (ResourceNotFoundException ex) {
|
|
|
- listener.onFailure(new ResourceNotFoundException(
|
|
|
+ getTrainedModelListener.onFailure(new ResourceNotFoundException(
|
|
|
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
|
|
|
return;
|
|
|
} catch (Exception ex) {
|
|
|
- listener.onFailure(ex);
|
|
|
+ getTrainedModelListener.onFailure(ex);
|
|
|
return;
|
|
|
}
|
|
|
|
|
@@ -499,22 +535,22 @@ public class TrainedModelProvider {
|
|
|
String compressedString = getDefinitionFromDocs(docs, modelId);
|
|
|
builder.setDefinitionFromString(compressedString);
|
|
|
} catch (ElasticsearchException elasticsearchException) {
|
|
|
- listener.onFailure(elasticsearchException);
|
|
|
+ getTrainedModelListener.onFailure(elasticsearchException);
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
} catch (ResourceNotFoundException ex) {
|
|
|
- listener.onFailure(new ResourceNotFoundException(
|
|
|
+ getTrainedModelListener.onFailure(new ResourceNotFoundException(
|
|
|
Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
|
|
|
return;
|
|
|
} catch (Exception ex) {
|
|
|
- listener.onFailure(ex);
|
|
|
+ getTrainedModelListener.onFailure(ex);
|
|
|
return;
|
|
|
}
|
|
|
}
|
|
|
- listener.onResponse(builder.build());
|
|
|
+ getTrainedModelListener.onResponse(builder);
|
|
|
},
|
|
|
- listener::onFailure
|
|
|
+ getTrainedModelListener::onFailure
|
|
|
);
|
|
|
|
|
|
executeAsyncWithOrigin(client,
|
|
@@ -531,7 +567,10 @@ public class TrainedModelProvider {
|
|
|
* This does no expansion on the ids.
|
|
|
* It assumes that there are fewer than 10k.
|
|
|
*/
|
|
|
- public void getTrainedModels(Set<String> modelIds, boolean allowNoResources, final ActionListener<List<TrainedModelConfig>> listener) {
|
|
|
+ public void getTrainedModels(Set<String> modelIds,
|
|
|
+ boolean allowNoResources,
|
|
|
+ boolean includeTotalFeatureImportance,
|
|
|
+ final ActionListener<List<TrainedModelConfig>> finalListener) {
|
|
|
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
|
|
|
|
|
|
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
|
|
@@ -540,23 +579,63 @@ public class TrainedModelProvider {
|
|
|
.setQuery(queryBuilder)
|
|
|
.setSize(modelIds.size())
|
|
|
.request();
|
|
|
- List<TrainedModelConfig> configs = new ArrayList<>(modelIds.size());
|
|
|
+ List<TrainedModelConfig.Builder> configs = new ArrayList<>(modelIds.size());
|
|
|
Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
|
|
|
Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
|
|
|
for(String modelId : modelsAsResource) {
|
|
|
try {
|
|
|
configs.add(loadModelFromResource(modelId, true));
|
|
|
} catch (ElasticsearchException ex) {
|
|
|
- listener.onFailure(ex);
|
|
|
+ finalListener.onFailure(ex);
|
|
|
return;
|
|
|
}
|
|
|
}
|
|
|
if (modelsInIndex.isEmpty()) {
|
|
|
- configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
|
|
|
- listener.onResponse(configs);
|
|
|
+ finalListener.onResponse(configs.stream()
|
|
|
+ .map(TrainedModelConfig.Builder::build)
|
|
|
+ .sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
|
|
+ .collect(Collectors.toList()));
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
+ ActionListener<List<TrainedModelConfig.Builder>> getTrainedModelListener = ActionListener.wrap(
|
|
|
+ modelBuilders -> {
|
|
|
+ if (includeTotalFeatureImportance == false) {
|
|
|
+ finalListener.onResponse(modelBuilders.stream()
|
|
|
+ .map(TrainedModelConfig.Builder::build)
|
|
|
+ .sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
|
|
+ .collect(Collectors.toList()));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ this.getTrainedModelMetadata(modelIds, ActionListener.wrap(
|
|
|
+ metadata ->
|
|
|
+ finalListener.onResponse(modelBuilders.stream()
|
|
|
+ .map(builder -> {
|
|
|
+ TrainedModelMetadata modelMetadata = metadata.get(builder.getModelId());
|
|
|
+ if (modelMetadata != null) {
|
|
|
+ builder.setFeatureImportance(modelMetadata.getTotalFeatureImportances());
|
|
|
+ }
|
|
|
+ return builder.build();
|
|
|
+ })
|
|
|
+ .sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
|
|
+ .collect(Collectors.toList())),
|
|
|
+ failure -> {
|
|
|
+ // total feature importance is not necessary for a model to be valid
|
|
|
+ // we shouldn't fail if it is not found
|
|
|
+ if (ExceptionsHelper.unwrapCause(failure) instanceof ResourceNotFoundException) {
|
|
|
+ finalListener.onResponse(modelBuilders.stream()
|
|
|
+ .map(TrainedModelConfig.Builder::build)
|
|
|
+ .sorted(Comparator.comparing(TrainedModelConfig::getModelId))
|
|
|
+ .collect(Collectors.toList()));
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ finalListener.onFailure(failure);
|
|
|
+ }
|
|
|
+ ));
|
|
|
+ },
|
|
|
+ finalListener::onFailure
|
|
|
+ );
|
|
|
+
|
|
|
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
|
|
|
searchResponse -> {
|
|
|
Set<String> observedIds = new HashSet<>(
|
|
@@ -567,12 +646,12 @@ public class TrainedModelProvider {
|
|
|
try {
|
|
|
if (observedIds.contains(searchHit.getId()) == false) {
|
|
|
configs.add(
|
|
|
- parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build()
|
|
|
+ parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId())
|
|
|
);
|
|
|
observedIds.add(searchHit.getId());
|
|
|
}
|
|
|
} catch (IOException ex) {
|
|
|
- listener.onFailure(
|
|
|
+ getTrainedModelListener.onFailure(
|
|
|
ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId()));
|
|
|
return;
|
|
|
}
|
|
@@ -582,14 +661,13 @@ public class TrainedModelProvider {
|
|
|
// Otherwise, treat it as if it was never expanded to begin with.
|
|
|
Set<String> missingConfigs = Sets.difference(modelIds, observedIds);
|
|
|
if (missingConfigs.isEmpty() == false && allowNoResources == false) {
|
|
|
- listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
|
|
|
+ getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
|
|
|
return;
|
|
|
}
|
|
|
// Ensure sorted even with the injection of locally resourced models
|
|
|
- configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
|
|
|
- listener.onResponse(configs);
|
|
|
+ getTrainedModelListener.onResponse(configs);
|
|
|
},
|
|
|
- listener::onFailure
|
|
|
+ getTrainedModelListener::onFailure
|
|
|
);
|
|
|
|
|
|
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler);
|
|
@@ -638,7 +716,7 @@ public class TrainedModelProvider {
|
|
|
foundResourceIds = new HashSet<>();
|
|
|
for(String resourceId : matchedResourceIds) {
|
|
|
// Does the model as a resource have all the tags?
|
|
|
- if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
|
|
|
+ if (Sets.newHashSet(loadModelFromResource(resourceId, true).build().getTags()).containsAll(tags)) {
|
|
|
foundResourceIds.add(resourceId);
|
|
|
}
|
|
|
}
|
|
@@ -832,7 +910,7 @@ public class TrainedModelProvider {
|
|
|
return QueryBuilders.constantScoreQuery(boolQueryBuilder);
|
|
|
}
|
|
|
|
|
|
- TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
|
|
|
+ TrainedModelConfig.Builder loadModelFromResource(String modelId, boolean nullOutDefinition) {
|
|
|
URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
|
|
|
if (resource == null) {
|
|
|
logger.error("[{}] presumed stored as a resource but not found", modelId);
|
|
@@ -847,7 +925,7 @@ public class TrainedModelProvider {
|
|
|
if (nullOutDefinition) {
|
|
|
builder.clearDefinition();
|
|
|
}
|
|
|
- return builder.build();
|
|
|
+ return builder;
|
|
|
} catch (IOException ioEx) {
|
|
|
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
|
|
|
throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);
|