Browse Source

Improve cpu utilization with dynamic slice size in doc partitioning (#132774)

We have seen CPU underutilization in metrics queries against large 
indices when using either SEGMENT or DOC partitioning:

1. SEGMENT partitioning does not split large segments, so a single 
driver may process the entire query if most matching documents in a few
segments.

2. DOC partitioning creates a fixed number of slices. If matching 
documents are concentrated in a few slices, a single driver may execute
the entire query.

This PR introduces dynamic-sized partitioning for DOC to address CPU 
underutilization while keeping overhead small:

Partitioning starts with a desired partition size based on 
task_concurrency and caps the slice size at approximately 250K
documents, preventing underutilization when matching documents are
concentrated in one area.

For small and medium segments (less than five times the desired slice 
size), a variant of segment partitioning is used, which also splits
segments larger than the desired size as needed.

To prevent multiple drivers from working on the same large segment 
unnecessarily, a single driver processes a segment sequentially until
work-stealing occurs. This is accomplished by passing the current slice
when polling for the next, allowing the queue to provide the next
sequential slice from the same segment. New drivers are assigned slices
from segments not currently being processed.
Nhat Nguyen 2 months ago
parent
commit
f9cdaaf3d4

+ 5 - 0
docs/changelog/132774.yaml

@@ -0,0 +1,5 @@
+pr: 132774
+summary: Improve cpu utilization with dynamic slice size in doc partitioning
+area: ES|QL
+type: enhancement
+issues: []

+ 17 - 4
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/DataPartitioning.java

@@ -9,6 +9,8 @@ package org.elasticsearch.compute.lucene;
 
 import org.elasticsearch.compute.operator.Driver;
 
+import java.util.List;
+
 /**
  * How we partition the data across {@link Driver}s. Each request forks into
  * {@code min(1.5 * cpus, partition_count)} threads on the data node. More partitions
@@ -37,9 +39,20 @@ public enum DataPartitioning {
      */
     SEGMENT,
     /**
-     * Partition each shard into {@code task_concurrency} partitions, splitting
-     * larger segments into slices. This allows bringing the most CPUs to bear on
-     * the problem but adds extra overhead, especially in query preparation.
+     * Partitions into dynamic-sized slices to improve CPU utilization while keeping overhead low.
+     * This approach is more flexible than {@link #SEGMENT} and works as follows:
+     *
+     * <ol>
+     *   <li>The slice size starts from a desired size based on {@code task_concurrency} but is capped
+     *       at around {@link LuceneSliceQueue#MAX_DOCS_PER_SLICE}. This prevents poor CPU usage when
+     *       matching documents are clustered together.</li>
+     *   <li>For small and medium segments (less than five times the desired slice size), it uses a
+     *       slightly different {@link #SEGMENT} strategy, which also splits segments that are larger
+     *       than the desired size. See {@link org.apache.lucene.search.IndexSearcher#slices(List, int, int, boolean)}.</li>
+     *   <li>For very large segments, multiple segments are not combined into a single slice. This allows
+     *       one driver to process an entire large segment until other drivers steal the work after finishing
+     *       their own tasks. See {@link LuceneSliceQueue#nextSlice(LuceneSlice)}.</li>
+     * </ol>
      */
-    DOC,
+    DOC
 }

+ 1 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java

