Browse Source

[ML][Inference] Fix model pagination with models as resources (#51573)

This adds logic to handle paging problems when the ID pattern + tags reference models stored as resources. 

Most of the complexity comes from the issue where a model stored as a resource could be at the start, or the end of a page or when we are on the last page.
Benjamin Trent 5 years ago
parent
commit
108ebc1baa

+ 55 - 23
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

@@ -28,7 +28,6 @@ import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.common.CheckedBiFunction;
-import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.collect.Tuple;
@@ -73,10 +72,10 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashSet;
-import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.TreeSet;
 
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@@ -381,19 +380,34 @@ public class TrainedModelProvider {
 
     public void expandIds(String idExpression,
                           boolean allowNoResources,
-                          @Nullable PageParams pageParams,
+                          PageParams pageParams,
                           Set<String> tags,
                           ActionListener<Tuple<Long, Set<String>>> idsListener) {
         String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
+        Set<String> matchedResourceIds = matchedResourceIds(tokens);
+        Set<String> foundResourceIds;
+        if (tags.isEmpty()) {
+            foundResourceIds = matchedResourceIds;
+        } else {
+            foundResourceIds = new HashSet<>();
+            for(String resourceId : matchedResourceIds) {
+                // Does the model as a resource have all the tags?
+                if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
+                    foundResourceIds.add(resourceId);
+                }
+            }
+        }
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
             .sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
                 // If there are no resources, there might be no mapping for the id field.
                 // This makes sure we don't get an error if that happens.
                 .unmappedType("long"))
-            .query(buildExpandIdsQuery(tokens, tags));
-        if (pageParams != null) {
-            sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
-        }
+            .query(buildExpandIdsQuery(tokens, tags))
+            // We "buffer" the from and size to take into account models stored as resources.
+            // This is so we handle the edge cases when the model that is stored as a resource is at the start/end of
+            // a page.
+            .from(Math.max(0, pageParams.getFrom() - foundResourceIds.size()))
+            .size(Math.min(10_000, pageParams.getSize() + foundResourceIds.size()));
         sourceBuilder.trackTotalHits(true)
             // we only care about the item id's
             .fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null);
@@ -406,17 +420,6 @@ public class TrainedModelProvider {
                 indicesOptions.expandWildcardsClosed(),
                 indicesOptions))
             .source(sourceBuilder);
-        Set<String> foundResourceIds = new LinkedHashSet<>();
-        if (tags.isEmpty()) {
-            foundResourceIds.addAll(matchedResourceIds(tokens));
-        } else {
-            for(String resourceId : matchedResourceIds(tokens)) {
-                // Does the model as a resource have all the tags?
-                if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
-                    foundResourceIds.add(resourceId);
-                }
-            }
-        }
 
         executeAsyncWithOrigin(client.threadPool().getThreadContext(),
             ML_ORIGIN,
@@ -424,6 +427,7 @@ public class TrainedModelProvider {
             ActionListener.<SearchResponse>wrap(
                 response -> {
                     long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
+                    Set<String> foundFromDocs = new HashSet<>();
                     for (SearchHit hit : response.getHits().getHits()) {
                         Map<String, Object> docSource = hit.getSourceAsMap();
                         if (docSource == null) {
@@ -431,15 +435,17 @@ public class TrainedModelProvider {
                         }
                         Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName());
                         if (idValue instanceof String) {
-                            foundResourceIds.add(idValue.toString());
+                            foundFromDocs.add(idValue.toString());
                         }
                     }
+                    Set<String> allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs);
                     ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources);
-                    requiredMatches.filterMatchedIds(foundResourceIds);
+                    requiredMatches.filterMatchedIds(allFoundIds);
                     if (requiredMatches.hasUnmatchedIds()) {
                         idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString()));
                     } else {
-                        idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds));
+
+                        idsListener.onResponse(Tuple.tuple(totalHitCount, allFoundIds));
                     }
                 },
                 idsListener::onFailure
@@ -447,6 +453,32 @@ public class TrainedModelProvider {
             client::search);
     }
 
