|
@@ -28,7 +28,6 @@ import org.elasticsearch.action.support.IndicesOptions;
|
|
|
import org.elasticsearch.action.support.WriteRequest;
|
|
|
import org.elasticsearch.client.Client;
|
|
|
import org.elasticsearch.common.CheckedBiFunction;
|
|
|
-import org.elasticsearch.common.Nullable;
|
|
|
import org.elasticsearch.common.Strings;
|
|
|
import org.elasticsearch.common.bytes.BytesReference;
|
|
|
import org.elasticsearch.common.collect.Tuple;
|
|
@@ -73,10 +72,10 @@ import java.util.Collection;
|
|
|
import java.util.Collections;
|
|
|
import java.util.Comparator;
|
|
|
import java.util.HashSet;
|
|
|
-import java.util.LinkedHashSet;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
+import java.util.TreeSet;
|
|
|
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
|
@@ -381,19 +380,34 @@ public class TrainedModelProvider {
|
|
|
|
|
|
public void expandIds(String idExpression,
|
|
|
boolean allowNoResources,
|
|
|
- @Nullable PageParams pageParams,
|
|
|
+ PageParams pageParams,
|
|
|
Set<String> tags,
|
|
|
ActionListener<Tuple<Long, Set<String>>> idsListener) {
|
|
|
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
|
|
|
+ Set<String> matchedResourceIds = matchedResourceIds(tokens);
|
|
|
+ Set<String> foundResourceIds;
|
|
|
+ if (tags.isEmpty()) {
|
|
|
+ foundResourceIds = matchedResourceIds;
|
|
|
+ } else {
|
|
|
+ foundResourceIds = new HashSet<>();
|
|
|
+ for(String resourceId : matchedResourceIds) {
|
|
|
+ // Does the model as a resource have all the tags?
|
|
|
+ if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
|
|
|
+ foundResourceIds.add(resourceId);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
|
|
|
.sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
|
|
|
// If there are no resources, there might be no mapping for the id field.
|
|
|
// This makes sure we don't get an error if that happens.
|
|
|
.unmappedType("long"))
|
|
|
- .query(buildExpandIdsQuery(tokens, tags));
|
|
|
- if (pageParams != null) {
|
|
|
- sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
|
|
|
- }
|
|
|
+ .query(buildExpandIdsQuery(tokens, tags))
|
|
|
+ // We "buffer" the from and size to take into account models stored as resources.
|
|
|
+ // This is so we handle the edge cases when the model that is stored as a resource is at the start/end of
|
|
|
+ // a page.
|
|
|
+ .from(Math.max(0, pageParams.getFrom() - foundResourceIds.size()))
|
|
|
+ .size(Math.min(10_000, pageParams.getSize() + foundResourceIds.size()));
|
|
|
sourceBuilder.trackTotalHits(true)
|
|
|
// we only care about the item id's
|
|
|
.fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null);
|
|
@@ -406,17 +420,6 @@ public class TrainedModelProvider {
|
|
|
indicesOptions.expandWildcardsClosed(),
|
|
|
indicesOptions))
|
|
|
.source(sourceBuilder);
|
|
|
- Set<String> foundResourceIds = new LinkedHashSet<>();
|
|
|
- if (tags.isEmpty()) {
|
|
|
- foundResourceIds.addAll(matchedResourceIds(tokens));
|
|
|
- } else {
|
|
|
- for(String resourceId : matchedResourceIds(tokens)) {
|
|
|
- // Does the model as a resource have all the tags?
|
|
|
- if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
|
|
|
- foundResourceIds.add(resourceId);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
|
|
|
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
|
|
|
ML_ORIGIN,
|
|
@@ -424,6 +427,7 @@ public class TrainedModelProvider {
|
|
|
ActionListener.<SearchResponse>wrap(
|
|
|
response -> {
|
|
|
long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
|
|
|
+ Set<String> foundFromDocs = new HashSet<>();
|
|
|
for (SearchHit hit : response.getHits().getHits()) {
|
|
|
Map<String, Object> docSource = hit.getSourceAsMap();
|
|
|
if (docSource == null) {
|
|
@@ -431,15 +435,17 @@ public class TrainedModelProvider {
|
|
|
}
|
|
|
Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName());
|
|
|
if (idValue instanceof String) {
|
|
|
- foundResourceIds.add(idValue.toString());
|
|
|
+ foundFromDocs.add(idValue.toString());
|
|
|
}
|
|
|
}
|
|
|
+ Set<String> allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs);
|
|
|
ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources);
|
|
|
- requiredMatches.filterMatchedIds(foundResourceIds);
|
|
|
+ requiredMatches.filterMatchedIds(allFoundIds);
|
|
|
if (requiredMatches.hasUnmatchedIds()) {
|
|
|
idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString()));
|
|
|
} else {
|
|
|
- idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds));
|
|
|
+
|
|
|
+ idsListener.onResponse(Tuple.tuple(totalHitCount, allFoundIds));
|
|
|
}
|
|
|
},
|
|
|
idsListener::onFailure
|
|
@@ -447,6 +453,32 @@ public class TrainedModelProvider {
|
|
|
client::search);
|
|
|
}
|
|
|
|
|
|
+ static Set<String> collectIds(PageParams pageParams, Set<String> foundFromResources, Set<String> foundFromDocs) {
|
|
|
+ // If there are no matching resource models, there was no buffering and the models from the docs
|
|
|
+ // are paginated correctly.
|
|
|
+ if (foundFromResources.isEmpty()) {
|
|
|
+ return foundFromDocs;
|
|
|
+ }
|
|
|
+
|
|
|
+ TreeSet<String> allFoundIds = new TreeSet<>(foundFromDocs);
|
|
|
+ allFoundIds.addAll(foundFromResources);
|
|
|
+
|
|
|
+ if (pageParams.getFrom() > 0) {
|
|
|
+ // not the first page so there will be extra results at the front to remove
|
|
|
+ int numToTrimFromFront = Math.min(foundFromResources.size(), pageParams.getFrom());
|
|
|
+ for (int i = 0; i < numToTrimFromFront; i++) {
|
|
|
+ allFoundIds.remove(allFoundIds.first());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // trim down to size removing from the rear
|
|
|
+ while (allFoundIds.size() > pageParams.getSize()) {
|
|
|
+ allFoundIds.remove(allFoundIds.last());
|
|
|
+ }
|
|
|
+
|
|
|
+ return allFoundIds;
|
|
|
+ }
|
|
|
+
|
|
|
static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) {
|
|
|
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery()
|
|
|
.filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
|
|
@@ -517,7 +549,7 @@ public class TrainedModelProvider {
|
|
|
|
|
|
private Set<String> matchedResourceIds(String[] tokens) {
|
|
|
if (Strings.isAllOrWildcard(tokens)) {
|
|
|
- return new HashSet<>(MODELS_STORED_AS_RESOURCE);
|
|
|
+ return MODELS_STORED_AS_RESOURCE;
|
|
|
}
|
|
|
|
|
|
Set<String> matchedModels = new HashSet<>();
|
|
@@ -535,7 +567,7 @@ public class TrainedModelProvider {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- return matchedModels;
|
|
|
+ return Collections.unmodifiableSet(matchedModels);
|
|
|
}
|
|
|
|
|
|
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
|