|
@@ -69,6 +69,7 @@ import java.io.IOException;
|
|
|
import java.io.InputStream;
|
|
|
import java.net.URL;
|
|
|
import java.util.ArrayList;
|
|
|
+import java.util.Collection;
|
|
|
import java.util.Collections;
|
|
|
import java.util.Comparator;
|
|
|
import java.util.HashSet;
|
|
@@ -381,6 +382,7 @@ public class TrainedModelProvider {
|
|
|
public void expandIds(String idExpression,
|
|
|
boolean allowNoResources,
|
|
|
@Nullable PageParams pageParams,
|
|
|
+ Set<String> tags,
|
|
|
ActionListener<Tuple<Long, Set<String>>> idsListener) {
|
|
|
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
|
|
|
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
|
|
@@ -388,7 +390,7 @@ public class TrainedModelProvider {
|
|
|
// If there are no resources, there might be no mapping for the id field.
|
|
|
// This makes sure we don't get an error if that happens.
|
|
|
.unmappedType("long"))
|
|
|
- .query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
|
|
|
+ .query(buildExpandIdsQuery(tokens, tags));
|
|
|
if (pageParams != null) {
|
|
|
sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
|
|
|
}
|
|
@@ -404,13 +406,23 @@ public class TrainedModelProvider {
|
|
|
indicesOptions.expandWildcardsClosed(),
|
|
|
indicesOptions))
|
|
|
.source(sourceBuilder);
|
|
|
+ Set<String> foundResourceIds = new LinkedHashSet<>();
|
|
|
+ if (tags.isEmpty()) {
|
|
|
+ foundResourceIds.addAll(matchedResourceIds(tokens));
|
|
|
+ } else {
|
|
|
+ for(String resourceId : matchedResourceIds(tokens)) {
|
|
|
+ // Does the model as a resource have all the tags?
|
|
|
+ if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
|
|
|
+ foundResourceIds.add(resourceId);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
|
|
ML_ORIGIN,
|
|
|
searchRequest,
|
|
|
ActionListener.<SearchResponse>wrap(
|
|
|
response -> {
|
|
|
- Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens));
|
|
|
long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
|
|
|
for (SearchHit hit : response.getHits().getHits()) {
|
|
|
Map<String, Object> docSource = hit.getSourceAsMap();
|
|
@@ -433,7 +445,15 @@ public class TrainedModelProvider {
|
|
|
idsListener::onFailure
|
|
|
),
|
|
|
client::search);
|
|
|
+ }
|
|
|
|
|
|
+ static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) {
|
|
|
+ BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery()
|
|
|
+ .filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
|
|
|
+ for(String tag : tags) {
|
|
|
+ boolQueryBuilder.filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), tag));
|
|
|
+ }
|
|
|
+ return QueryBuilders.constantScoreQuery(boolQueryBuilder);
|
|
|
}
|
|
|
|
|
|
TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
|
|
@@ -467,7 +487,7 @@ public class TrainedModelProvider {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
|
|
|
+ private static QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
|
|
|
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
|
|
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));
|
|
|
|