+    static Set<String> collectIds(PageParams pageParams, Set<String> foundFromResources, Set<String> foundFromDocs) {
+        // If there are no matching resource models, there was no buffering and the models from the docs
+        // are paginated correctly.
+        if (foundFromResources.isEmpty()) {
+            return foundFromDocs;
+        }
+
+        TreeSet<String> allFoundIds = new TreeSet<>(foundFromDocs);
+        allFoundIds.addAll(foundFromResources);
+
+        if (pageParams.getFrom() > 0) {
+            // not the first page so there will be extra results at the front to remove
+            int numToTrimFromFront = Math.min(foundFromResources.size(), pageParams.getFrom());
+            for (int i = 0; i < numToTrimFromFront; i++) {
+                allFoundIds.remove(allFoundIds.first());
+            }
+        }
+
+        // trim down to size removing from the rear
+        while (allFoundIds.size() > pageParams.getSize()) {
+            allFoundIds.remove(allFoundIds.last());
+        }
+
+        return allFoundIds;
+    }
+
     static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) {
         BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery()
             .filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
@@ -517,7 +549,7 @@ public class TrainedModelProvider {
 
     private Set<String> matchedResourceIds(String[] tokens) {
         if (Strings.isAllOrWildcard(tokens)) {
-            return new HashSet<>(MODELS_STORED_AS_RESOURCE);
+            return MODELS_STORED_AS_RESOURCE;
         }
 
         Set<String> matchedModels = new HashSet<>();
@@ -535,7 +567,7 @@ public class TrainedModelProvider {
                 }
             }
         }
-        return matchedModels;
+        return Collections.unmodifiableSet(matchedModels);
     }
 
     private static <T> T handleSearchItem(MultiSearchResponse.Item item,

+ 48 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java

@@ -14,12 +14,16 @@ import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.TermQueryBuilder;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.action.util.PageParams;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.TreeSet;
 
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
@@ -86,6 +90,50 @@ public class TrainedModelProviderTests extends ESTestCase {
         });
     }
 
+    public void testExpandIdsPagination() {
+        // NOTE: these tests assume that the query pagination results are "buffered"
+
+        assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3),
+            Collections.emptySet(),
+            new HashSet<>(Arrays.asList("a", "b", "c"))),
+            equalTo(new TreeSet<>(Arrays.asList("a", "b", "c"))));
+
+        assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3),
+            Collections.singleton("a"),
+            new HashSet<>(Arrays.asList("b", "c", "d"))),
+            equalTo(new TreeSet<>(Arrays.asList("a", "b", "c"))));
+
+        assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3),
+            Collections.singleton("a"),
+            new HashSet<>(Arrays.asList("b", "c", "d"))),
+            equalTo(new TreeSet<>(Arrays.asList("b", "c", "d"))));
+
+        assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1),
+            Collections.singleton("c"),
+            new HashSet<>(Arrays.asList("a", "b"))),
+            equalTo(new TreeSet<>(Arrays.asList("b"))));
+
+        assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1),
+            Collections.singleton("b"),
+            new HashSet<>(Arrays.asList("a", "c"))),
+            equalTo(new TreeSet<>(Arrays.asList("b"))));
+
+        assertThat(TrainedModelProvider.collectIds(new PageParams(1, 2),
+            new HashSet<>(Arrays.asList("a", "b")),
+            new HashSet<>(Arrays.asList("c", "d", "e"))),
+            equalTo(new TreeSet<>(Arrays.asList("b", "c"))));
+
+        assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3),
+            new HashSet<>(Arrays.asList("a", "b")),
+            new HashSet<>(Arrays.asList("c", "d", "e"))),
+            equalTo(new TreeSet<>(Arrays.asList("b", "c", "d"))));
+
+        assertThat(TrainedModelProvider.collectIds(new PageParams(2, 3),
+            new HashSet<>(Arrays.asList("a", "b")),
+            new HashSet<>(Arrays.asList("c", "d", "e"))),
+            equalTo(new TreeSet<>(Arrays.asList("c", "d", "e"))));
+    }
+
     public void testGetModelThatExistsAsResourceButIsMissing() {
         TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
         ElasticsearchException ex = expectThrows(ElasticsearchException.class,

+ 129 - 3
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

@@ -72,6 +72,56 @@ setup:
                }
             }
           }
