Browse Source

Add early termination support to BucketCollector (#33279)

This commit adds the support to early terminate the collection of a leaf
in the aggregation framework. This change introduces a MultiBucketCollector which
handles CollectionTerminatedException exactly like the Lucene MultiCollector.
Any aggregator can now throw a CollectionTerminatedException without stopping
the collection of a sibling aggregator. This is useful for aggregators that
can infer their result without visiting all documents (e.g.: a min/max aggregation on a match_all query).
Jim Ferenczi 7 years ago
parent
commit
713c07e14d

+ 1 - 1
docs/reference/search/profile.asciidoc

@@ -596,7 +596,7 @@ And the response:
                                   ]
                                 },
                                 {
-                                  "name": "BucketCollector: [[my_scoped_agg, my_global_agg]]",
+                                  "name": "MultiBucketCollector: [[my_scoped_agg, my_global_agg]]",
                                   "reason": "aggregation",
                                   "time_in_nanos": 8273
                                 }

+ 2 - 2
server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java

@@ -60,7 +60,7 @@ public class AggregationPhase implements SearchPhase {
                 }
                 context.aggregations().aggregators(aggregators);
                 if (!collectors.isEmpty()) {
-                    Collector collector = BucketCollector.wrap(collectors);
+                    Collector collector = MultiBucketCollector.wrap(collectors);
                     ((BucketCollector)collector).preCollection();
                     if (context.getProfilers() != null) {
                         collector = new InternalProfileCollector(collector, CollectorResult.REASON_AGGREGATION,
@@ -97,7 +97,7 @@ public class AggregationPhase implements SearchPhase {
 
         // optimize the global collector based execution
         if (!globals.isEmpty()) {
-            BucketCollector globalsCollector = BucketCollector.wrap(globals);
+            BucketCollector globalsCollector = MultiBucketCollector.wrap(globals);
             Query query = context.buildFilteredQuery(Queries.newMatchAllQuery());
 
             try {

+ 1 - 1
server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java

@@ -183,7 +183,7 @@ public abstract class AggregatorBase extends Aggregator {
     @Override
     public final void preCollection() throws IOException {
         List<BucketCollector> collectors = Arrays.asList(subAggregators);
-        collectableSubAggregators = BucketCollector.wrap(collectors);
+        collectableSubAggregators = MultiBucketCollector.wrap(collectors);
         doPreCollection();
         collectableSubAggregators.preCollection();
     }

+ 0 - 59
server/src/main/java/org/elasticsearch/search/aggregations/BucketCollector.java

@@ -24,10 +24,6 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.Collector;
 
 import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.stream.StreamSupport;
 
 /**
  * A Collector that can collect data in separate buckets.
@@ -54,61 +50,6 @@ public abstract class BucketCollector implements Collector {
         }
     };
 
-    /**
-     * Wrap the given collectors into a single instance.
-     */
-    public static BucketCollector wrap(Iterable<? extends BucketCollector> collectorList) {
-        final BucketCollector[] collectors =
-                StreamSupport.stream(collectorList.spliterator(), false).toArray(size -> new BucketCollector[size]);
-        switch (collectors.length) {
-            case 0:
-                return NO_OP_COLLECTOR;
-            case 1:
-                return collectors[0];
-            default:
-                return new BucketCollector() {
-
-                    @Override
-                    public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException {
-                        List<LeafBucketCollector> leafCollectors = new ArrayList<>(collectors.length);
-                        for (BucketCollector c : collectors) {
-                            leafCollectors.add(c.getLeafCollector(ctx));
-                        }
-                        return LeafBucketCollector.wrap(leafCollectors);
-                    }
-
-                    @Override
-                    public void preCollection() throws IOException {
-                        for (BucketCollector collector : collectors) {
-                            collector.preCollection();
-                        }
-                    }
-
-                    @Override
-                    public void postCollection() throws IOException {
-                        for (BucketCollector collector : collectors) {
-                            collector.postCollection();
-                        }
-                    }
-
-                    @Override
-                    public boolean needsScores() {
-                        for (BucketCollector collector : collectors) {
-                            if (collector.needsScores()) {
-                                return true;
-                            }
-                        }
-                        return false;
-                    }
-
-                    @Override
-                    public String toString() {
-                        return Arrays.toString(collectors);
-                    }
-                };
-        }
-    }
-
     @Override
     public abstract LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException;
 

+ 207 - 0
server/src/main/java/org/elasticsearch/search/aggregations/MultiBucketCollector.java

@@ -0,0 +1,207 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.search.aggregations;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.CollectionTerminatedException;
+import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.MultiCollector;
+import org.apache.lucene.search.ScoreCachingWrappingScorer;
+import org.apache.lucene.search.Scorer;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * A {@link BucketCollector} which allows running a bucket collection with several
+ * {@link BucketCollector}s. It is similar to the {@link MultiCollector} except that the
+ * {@link #wrap} method filters out the {@link BucketCollector#NO_OP_COLLECTOR}s and not
+ * the null ones.
+ */
+public class MultiBucketCollector extends BucketCollector {
+
+    /** See {@link #wrap(Iterable)}. */
+    public static BucketCollector wrap(BucketCollector... collectors) {
+        return wrap(Arrays.asList(collectors));
+    }
+
+    /**
+     * Wraps a list of {@link BucketCollector}s with a {@link MultiBucketCollector}. This
+     * method works as follows:
+     * <ul>
+     * <li>Filters out the {@link BucketCollector#NO_OP_COLLECTOR}s collectors, so they are not used
+     * during search time.
+     * <li>If the input contains 1 real collector, it is returned.
+     * <li>Otherwise the method returns a {@link MultiBucketCollector} which wraps the
+     * non-{@link BucketCollector#NO_OP_COLLECTOR} collectors.
+     * </ul>
+     */
+    public static BucketCollector wrap(Iterable<? extends BucketCollector> collectors) {
+        // For the user's convenience, we allow NO_OP collectors to be passed.
+        // However, to improve performance, these null collectors are found
+        // and dropped from the array we save for actual collection time.
+        int n = 0;
+        for (BucketCollector c : collectors) {
+            if (c != NO_OP_COLLECTOR) {
+                n++;
+            }
+        }
+
+        if (n == 0) {
+            return NO_OP_COLLECTOR;
+        } else if (n == 1) {
+            // only 1 Collector - return it.
+            BucketCollector col = null;
+            for (BucketCollector c : collectors) {
+                if (c != null) {
+                    col = c;
+                    break;
+                }
+            }
+            return col;
+        } else {
+            BucketCollector[] colls = new BucketCollector[n];
+            n = 0;
+            for (BucketCollector c : collectors) {
+                if (c != null) {
+                    colls[n++] = c;
+                }
+            }
+            return new MultiBucketCollector(colls);
+        }
+    }
+
+    private final boolean cacheScores;
+    private final BucketCollector[] collectors;
+
+    private MultiBucketCollector(BucketCollector... collectors) {
+        this.collectors = collectors;
+        int numNeedsScores = 0;
+        for (BucketCollector collector : collectors) {
+            if (collector.needsScores()) {
+                numNeedsScores += 1;
+            }
+        }
+        this.cacheScores = numNeedsScores >= 2;
+    }
+
+    @Override
+    public void preCollection() throws IOException {
+        for (BucketCollector collector : collectors) {
+            collector.preCollection();
+        }
+    }
+
+    @Override
+    public void postCollection() throws IOException {
+        for (BucketCollector collector : collectors) {
+            collector.postCollection();
+        }
+    }
+
+    @Override
+    public boolean needsScores() {
+        for (BucketCollector collector : collectors) {
+            if (collector.needsScores()) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    @Override
+    public String toString() {
+        return Arrays.toString(collectors);
+    }
+
+    @Override
+    public LeafBucketCollector getLeafCollector(LeafReaderContext context) throws IOException {
+        final List<LeafBucketCollector> leafCollectors = new ArrayList<>();
+        for (BucketCollector collector : collectors) {
+            final LeafBucketCollector leafCollector;
+            try {
+                leafCollector = collector.getLeafCollector(context);
+            } catch (CollectionTerminatedException e) {
+                // this leaf collector does not need this segment
+                continue;
+            }
+            leafCollectors.add(leafCollector);
+        }
+        switch (leafCollectors.size()) {
+            case 0:
+                throw new CollectionTerminatedException();
+            case 1:
+                return leafCollectors.get(0);
+            default:
+                return new MultiLeafBucketCollector(leafCollectors, cacheScores);
+        }
+    }
+
+    private static class MultiLeafBucketCollector extends LeafBucketCollector {
+
+        private final boolean cacheScores;
+        private final LeafBucketCollector[] collectors;
+        private int numCollectors;
+
+        private MultiLeafBucketCollector(List<LeafBucketCollector> collectors, boolean cacheScores) {
+            this.collectors = collectors.toArray(new LeafBucketCollector[collectors.size()]);
+            this.cacheScores = cacheScores;
+            this.numCollectors = this.collectors.length;
+        }
+
+        @Override
+        public void setScorer(Scorer scorer) throws IOException {
+            if (cacheScores) {
+                scorer = new ScoreCachingWrappingScorer(scorer);
+            }
+            for (int i = 0; i < numCollectors; ++i) {
+                final LeafCollector c = collectors[i];
+                c.setScorer(scorer);
+            }
+        }
+
+        private void removeCollector(int i) {
+            System.arraycopy(collectors, i + 1, collectors, i, numCollectors - i - 1);
+            --numCollectors;
+            collectors[numCollectors] = null;
+        }
+
+        @Override
+        public void collect(int doc, long bucket) throws IOException {
+            final LeafBucketCollector[] collectors = this.collectors;
+            int numCollectors = this.numCollectors;
+            for (int i = 0; i < numCollectors; ) {
+                final LeafBucketCollector collector = collectors[i];
+                try {
+                    collector.collect(doc, bucket);
+                    ++i;
+                } catch (CollectionTerminatedException e) {
+                    removeCollector(i);
+                    numCollectors = this.numCollectors;
+                    if (numCollectors == 0) {
+                        throw new CollectionTerminatedException();
+                    }
+                }
+            }
+        }
+    }
+}

+ 2 - 1
server/src/main/java/org/elasticsearch/search/aggregations/bucket/BestBucketsDeferringCollector.java

@@ -33,6 +33,7 @@ import org.elasticsearch.search.aggregations.Aggregator;
 import org.elasticsearch.search.aggregations.BucketCollector;
 import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.LeafBucketCollector;
+import org.elasticsearch.search.aggregations.MultiBucketCollector;
 import org.elasticsearch.search.internal.SearchContext;
 
 import java.io.IOException;
@@ -90,7 +91,7 @@ public class BestBucketsDeferringCollector extends DeferringBucketCollector {
     /** Set the deferred collectors. */
     @Override
     public void setDeferredCollector(Iterable<BucketCollector> deferredCollectors) {
-        this.collector = BucketCollector.wrap(deferredCollectors);
+        this.collector = MultiBucketCollector.wrap(deferredCollectors);
     }
 
     private void finishLeaf() {

+ 2 - 1
server/src/main/java/org/elasticsearch/search/aggregations/bucket/DeferableBucketAggregator.java

@@ -22,6 +22,7 @@ package org.elasticsearch.search.aggregations.bucket;
 import org.elasticsearch.search.aggregations.Aggregator;
 import org.elasticsearch.search.aggregations.AggregatorFactories;
 import org.elasticsearch.search.aggregations.BucketCollector;
+import org.elasticsearch.search.aggregations.MultiBucketCollector;
 import org.elasticsearch.search.aggregations.bucket.global.GlobalAggregator;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
 import org.elasticsearch.search.internal.SearchContext;
@@ -59,7 +60,7 @@ public abstract class DeferableBucketAggregator extends BucketsAggregator {
             recordingWrapper.setDeferredCollector(deferredCollectors);
             collectors.add(recordingWrapper);
         }
-        collectableSubAggregators = BucketCollector.wrap(collectors);
+        collectableSubAggregators = MultiBucketCollector.wrap(collectors);
     }
 
     public static boolean descendsFromGlobalAggregator(Aggregator parent) {

+ 2 - 1
server/src/main/java/org/elasticsearch/search/aggregations/bucket/MergingBucketsDeferringCollector.java

@@ -31,6 +31,7 @@ import org.elasticsearch.search.aggregations.Aggregator;
 import org.elasticsearch.search.aggregations.BucketCollector;
 import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.LeafBucketCollector;
+import org.elasticsearch.search.aggregations.MultiBucketCollector;
 import org.elasticsearch.search.internal.SearchContext;
 
 import java.io.IOException;
@@ -61,7 +62,7 @@ public class MergingBucketsDeferringCollector extends DeferringBucketCollector {
 
     @Override
     public void setDeferredCollector(Iterable<BucketCollector> deferredCollectors) {
-        this.collector = BucketCollector.wrap(deferredCollectors);
+        this.collector = MultiBucketCollector.wrap(deferredCollectors);
     }
 
     @Override

+ 2 - 1
server/src/main/java/org/elasticsearch/search/aggregations/bucket/composite/CompositeAggregator.java

@@ -38,6 +38,7 @@ import org.elasticsearch.search.aggregations.BucketCollector;
 import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.aggregations.LeafBucketCollector;
+import org.elasticsearch.search.aggregations.MultiBucketCollector;
 import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
 import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
 import org.elasticsearch.search.aggregations.support.ValuesSource;
@@ -93,7 +94,7 @@ final class CompositeAggregator extends BucketsAggregator {
     @Override
     protected void doPreCollection() throws IOException {
         List<BucketCollector> collectors = Arrays.asList(subAggregators);
-        deferredCollectors = BucketCollector.wrap(collectors);
+        deferredCollectors = MultiBucketCollector.wrap(collectors);
         collectableSubAggregators = BucketCollector.NO_OP_COLLECTOR;
     }
 

+ 2 - 1
server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/BestDocsDeferringCollector.java

@@ -33,6 +33,7 @@ import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.ObjectArray;
 import org.elasticsearch.search.aggregations.BucketCollector;
 import org.elasticsearch.search.aggregations.LeafBucketCollector;
+import org.elasticsearch.search.aggregations.MultiBucketCollector;
 import org.elasticsearch.search.aggregations.bucket.DeferringBucketCollector;
 
 import java.io.IOException;
@@ -76,7 +77,7 @@ public class BestDocsDeferringCollector extends DeferringBucketCollector impleme
     /** Set the deferred collectors. */
     @Override
     public void setDeferredCollector(Iterable<BucketCollector> deferredCollectors) {
-        this.deferred = BucketCollector.wrap(deferredCollectors);
+        this.deferred = MultiBucketCollector.wrap(deferredCollectors);
     }
 
     @Override

+ 262 - 0
server/src/test/java/org/elasticsearch/search/aggregations/MultiBucketCollectorTests.java

@@ -0,0 +1,262 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.search.aggregations;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.RandomIndexWriter;
+import org.apache.lucene.search.CollectionTerminatedException;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.store.Directory;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public class MultiBucketCollectorTests  extends ESTestCase {
+    private static class FakeScorer extends Scorer {
+        float score;
+        int doc = -1;
+
+        FakeScorer() {
+            super(null);
+        }
+
+        @Override
+        public int docID() {
+            return doc;
+        }
+
+        @Override
+        public float score() {
+            return score;
+        }
+
+        @Override
+        public DocIdSetIterator iterator() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public Weight getWeight() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public Collection<ChildScorer> getChildren() {
+            throw new UnsupportedOperationException();
+        }
+    }
+
+    private static class TerminateAfterBucketCollector extends BucketCollector {
+
+        private int count = 0;
+        private final int terminateAfter;
+        private final BucketCollector in;
+
+        TerminateAfterBucketCollector(BucketCollector in, int terminateAfter) {
+            this.in = in;
+            this.terminateAfter = terminateAfter;
+        }
+
+        @Override
+        public LeafBucketCollector getLeafCollector(LeafReaderContext context) throws IOException {
+            if (count >= terminateAfter) {
+                throw new CollectionTerminatedException();
+            }
+            final LeafBucketCollector leafCollector = in.getLeafCollector(context);
+            return new LeafBucketCollectorBase(leafCollector, null) {
+                @Override
+                public void collect(int doc, long bucket) throws IOException {
+                    if (count >= terminateAfter) {
+                        throw new CollectionTerminatedException();
+                    }
+                    super.collect(doc, bucket);
+                    count++;
+                }
+            };
+        }
+
+        @Override
+        public boolean needsScores() {
+            return false;
+        }
+
+        @Override
+        public void preCollection() {}
+
+        @Override
+        public void postCollection() {}
+    }
+
+    private static class TotalHitCountBucketCollector extends BucketCollector {
+
+        private int count = 0;
+
+        TotalHitCountBucketCollector() {
+        }
+
+        @Override
+        public LeafBucketCollector getLeafCollector(LeafReaderContext context) {
+            return new LeafBucketCollector() {
+                @Override
+                public void collect(int doc, long bucket) throws IOException {
+                    count++;
+                }
+            };
+        }
+
+        @Override
+        public boolean needsScores() {
+            return false;
+        }
+
+        @Override
+        public void preCollection() {}
+
+        @Override
+        public void postCollection() {}
+
+        int getTotalHits() {
+            return count;
+        }
+    }
+
+    private static class SetScorerBucketCollector extends BucketCollector {
+        private final BucketCollector in;
+        private final AtomicBoolean setScorerCalled;
+
+        SetScorerBucketCollector(BucketCollector in, AtomicBoolean setScorerCalled) {
+            this.in = in;
+            this.setScorerCalled = setScorerCalled;
+        }
+
+        @Override
+        public LeafBucketCollector getLeafCollector(LeafReaderContext context) throws IOException {
+            final LeafBucketCollector leafCollector = in.getLeafCollector(context);
+            return new LeafBucketCollectorBase(leafCollector, null) {
+                @Override
+                public void setScorer(Scorer scorer) throws IOException {
+                    super.setScorer(scorer);
+                    setScorerCalled.set(true);
+                }
+            };
+        }
+
+        @Override
+        public boolean needsScores() {
+            return false;
+        }
+
+        @Override
+        public void preCollection() {}
+
+        @Override
+        public void postCollection() {}
+    }
+
+    public void testCollectionTerminatedExceptionHandling() throws IOException {
+        final int iters = atLeast(3);
+        for (int iter = 0; iter < iters; ++iter) {
+            Directory dir = newDirectory();
+            RandomIndexWriter w = new RandomIndexWriter(random(), dir);
+            final int numDocs = randomIntBetween(100, 1000);
+            final Document doc = new Document();
+            for (int i = 0; i < numDocs; ++i) {
+                w.addDocument(doc);
+            }
+            final IndexReader reader = w.getReader();
+            w.close();
+            final IndexSearcher searcher = newSearcher(reader);
+            Map<TotalHitCountBucketCollector, Integer> expectedCounts = new HashMap<>();
+            List<BucketCollector> collectors = new ArrayList<>();
+            final int numCollectors = randomIntBetween(1, 5);
+            for (int i = 0; i < numCollectors; ++i) {
+                final int terminateAfter = random().nextInt(numDocs + 10);
+                final int expectedCount = terminateAfter > numDocs ? numDocs : terminateAfter;
+                TotalHitCountBucketCollector collector = new TotalHitCountBucketCollector();
+                expectedCounts.put(collector, expectedCount);
+                collectors.add(new TerminateAfterBucketCollector(collector, terminateAfter));
+            }
+            searcher.search(new MatchAllDocsQuery(), MultiBucketCollector.wrap(collectors));
+            for (Map.Entry<TotalHitCountBucketCollector, Integer> expectedCount : expectedCounts.entrySet()) {
+                assertEquals(expectedCount.getValue().intValue(), expectedCount.getKey().getTotalHits());
+            }
+            reader.close();
+            dir.close();
+        }
+    }
+
+    public void testSetScorerAfterCollectionTerminated() throws IOException {
+        BucketCollector collector1 = new TotalHitCountBucketCollector();
+        BucketCollector collector2 = new TotalHitCountBucketCollector();
+
+        AtomicBoolean setScorerCalled1 = new AtomicBoolean();
+        collector1 = new SetScorerBucketCollector(collector1, setScorerCalled1);
+
+        AtomicBoolean setScorerCalled2 = new AtomicBoolean();
+        collector2 = new SetScorerBucketCollector(collector2, setScorerCalled2);
+
+        collector1 = new TerminateAfterBucketCollector(collector1, 1);
+        collector2 = new TerminateAfterBucketCollector(collector2, 2);
+
+        Scorer scorer = new FakeScorer();
+
+        List<BucketCollector> collectors = Arrays.asList(collector1, collector2);
+        Collections.shuffle(collectors, random());
+        BucketCollector collector = MultiBucketCollector.wrap(collectors);
+
+        LeafBucketCollector leafCollector = collector.getLeafCollector(null);
+        leafCollector.setScorer(scorer);
+        assertTrue(setScorerCalled1.get());
+        assertTrue(setScorerCalled2.get());
+
+        leafCollector.collect(0);
+        leafCollector.collect(1);
+
+        setScorerCalled1.set(false);
+        setScorerCalled2.set(false);
+        leafCollector.setScorer(scorer);
+        assertFalse(setScorerCalled1.get());
+        assertTrue(setScorerCalled2.get());
+
+        expectThrows(CollectionTerminatedException.class, () -> {
+            leafCollector.collect(1);
+        });
+
+        setScorerCalled1.set(false);
+        setScorerCalled2.set(false);
+        leafCollector.setScorer(scorer);
+        assertFalse(setScorerCalled1.get());
+        assertFalse(setScorerCalled2.get());
+    }
+}