@@ -165,7 +165,7 @@ public abstract class LuceneOperator extends SourceOperator {
         while (currentScorer == null || currentScorer.isDone()) {
             if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
                 sliceIndex = 0;
-                currentSlice = sliceQueue.nextSlice();
+                currentSlice = sliceQueue.nextSlice(currentSlice);
                 if (currentSlice == null) {
                     doneCollecting = true;
                     return null;

+ 7 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSlice.java

@@ -14,7 +14,13 @@ import java.util.List;
 /**
  * Holds a list of multiple partial Lucene segments
  */
-public record LuceneSlice(ShardContext shardContext, List<PartialLeafReaderContext> leaves, Weight weight, List<Object> tags) {
+public record LuceneSlice(
+    int slicePosition,
+    ShardContext shardContext,
+    List<PartialLeafReaderContext> leaves,
+    Weight weight,
+    List<Object> tags
+) {
     int numLeaves() {
         return leaves.size();
     }

+ 150 - 57
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSliceQueue.java

@@ -16,6 +16,7 @@ import org.apache.lucene.search.Weight;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.core.Nullable;
 
 import java.io.IOException;
@@ -23,11 +24,13 @@ import java.io.UncheckedIOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Queue;
-import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicReferenceArray;
 import java.util.function.Function;
 
 /**
@@ -77,18 +80,78 @@ public final class LuceneSliceQueue {
     public static final int MAX_SEGMENTS_PER_SLICE = 5; // copied from IndexSearcher
 
     private final int totalSlices;
-    private final Queue<LuceneSlice> slices;
     private final Map<String, PartitioningStrategy> partitioningStrategies;
 
-    private LuceneSliceQueue(List<LuceneSlice> slices, Map<String, PartitioningStrategy> partitioningStrategies) {
-        this.totalSlices = slices.size();
-        this.slices = new ConcurrentLinkedQueue<>(slices);
+    private final AtomicReferenceArray<LuceneSlice> slices;
+    /**
+     * Queue of slice IDs that are the primary entry point for a new group of segments.
+     * A driver should prioritize polling from this queue after failing to get a sequential
+     * slice (the segment affinity). This ensures that threads start work on fresh,
+     * independent segment groups before resorting to work stealing.
+     */
+    private final Queue<Integer> sliceHeads;
+
+    /**
+     * Queue of slice IDs that are not the primary entry point for a segment group.
+     * This queue serves as a fallback pool for work stealing. When a thread has no more independent work,
+     * it will "steal" a slice from this queue to keep itself utilized. A driver should pull tasks from
+     * this queue only when {@code sliceHeads} has been exhausted.
+     */
+    private final Queue<Integer> stealableSlices;
+
+    LuceneSliceQueue(List<LuceneSlice> sliceList, Map<String, PartitioningStrategy> partitioningStrategies) {
+        this.totalSlices = sliceList.size();
+        this.slices = new AtomicReferenceArray<>(sliceList.size());
+        for (int i = 0; i < sliceList.size(); i++) {
+            slices.set(i, sliceList.get(i));
+        }
         this.partitioningStrategies = partitioningStrategies;
+        this.sliceHeads = ConcurrentCollections.newQueue();
+        this.stealableSlices = ConcurrentCollections.newQueue();
+        for (LuceneSlice slice : sliceList) {
+            if (slice.getLeaf(0).minDoc() == 0) {
+                sliceHeads.add(slice.slicePosition());
+            } else {
+                stealableSlices.add(slice.slicePosition());
+            }
+        }
     }
 
+    /**
+     * Retrieves the next available {@link LuceneSlice} for processing.
+     * <p>
+     * This method implements a three-tiered strategy to minimize the overhead of switching between segments:
+     * 1. If a previous slice is provided, it first attempts to return the next sequential slice.
+     * This keeps a thread working on the same segments, minimizing the overhead of segment switching.
+     * 2. If affinity fails, it returns a slice from the {@link #sliceHeads} queue, which is an entry point for
+     * a new, independent group of segments, allowing the calling Driver to work on a fresh set of segments.
+     * 3. If the {@link #sliceHeads} queue is exhausted, it "steals" a slice
+     * from the {@link #stealableSlices} queue. This fallback ensures all threads remain utilized.
+     *
+     * @param prev the previously returned {@link LuceneSlice}, or {@code null} if starting
+     * @return the next available {@link LuceneSlice}, or {@code null} if exhausted
+     */
     @Nullable
-    public LuceneSlice nextSlice() {
-        return slices.poll();
+    public LuceneSlice nextSlice(LuceneSlice prev) {
+        if (prev != null) {
+            final int nextId = prev.slicePosition() + 1;
+            if (nextId < totalSlices) {
+                var slice = slices.getAndSet(nextId, null);
+                if (slice != null) {
+                    return slice;
+                }
+            }
+        }
+        for (var ids : List.of(sliceHeads, stealableSlices)) {
+            Integer nextId;
+            while ((nextId = ids.poll()) != null) {
+                var slice = slices.getAndSet(nextId, null);
+                if (slice != null) {
+                    return slice;
+                }
+            }
+        }
+        return null;
     }
 
     public int totalSlices() {
@@ -103,7 +166,14 @@ public final class LuceneSliceQueue {
     }
 
     public Collection<String> remainingShardsIdentifiers() {
-        return slices.stream().map(slice -> slice.shardContext().shardIdentifier()).toList();
+        List<String> remaining = new ArrayList<>(slices.length());
+        for (int i = 0; i < slices.length(); i++) {
+            LuceneSlice slice = slices.get(i);
+            if (slice != null) {
+                remaining.add(slice.shardContext().shardIdentifier());
+            }
+        }
+        return remaining;
     }
 
     public static LuceneSliceQueue create(
@@ -117,6 +187,7 @@ public final class LuceneSliceQueue {
         List<LuceneSlice> slices = new ArrayList<>();
         Map<String, PartitioningStrategy> partitioningStrategies = new HashMap<>(contexts.size());
 
+        int nextSliceId = 0;
         for (ShardContext ctx : contexts) {
             for (QueryAndTags queryAndExtra : queryFunction.apply(ctx)) {
                 var scoreMode = scoreModeFunction.apply(ctx);
@@ -140,7 +211,7 @@ public final class LuceneSliceQueue {
                 Weight weight = weight(ctx, query, scoreMode);
                 for (List<PartialLeafReaderContext> group : groups) {
                     if (group.isEmpty() == false) {
-                        slices.add(new LuceneSlice(ctx, group, weight, queryAndExtra.tags));
+                        slices.add(new LuceneSlice(nextSliceId++, ctx, group, weight, queryAndExtra.tags));
                     }
                 }
             }
@@ -158,7 +229,7 @@ public final class LuceneSliceQueue {
          */
         SHARD(0) {
             @Override
-            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
+            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency) {
                 return List.of(searcher.getLeafContexts().stream().map(PartialLeafReaderContext::new).toList());
             }
         },
@@ -167,7 +238,7 @@ public final class LuceneSliceQueue {
          */
         SEGMENT(1) {
             @Override
-            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
+            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency) {
                 IndexSearcher.LeafSlice[] gs = IndexSearcher.slices(
                     searcher.getLeafContexts(),
                     MAX_DOCS_PER_SLICE,
@@ -182,52 +253,11 @@ public final class LuceneSliceQueue {
          */
         DOC(2) {
             @Override
-            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices) {
+            List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency) {
                 final int totalDocCount = searcher.getIndexReader().maxDoc();
-                final int normalMaxDocsPerSlice = totalDocCount / requestedNumSlices;
-                final int extraDocsInFirstSlice = totalDocCount % requestedNumSlices;
-                final List<List<PartialLeafReaderContext>> slices = new ArrayList<>();
-                int docsAllocatedInCurrentSlice = 0;
-                List<PartialLeafReaderContext> currentSlice = null;
-                int maxDocsPerSlice = normalMaxDocsPerSlice + extraDocsInFirstSlice;
-                for (LeafReaderContext ctx : searcher.getLeafContexts()) {
-                    final int numDocsInLeaf = ctx.reader().maxDoc();
-                    int minDoc = 0;
-                    while (minDoc < numDocsInLeaf) {
-                        int numDocsToUse = Math.min(maxDocsPerSlice - docsAllocatedInCurrentSlice, numDocsInLeaf - minDoc);
-                        if (numDocsToUse <= 0) {
-                            break;
-                        }
-                        if (currentSlice == null) {
-                            currentSlice = new ArrayList<>();
-                        }
-                        currentSlice.add(new PartialLeafReaderContext(ctx, minDoc, minDoc + numDocsToUse));
-                        minDoc += numDocsToUse;
-                        docsAllocatedInCurrentSlice += numDocsToUse;
-                        if (docsAllocatedInCurrentSlice == maxDocsPerSlice) {
-                            slices.add(currentSlice);
-                            // once the first slice with the extra docs is added, no need for extra docs
-                            maxDocsPerSlice = normalMaxDocsPerSlice;
-                            currentSlice = null;
-                            docsAllocatedInCurrentSlice = 0;
-                        }
-                    }
-                }
-                if (currentSlice != null) {
-                    slices.add(currentSlice);
-                }
-                if (requestedNumSlices < totalDocCount && slices.size() != requestedNumSlices) {
-                    throw new IllegalStateException("wrong number of slices, expected " + requestedNumSlices + " but got " + slices.size());
-                }
-                if (slices.stream()
-                    .flatMapToInt(
-                        l -> l.stream()
-                            .mapToInt(partialLeafReaderContext -> partialLeafReaderContext.maxDoc() - partialLeafReaderContext.minDoc())
-                    )
-                    .sum() != totalDocCount) {
-                    throw new IllegalStateException("wrong doc count");
-                }
-                return slices;
+                // Cap the desired slice to prevent CPU underutilization when matching documents are concentrated in one segment region.
+                int desiredSliceSize = Math.clamp(Math.ceilDiv(totalDocCount, taskConcurrency), 1, MAX_DOCS_PER_SLICE);
+                return new AdaptivePartitioner(Math.max(1, desiredSliceSize), MAX_SEGMENTS_PER_SLICE).partition(searcher.getLeafContexts());
             }
         };
 
@@ -252,7 +282,7 @@ public final class LuceneSliceQueue {
             out.writeByte(id);
         }
 
-        abstract List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int requestedNumSlices);
+        abstract List<List<PartialLeafReaderContext>> groups(IndexSearcher searcher, int taskConcurrency);
 
         private static PartitioningStrategy pick(
             DataPartitioning dataPartitioning,
@@ -291,4 +321,67 @@ public final class LuceneSliceQueue {
             throw new UncheckedIOException(e);
         }
     }
+
+    static final class AdaptivePartitioner {
+        final int desiredDocsPerSlice;
+        final int maxDocsPerSlice;
+        final int maxSegmentsPerSlice;
+
+        AdaptivePartitioner(int desiredDocsPerSlice, int maxSegmentsPerSlice) {
+            this.desiredDocsPerSlice = desiredDocsPerSlice;
+            this.maxDocsPerSlice = desiredDocsPerSlice * 5 / 4;
+            this.maxSegmentsPerSlice = maxSegmentsPerSlice;
+        }
+
+        List<List<PartialLeafReaderContext>> partition(List<LeafReaderContext> leaves) {
+            List<LeafReaderContext> smallSegments = new ArrayList<>();
+            List<LeafReaderContext> largeSegments = new ArrayList<>();
+            List<List<PartialLeafReaderContext>> results = new ArrayList<>();
+            for (LeafReaderContext leaf : leaves) {
+                if (leaf.reader().maxDoc() >= 5 * desiredDocsPerSlice) {
+                    largeSegments.add(leaf);
+                } else {
+                    smallSegments.add(leaf);
+                }
+            }
+            largeSegments.sort(Collections.reverseOrder(Comparator.comparingInt(l -> l.reader().maxDoc())));
+            for (LeafReaderContext segment : largeSegments) {
+                results.addAll(partitionOneLargeSegment(segment));
+            }
+            results.addAll(partitionSmallSegments(smallSegments));
+            return results;
+        }
+
+        List<List<PartialLeafReaderContext>> partitionOneLargeSegment(LeafReaderContext leaf) {
+            int numDocsInLeaf = leaf.reader().maxDoc();
+            int numSlices = Math.max(1, numDocsInLeaf / desiredDocsPerSlice);
+            while (Math.ceilDiv(numDocsInLeaf, numSlices) > maxDocsPerSlice) {
+                numSlices++;
+            }
+            int docPerSlice = numDocsInLeaf / numSlices;
+            int leftoverDocs = numDocsInLeaf % numSlices;
+            int minDoc = 0;
+            List<List<PartialLeafReaderContext>> results = new ArrayList<>();
+            while (minDoc < numDocsInLeaf) {
+                int docsToUse = docPerSlice;
+                if (leftoverDocs > 0) {
+                    --leftoverDocs;
+                    docsToUse++;
+                }
+                int maxDoc = Math.min(minDoc + docsToUse, numDocsInLeaf);
+                results.add(List.of(new PartialLeafReaderContext(leaf, minDoc, maxDoc)));
+                minDoc = maxDoc;
+            }
+            assert leftoverDocs == 0 : leftoverDocs;
+            assert results.stream().allMatch(s -> s.size() == 1) : "must have one partial leaf per slice";
+            assert results.stream().flatMapToInt(ss -> ss.stream().mapToInt(s -> s.maxDoc() - s.minDoc())).sum() == numDocsInLeaf;
+            return results;
+        }
+
+        List<List<PartialLeafReaderContext>> partitionSmallSegments(List<LeafReaderContext> leaves) {
+            var slices = IndexSearcher.slices(leaves, maxDocsPerSlice, maxSegmentsPerSlice, true);
+            return Arrays.stream(slices).map(g -> Arrays.stream(g.partitions).map(PartialLeafReaderContext::new).toList()).toList();
+        }
+    }
+
 }

+ 1 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java

@@ -97,7 +97,7 @@ public final class TimeSeriesSourceOperator extends LuceneOperator {
         long startInNanos = System.nanoTime();
         try {
             if (iterator == null) {
-                var slice = sliceQueue.nextSlice();
+                var slice = sliceQueue.nextSlice(null);
                 if (slice == null) {
                     doneCollecting = true;
                     return null;

+ 348 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSliceQueueTests.java

@@ -0,0 +1,348 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.lucene;
+
+import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.index.DocValuesSkipper;
+import org.apache.lucene.index.FieldInfos;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.LeafMetaData;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.NumericDocValues;
+import org.apache.lucene.index.PointValues;
+import org.apache.lucene.index.SortedDocValues;
+import org.apache.lucene.index.SortedNumericDocValues;
+import org.apache.lucene.index.SortedSetDocValues;
+import org.apache.lucene.index.StoredFields;
+import org.apache.lucene.index.TermVectors;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.util.Bits;
+import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
+import org.elasticsearch.test.ESTestCase;
+import org.hamcrest.Matchers;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Set;
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.TimeUnit;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.mockito.Mockito.mock;
+
+public class LuceneSliceQueueTests extends ESTestCase {
+
+    public void testBasics() {
+
+        LeafReaderContext leaf1 = new MockLeafReader(1000).getContext();
+        LeafReaderContext leaf2 = new MockLeafReader(1000).getContext();
+        LeafReaderContext leaf3 = new MockLeafReader(1000).getContext();
+        LeafReaderContext leaf4 = new MockLeafReader(1000).getContext();
+        var slice1 = new LuceneSlice(0, null, List.of(new PartialLeafReaderContext(leaf1, 0, 10)), null, null);
+
+        var slice2 = new LuceneSlice(1, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, null);
+        var slice3 = new LuceneSlice(2, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, null);
+
+        var slice4 = new LuceneSlice(3, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, null);
+        var slice5 = new LuceneSlice(4, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, null);
+        var slice6 = new LuceneSlice(5, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, null);
+
+        var slice7 = new LuceneSlice(6, null, List.of(new PartialLeafReaderContext(leaf4, 0, 10)), null, null);
+        var slice8 = new LuceneSlice(7, null, List.of(new PartialLeafReaderContext(leaf4, 10, 20)), null, null);
+        List<LuceneSlice> sliceList = List.of(slice1, slice2, slice3, slice4, slice5, slice6, slice7, slice8);
+        // single driver
+        {
+            LuceneSliceQueue queue = new LuceneSliceQueue(sliceList, Map.of());
+            LuceneSlice last = null;
+            for (LuceneSlice slice : sliceList) {
+                last = queue.nextSlice(last);
+                assertEquals(slice, last);
+            }
+            assertNull(queue.nextSlice(randomBoolean() ? last : null));
+        }
+        // two drivers
+        {
+            LuceneSliceQueue queue = new LuceneSliceQueue(sliceList, Map.of());
+            LuceneSlice first = null;
+            LuceneSlice second = null;
+            first = queue.nextSlice(first);
+            assertEquals(slice1, first);
+            first = queue.nextSlice(first);
+            assertEquals(slice2, first);
+
+            second = queue.nextSlice(second);
+            assertEquals(slice4, second);
+            second = queue.nextSlice(second);
+            assertEquals(slice5, second);
+
+            first = queue.nextSlice(first);
+            assertEquals(slice3, first);
+            second = queue.nextSlice(second);
+            assertEquals(slice6, second);
+            first = queue.nextSlice(first);
+            assertEquals(slice7, first);
+
+            assertEquals(slice8, queue.nextSlice(randomFrom(first, second)));
+
+            assertNull(queue.nextSlice(first));
+            assertNull(queue.nextSlice(second));
+        }
+    }
+
+    public void testRandom() throws Exception {
+        List<LuceneSlice> sliceList = new ArrayList<>();
+        int numShards = randomIntBetween(1, 10);
+        int slicePosition = 0;
+        for (int shard = 0; shard < numShards; shard++) {
+            int numSegments = randomIntBetween(1, 10);
+            for (int segment = 0; segment < numSegments; segment++) {
+                int numSlices = randomBoolean() ? 1 : between(2, 5);
+                LeafReaderContext leafContext = new MockLeafReader(randomIntBetween(1000, 2000)).getContext();
+                for (int i = 0; i < numSlices; i++) {
+                    final int minDoc = i * 10;
+                    final int maxDoc = minDoc + 10;
+                    LuceneSlice slice = new LuceneSlice(
+                        slicePosition++,
+                        mock(ShardContext.class),
+                        List.of(new PartialLeafReaderContext(leafContext, minDoc, maxDoc)),
+                        null,
+                        null
+                    );
+                    sliceList.add(slice);
+                }
+            }
+        }
+        LuceneSliceQueue queue = new LuceneSliceQueue(sliceList, Map.of());
+        Queue<LuceneSlice> allProcessedSlices = ConcurrentCollections.newQueue();
+        int numDrivers = randomIntBetween(1, 5);
+        CyclicBarrier barrier = new CyclicBarrier(numDrivers + 1);
+        List<Thread> drivers = new ArrayList<>();
+        for (int d = 0; d < numDrivers; d++) {
+            Thread driver = new Thread(() -> {
+                try {
+                    barrier.await(1, TimeUnit.SECONDS);
+                } catch (Exception e) {
+                    throw new AssertionError(e);
+                }
+                LuceneSlice nextSlice = null;
+                List<LuceneSlice> processedSlices = new ArrayList<>();
+                while ((nextSlice = queue.nextSlice(nextSlice)) != null) {
+                    processedSlices.add(nextSlice);
+                }
+                allProcessedSlices.addAll(processedSlices);
+                // slices from a single driver are forward-only
+                for (int i = 1; i < processedSlices.size(); i++) {
+                    var currentLeaf = processedSlices.get(i).getLeaf(0);
+                    for (int p = 0; p < i; p++) {
+                        PartialLeafReaderContext prevLeaf = processedSlices.get(p).getLeaf(0);
+                        if (prevLeaf == currentLeaf) {
+                            assertThat(prevLeaf.minDoc(), Matchers.lessThanOrEqualTo(processedSlices.get(i).leaves().getFirst().maxDoc()));
+                        }
+                    }
+                }
+            });
+            drivers.add(driver);
+            driver.start();
+        }
+        barrier.await();
+        for (Thread driver : drivers) {
+            driver.join();
+        }
+        assertThat(allProcessedSlices, Matchers.hasSize(sliceList.size()));
+        assertThat(Set.copyOf(allProcessedSlices), equalTo(Set.copyOf(sliceList)));
+    }
+
+    public void testDocPartitioningBigSegments() {
+        LeafReaderContext leaf1 = new MockLeafReader(250).getContext();
+        LeafReaderContext leaf2 = new MockLeafReader(400).getContext();
+        LeafReaderContext leaf3 = new MockLeafReader(1_400_990).getContext();
+        LeafReaderContext leaf4 = new MockLeafReader(2_100_061).getContext();
+        LeafReaderContext leaf5 = new MockLeafReader(1_000_607).getContext();
+        var adaptivePartitioner = new LuceneSliceQueue.AdaptivePartitioner(250_000, 5);
+        List<List<PartialLeafReaderContext>> slices = adaptivePartitioner.partition(List.of(leaf1, leaf2, leaf3, leaf4, leaf5));
+        // leaf4: 2_100_061
+        int sliceOffset = 0;
+        {
+            List<Integer> sliceSizes = List.of(262508, 262508, 262508, 262508, 262508, 262507, 262507, 262507);
+            for (Integer sliceSize : sliceSizes) {
+                List<PartialLeafReaderContext> slice = slices.get(sliceOffset++);
+                assertThat(slice, hasSize(1));
+                assertThat(slice.getFirst().leafReaderContext(), equalTo(leaf4));
+                assertThat(slice.getFirst().maxDoc() - slice.getFirst().minDoc(), equalTo(sliceSize));
+            }
+        }
+        // leaf3: 1_400_990
+        {
+            List<Integer> sliceSizes = List.of(280198, 280198, 280198, 280198, 280198);
+            for (Integer sliceSize : sliceSizes) {
+                List<PartialLeafReaderContext> slice = slices.get(sliceOffset++);
+                assertThat(slice, hasSize(1));
+                assertThat(slice.getFirst().leafReaderContext(), equalTo(leaf3));
+                assertThat(slice.getFirst().maxDoc() - slice.getFirst().minDoc(), equalTo(sliceSize));
+            }
+        }
+        // leaf5: 1_000_600
+        {
+            List<Integer> sliceSizes = List.of(250151, 250151, 250151, 250154);
+            for (Integer sliceSize : sliceSizes) {
+                List<PartialLeafReaderContext> slice = slices.get(sliceOffset++);
+                assertThat(slice, hasSize(1));
+                var partialLeaf = slice.getFirst();
+                assertThat(partialLeaf.leafReaderContext(), equalTo(leaf5));
+                assertThat(partialLeaf.toString(), partialLeaf.maxDoc() - partialLeaf.minDoc(), equalTo(sliceSize));
+            }
+        }
+        // leaf2 and leaf1
+        {
+            List<PartialLeafReaderContext> slice = slices.get(sliceOffset++);
+            assertThat(slice, hasSize(2));
+            assertThat(slice.getFirst().leafReaderContext(), equalTo(leaf2));
+            assertThat(slice.getFirst().minDoc(), equalTo(0));
+            assertThat(slice.getFirst().maxDoc(), equalTo(Integer.MAX_VALUE));
+            assertThat(slice.getLast().leafReaderContext(), equalTo(leaf1));
+            assertThat(slice.getLast().minDoc(), equalTo(0));
+            assertThat(slice.getLast().maxDoc(), equalTo(Integer.MAX_VALUE));
+        }
+        assertThat(slices, hasSize(sliceOffset));
+    }
+
+    static class MockLeafReader extends LeafReader {
+        private final int maxDoc;
+
+        MockLeafReader(int maxDoc) {
+            this.maxDoc = maxDoc;
+        }
+
+        @Override
+        public CacheHelper getCoreCacheHelper() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public Terms terms(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public NumericDocValues getNumericDocValues(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public BinaryDocValues getBinaryDocValues(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public SortedDocValues getSortedDocValues(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public SortedNumericDocValues getSortedNumericDocValues(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public SortedSetDocValues getSortedSetDocValues(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public NumericDocValues getNormValues(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public DocValuesSkipper getDocValuesSkipper(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public FloatVectorValues getFloatVectorValues(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public ByteVectorValues getByteVectorValues(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public void searchNearestVectors(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public void searchNearestVectors(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public FieldInfos getFieldInfos() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public Bits getLiveDocs() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public PointValues getPointValues(String field) throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public void checkIntegrity() throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public LeafMetaData getMetaData() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public TermVectors termVectors() throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public int numDocs() {
+            return maxDoc;
+        }
+
+        @Override
+        public int maxDoc() {
+            return maxDoc;
+        }
+
+        @Override
+        public StoredFields storedFields() throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        protected void doClose() throws IOException {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public CacheHelper getReaderCacheHelper() {
+            throw new UnsupportedOperationException();
+        }
+    }
+}