+
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      ml.put_trained_model:
+        model_id: yyy-classification-model
+        body: >
+          {
+            "description": "empty model for tests",
+            "input": {"field_names": ["field1", "field2"]},
+            "tags": ["classification", "tag3"],
+            "definition": {
+               "preprocessors": [],
+               "trained_model": {
+                  "tree": {
+                     "feature_names": ["field1", "field2"],
+                     "tree_structure": [
+                        {"node_index": 0, "leaf_value": 1}
+                     ],
+                     "target_type": "classification",
+                     "classification_labels": ["no", "yes"]
+                  }
+               }
+            }
+          }
+
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+      ml.put_trained_model:
+        model_id: zzz-classification-model
+        body: >
+          {
+            "description": "empty model for tests",
+            "input": {"field_names": ["field1", "field2"]},
+            "tags": ["classification", "tag3"],
+            "definition": {
+               "preprocessors": [],
+               "trained_model": {
+                  "tree": {
+                     "feature_names": ["field1", "field2"],
+                     "tree_structure": [
+                        {"node_index": 0, "leaf_value": 1}
+                     ],
+                     "target_type": "classification",
+                     "classification_labels": ["no", "yes"]
+                  }
+               }
+            }
+          }
 ---
 "Test get given missing trained model":
 
@@ -102,15 +152,20 @@ setup:
   - do:
       ml.get_trained_models:
         model_id: "*"
-  - match: { count: 4 }
+  - match: { count: 6 }
+  - length: { trained_model_configs: 6 }
   - match: { trained_model_configs.0.model_id: "a-classification-model" }
   - match: { trained_model_configs.1.model_id: "a-regression-model-0" }
   - match: { trained_model_configs.2.model_id: "a-regression-model-1" }
+  - match: { trained_model_configs.3.model_id: "lang_ident_model_1" }
+  - match: { trained_model_configs.4.model_id: "yyy-classification-model" }
+  - match: { trained_model_configs.5.model_id: "zzz-classification-model" }
 
   - do:
       ml.get_trained_models:
         model_id: "a-regression*"
   - match: { count: 2 }
+  - length: { trained_model_configs: 2 }
   - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
   - match: { trained_model_configs.1.model_id: "a-regression-model-1" }
 
@@ -119,7 +174,8 @@ setup:
         model_id: "*"
         from: 0
         size: 2
-  - match: { count: 4 }
+  - match: { count: 6 }
+  - length: { trained_model_configs: 2 }
   - match: { trained_model_configs.0.model_id: "a-classification-model" }
   - match: { trained_model_configs.1.model_id: "a-regression-model-0" }
 
@@ -128,8 +184,78 @@ setup:
         model_id: "*"
         from: 1
         size: 1
-  - match: { count: 4 }
+  - match: { count: 6 }
+  - length: { trained_model_configs: 1 }
   - match: { trained_model_configs.0.model_id: "a-regression-model-0" }
+
+  - do:
+      ml.get_trained_models:
+        model_id: "*"
+        from: 2
+        size: 2
+  - match: { count: 6 }
+  - length: { trained_model_configs: 2 }
+  - match: { trained_model_configs.0.model_id: "a-regression-model-1" }
+  - match: { trained_model_configs.1.model_id: "lang_ident_model_1" }
+
+  - do:
+      ml.get_trained_models:
+        model_id: "*"
+        from: 3
+        size: 1
+  - match: { count: 6 }
+  - length: { trained_model_configs: 1 }
+  - match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
+
+  - do:
+      ml.get_trained_models:
+        model_id: "*"
+        from: 3
+        size: 2
+  - match: { count: 6 }
+  - length: { trained_model_configs: 2 }
+  - match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
+  - match: { trained_model_configs.1.model_id: "yyy-classification-model" }
+
+  - do:
+      ml.get_trained_models:
+        model_id: "*"
+        from: 4
+        size: 2
+  - match: { count: 6 }
+  - length: { trained_model_configs: 2 }
+  - match: { trained_model_configs.0.model_id: "yyy-classification-model" }
+  - match: { trained_model_configs.1.model_id: "zzz-classification-model" }
+
+  - do:
+      ml.get_trained_models:
+        model_id: "a-*,lang*,zzz*"
+        allow_no_match: true
+        from: 3
+        size: 1
+  - match: { count: 5 }
+  - length: { trained_model_configs: 1 }
+  - match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
+
+  - do:
+      ml.get_trained_models:
+        model_id: "a-*,lang*,zzz*"
+        allow_no_match: true
+        from: 4
+        size: 1
+  - match: { count: 5 }
+  - length: { trained_model_configs: 1 }
+  - match: { trained_model_configs.0.model_id: "zzz-classification-model" }
+
+  - do:
+      ml.get_trained_models:
+        model_id: "a-*,lang*,zzz*"
+        from: 4
+        size: 100
+  - match: { count: 5 }
+  - length: { trained_model_configs: 1 }
+  - match: { trained_model_configs.0.model_id: "zzz-classification-model" }
+
 ---
 "Test get models with tags":
   - do: