Selaa lähdekoodia

Fix NPE on composite aggregation with sub-aggregations that need scores (#28129)

The composite aggregation defers the collection of sub-aggregations to a second pass that visits documents only if they
appear in the top buckets. Though the scorer for sub-aggregations is not set on this second pass and generates an NPE if any sub-aggregation
tries to access the score. This change creates a scorer for the second pass and makes sure that sub-aggs can use it safely to check the score of
the collected documents.
Jim Ferenczi 7 vuotta sitten
vanhempi
commit
bd11e6c441

+ 23 - 0
server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java

@@ -23,6 +23,9 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.CollectionTerminatedException;
 import org.apache.lucene.search.DocIdSet;
 import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.Weight;
 import org.apache.lucene.util.RoaringDocIdSet;
 import org.elasticsearch.search.aggregations.Aggregator;
 import org.elasticsearch.search.aggregations.AggregatorFactories;
@@ -87,6 +90,12 @@ final class CompositeAggregator extends BucketsAggregator {
 
         // Replay all documents that contain at least one top bucket (collected during the first pass).
         grow(keys.size()+1);
+        final boolean needsScores = needsScores();
+        Weight weight = null;
+        if (needsScores) {
+            Query query = context.query();
+            weight = context.searcher().createNormalizedWeight(query, true);
+        }
         for (LeafContext context : contexts) {
             DocIdSetIterator docIdSetIterator = context.docIdSet.iterator();
             if (docIdSetIterator == null) {
@@ -95,7 +104,21 @@ final class CompositeAggregator extends BucketsAggregator {
             final CompositeValuesSource.Collector collector =
                 array.getLeafCollector(context.ctx, getSecondPassCollector(context.subCollector));
             int docID;
+            DocIdSetIterator scorerIt = null;
+            if (needsScores) {
+                Scorer scorer = weight.scorer(context.ctx);
+                // We don't need to check if the scorer is null
+                // since we are sure that there are documents to replay (docIdSetIterator it not empty).
+                scorerIt = scorer.iterator();
+                context.subCollector.setScorer(scorer);
+            }
             while ((docID = docIdSetIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
+                if (needsScores) {
+                    assert scorerIt.docID() < docID;
+                    scorerIt.advance(docID);
+                    // aggregations should only be replayed on matching documents
+                    assert scorerIt.docID() == docID;
+                }
                 collector.collect(docID);
             }
         }

+ 70 - 3
server/src/test/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregatorTests.java

@@ -50,6 +50,8 @@ import org.elasticsearch.index.mapper.Mapper;
 import org.elasticsearch.index.mapper.NumberFieldMapper;
 import org.elasticsearch.search.aggregations.AggregatorTestCase;
 import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval;
+import org.elasticsearch.search.aggregations.metrics.tophits.TopHits;
+import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder;
 import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.test.IndexSettingsModule;
 import org.joda.time.DateTimeZone;
@@ -1065,8 +1067,73 @@ public class CompositeAggregatorTests extends AggregatorTestCase {
         );
     }
 
-    private void testSearchCase(Query query,
-                                Sort sort,
+    public void testWithKeywordAndTopHits() throws Exception {
+        final List<Map<String, List<Object>>> dataset = new ArrayList<>();
+        dataset.addAll(
+            Arrays.asList(
+                createDocument("keyword", "a"),
+                createDocument("keyword", "c"),
+                createDocument("keyword", "a"),
+                createDocument("keyword", "d"),
+                createDocument("keyword", "c")
+            )
+        );
+        final Sort sort = new Sort(new SortedSetSortField("keyword", false));
+        testSearchCase(new MatchAllDocsQuery(), sort, dataset,
+            () -> {
+                TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
+                    .field("keyword");
+                return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
+                    .subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
+            }, (result) -> {
+                assertEquals(3, result.getBuckets().size());
+                assertEquals("{keyword=a}", result.getBuckets().get(0).getKeyAsString());
+                assertEquals(2L, result.getBuckets().get(0).getDocCount());
+                TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
+                assertNotNull(topHits);
+                assertEquals(topHits.getHits().getHits().length, 2);
+                assertEquals(topHits.getHits().getTotalHits(), 2L);
+                assertEquals("{keyword=c}", result.getBuckets().get(1).getKeyAsString());
+                assertEquals(2L, result.getBuckets().get(1).getDocCount());
+                topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
+                assertNotNull(topHits);
+                assertEquals(topHits.getHits().getHits().length, 2);
+                assertEquals(topHits.getHits().getTotalHits(), 2L);
+                assertEquals("{keyword=d}", result.getBuckets().get(2).getKeyAsString());
+                assertEquals(1L, result.getBuckets().get(2).getDocCount());
+                topHits = result.getBuckets().get(2).getAggregations().get("top_hits");
+                assertNotNull(topHits);
+                assertEquals(topHits.getHits().getHits().length, 1);
+                assertEquals(topHits.getHits().getTotalHits(), 1L);;
+            }
+        );
+
+        testSearchCase(new MatchAllDocsQuery(), sort, dataset,
+            () -> {
+                TermsValuesSourceBuilder terms = new TermsValuesSourceBuilder("keyword")
+                    .field("keyword");
+                return new CompositeAggregationBuilder("name", Collections.singletonList(terms))
+                    .aggregateAfter(Collections.singletonMap("keyword", "a"))
+                    .subAggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
+            }, (result) -> {
+                assertEquals(2, result.getBuckets().size());
+                assertEquals("{keyword=c}", result.getBuckets().get(0).getKeyAsString());
+                assertEquals(2L, result.getBuckets().get(0).getDocCount());
+                TopHits topHits = result.getBuckets().get(0).getAggregations().get("top_hits");
+                assertNotNull(topHits);
+                assertEquals(topHits.getHits().getHits().length, 2);
+                assertEquals(topHits.getHits().getTotalHits(), 2L);
+                assertEquals("{keyword=d}", result.getBuckets().get(1).getKeyAsString());
+                assertEquals(1L, result.getBuckets().get(1).getDocCount());
+                topHits = result.getBuckets().get(1).getAggregations().get("top_hits");
+                assertNotNull(topHits);
+                assertEquals(topHits.getHits().getHits().length, 1);
+                assertEquals(topHits.getHits().getTotalHits(), 1L);
+            }
+        );
+    }
+
+    private void testSearchCase(Query query, Sort sort,
                                 List<Map<String, List<Object>>> dataset,
                                 Supplier<CompositeAggregationBuilder> create,
                                 Consumer<InternalComposite> verify) throws IOException {
@@ -1107,7 +1174,7 @@ public class CompositeAggregatorTests extends AggregatorTestCase {
                 IndexSearcher indexSearcher = newSearcher(indexReader, sort == null, sort == null);
                 CompositeAggregationBuilder aggregationBuilder = create.get();
                 if (sort != null) {
-                    CompositeAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
+                    CompositeAggregator aggregator = createAggregator(query, aggregationBuilder, indexSearcher, indexSettings, FIELD_TYPES);
                     assertTrue(aggregator.canEarlyTerminate());
                 }
                 final InternalComposite composite;

+ 30 - 9
test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

@@ -103,16 +103,27 @@ public abstract class AggregatorTestCase extends ESTestCase {
             new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
     }
 
-    /** Create a factory for the given aggregation builder. */
+
     protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggregationBuilder,
                                                            IndexSearcher indexSearcher,
                                                            IndexSettings indexSettings,
                                                            MultiBucketConsumer bucketConsumer,
                                                            MappedFieldType... fieldTypes) throws IOException {
+        return createAggregatorFactory(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
+    }
+
+    /** Create a factory for the given aggregation builder. */
+    protected AggregatorFactory<?> createAggregatorFactory(Query query,
+                                                           AggregationBuilder aggregationBuilder,
+                                                           IndexSearcher indexSearcher,
+                                                           IndexSettings indexSettings,
+                                                           MultiBucketConsumer bucketConsumer,
+                                                           MappedFieldType... fieldTypes) throws IOException {
         SearchContext searchContext = createSearchContext(indexSearcher, indexSettings);
         CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService();
         when(searchContext.aggregations())
             .thenReturn(new SearchContextAggregations(AggregatorFactories.EMPTY, bucketConsumer));
+        when(searchContext.query()).thenReturn(query);
         when(searchContext.bigArrays()).thenReturn(new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), circuitBreakerService));
         // TODO: now just needed for top_hits, this will need to be revised for other agg unit tests:
         MapperService mapperService = mapperServiceMock();
@@ -146,19 +157,20 @@ public abstract class AggregatorTestCase extends ESTestCase {
             new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
     }
 
-    protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
+    protected <A extends Aggregator> A createAggregator(Query query,
+                                                        AggregationBuilder aggregationBuilder,
                                                         IndexSearcher indexSearcher,
                                                         IndexSettings indexSettings,
                                                         MappedFieldType... fieldTypes) throws IOException {
-        return createAggregator(aggregationBuilder, indexSearcher, indexSettings,
+        return createAggregator(query, aggregationBuilder, indexSearcher, indexSettings,
             new MultiBucketConsumer(DEFAULT_MAX_BUCKETS), fieldTypes);
     }
 
-    protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
+    protected <A extends Aggregator> A createAggregator(Query query, AggregationBuilder aggregationBuilder,
                                                         IndexSearcher indexSearcher,
                                                         MultiBucketConsumer bucketConsumer,
                                                         MappedFieldType... fieldTypes) throws IOException {
-        return createAggregator(aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
+        return createAggregator(query, aggregationBuilder, indexSearcher, createIndexSettings(), bucketConsumer, fieldTypes);
     }
 
     protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
@@ -166,8 +178,17 @@ public abstract class AggregatorTestCase extends ESTestCase {
                                                         IndexSettings indexSettings,
                                                         MultiBucketConsumer bucketConsumer,
                                                         MappedFieldType... fieldTypes) throws IOException {
+        return createAggregator(null, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes);
+    }
+
+    protected <A extends Aggregator> A createAggregator(Query query,
+                                                        AggregationBuilder aggregationBuilder,
+                                                        IndexSearcher indexSearcher,
+                                                        IndexSettings indexSettings,
+                                                        MultiBucketConsumer bucketConsumer,
+                                                        MappedFieldType... fieldTypes) throws IOException {
         @SuppressWarnings("unchecked")
-        A aggregator = (A) createAggregatorFactory(aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
+        A aggregator = (A) createAggregatorFactory(query, aggregationBuilder, indexSearcher, indexSettings, bucketConsumer, fieldTypes)
             .create(null, true);
         return aggregator;
     }
@@ -262,7 +283,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
                                                                              int maxBucket,
                                                                              MappedFieldType... fieldTypes) throws IOException {
         MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
-        C a = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
+        C a = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
         a.preCollection();
         searcher.search(query, a);
         a.postCollection();
@@ -310,11 +331,11 @@ public abstract class AggregatorTestCase extends ESTestCase {
         Query rewritten = searcher.rewrite(query);
         Weight weight = searcher.createWeight(rewritten, true, 1f);
         MultiBucketConsumer bucketConsumer = new MultiBucketConsumer(maxBucket);
-        C root = createAggregator(builder, searcher, bucketConsumer, fieldTypes);
+        C root = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes);
 
         for (ShardSearcher subSearcher : subSearchers) {
             MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer(maxBucket);
-            C a = createAggregator(builder, subSearcher, shardBucketConsumer, fieldTypes);
+            C a = createAggregator(query, builder, subSearcher, shardBucketConsumer, fieldTypes);
             a.preCollection();
             subSearcher.search(weight, a);
             a.postCollection();