Browse Source

Add query heads priority to SliceQueue (#133245)

With query and tags, SliceQueue will contain more slices (see #132512). 
This change introduces an additional priority for query heads, allowing 
Drivers to pull slices from the same query and segment first. This
minimizes the overhead of switching between queries and segments.

Relates #132774
Nhat Nguyen 1 month ago
parent
commit
7d678a3a76

+ 5 - 0
docs/changelog/133245.yaml

@@ -0,0 +1,5 @@
+pr: 133245
+summary: Add query heads priority to `SliceQueue`
+area: ES|QL
+type: enhancement
+issues: []

+ 1 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSlice.java

@@ -16,6 +16,7 @@ import java.util.List;
  */
 public record LuceneSlice(
     int slicePosition,
+    boolean queryHead,
     ShardContext shardContext,
     List<PartialLeafReaderContext> leaves,
     Weight weight,

+ 27 - 11
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSliceQueue.java

@@ -83,13 +83,21 @@ public final class LuceneSliceQueue {
     private final Map<String, PartitioningStrategy> partitioningStrategies;
 
     private final AtomicReferenceArray<LuceneSlice> slices;
+    /**
+     * Queue of slice IDs that are the primary entry point for a new query.
+     * A driver should prioritize polling from this queue after failing to get a sequential
+     * slice (the query/segment affinity). This ensures that threads start work on fresh,
+     * independent query before stealing segments from other queries.
+     */
+    private final Queue<Integer> queryHeads;
+
     /**
      * Queue of slice IDs that are the primary entry point for a new group of segments.
      * A driver should prioritize polling from this queue after failing to get a sequential
      * slice (the segment affinity). This ensures that threads start work on fresh,
      * independent segment groups before resorting to work stealing.
      */
-    private final Queue<Integer> sliceHeads;
+    private final Queue<Integer> segmentHeads;
 
     /**
      * Queue of slice IDs that are not the primary entry point for a segment group.
@@ -106,11 +114,14 @@ public final class LuceneSliceQueue {
             slices.set(i, sliceList.get(i));
         }
         this.partitioningStrategies = partitioningStrategies;
-        this.sliceHeads = ConcurrentCollections.newQueue();
+        this.queryHeads = ConcurrentCollections.newQueue();
+        this.segmentHeads = ConcurrentCollections.newQueue();
         this.stealableSlices = ConcurrentCollections.newQueue();
         for (LuceneSlice slice : sliceList) {
-            if (slice.getLeaf(0).minDoc() == 0) {
-                sliceHeads.add(slice.slicePosition());
+            if (slice.queryHead()) {
+                queryHeads.add(slice.slicePosition());
+            } else if (slice.getLeaf(0).minDoc() == 0) {
+                segmentHeads.add(slice.slicePosition());
             } else {
                 stealableSlices.add(slice.slicePosition());
             }
@@ -120,12 +131,14 @@ public final class LuceneSliceQueue {
     /**
      * Retrieves the next available {@link LuceneSlice} for processing.
      * <p>
-     * This method implements a three-tiered strategy to minimize the overhead of switching between segments:
+     * This method implements a four-tiered strategy to minimize the overhead of switching between queries/segments:
      * 1. If a previous slice is provided, it first attempts to return the next sequential slice.
-     * This keeps a thread working on the same segments, minimizing the overhead of segment switching.
-     * 2. If affinity fails, it returns a slice from the {@link #sliceHeads} queue, which is an entry point for
-     * a new, independent group of segments, allowing the calling Driver to work on a fresh set of segments.
-     * 3. If the {@link #sliceHeads} queue is exhausted, it "steals" a slice
+     * This keeps a thread working on the same query and same segment, minimizing the overhead of query/segment switching.
+     * 2. If affinity fails, it returns a slice from the {@link #queryHeads} queue, which is an entry point for
+     * a new query, allowing the calling Driver to work on a fresh query with a new set of segments.
+     * 3. If the {@link #queryHeads} queue is exhausted, it returns a slice from the {@link #segmentHeads} queue of other queries,
+     * which is an entry point for a new, independent group of segments, allowing the calling Driver to work on a fresh set of segments.
+     * 4. If the {@link #segmentHeads} queue is exhausted, it "steals" a slice
      * from the {@link #stealableSlices} queue. This fallback ensures all threads remain utilized.
      *
      * @param prev the previously returned {@link LuceneSlice}, or {@code null} if starting
@@ -142,7 +155,7 @@ public final class LuceneSliceQueue {
                 }
             }
         }
-        for (var ids : List.of(sliceHeads, stealableSlices)) {
+        for (var ids : List.of(queryHeads, segmentHeads, stealableSlices)) {
             Integer nextId;
             while ((nextId = ids.poll()) != null) {
                 var slice = slices.getAndSet(nextId, null);
@@ -209,9 +222,12 @@ public final class LuceneSliceQueue {
                 partitioningStrategies.put(ctx.shardIdentifier(), partitioning);
                 List<List<PartialLeafReaderContext>> groups = partitioning.groups(ctx.searcher(), taskConcurrency);
                 Weight weight = weight(ctx, query, scoreMode);
+                boolean queryHead = true;
                 for (List<PartialLeafReaderContext> group : groups) {
                     if (group.isEmpty() == false) {
-                        slices.add(new LuceneSlice(nextSliceId++, ctx, group, weight, queryAndExtra.tags));
+                        final int slicePosition = nextSliceId++;
+                        slices.add(new LuceneSlice(slicePosition, queryHead, ctx, group, weight, queryAndExtra.tags));
+                        queryHead = false;
                     }
                 }
             }

+ 50 - 28
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSliceQueueTests.java

@@ -50,18 +50,28 @@ public class LuceneSliceQueueTests extends ESTestCase {
         LeafReaderContext leaf2 = new MockLeafReader(1000).getContext();
         LeafReaderContext leaf3 = new MockLeafReader(1000).getContext();
         LeafReaderContext leaf4 = new MockLeafReader(1000).getContext();
-        var slice1 = new LuceneSlice(0, null, List.of(new PartialLeafReaderContext(leaf1, 0, 10)), null, null);
-
-        var slice2 = new LuceneSlice(1, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, null);
-        var slice3 = new LuceneSlice(2, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, null);
-
-        var slice4 = new LuceneSlice(3, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, null);
-        var slice5 = new LuceneSlice(4, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, null);
-        var slice6 = new LuceneSlice(5, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, null);
-
-        var slice7 = new LuceneSlice(6, null, List.of(new PartialLeafReaderContext(leaf4, 0, 10)), null, null);
-        var slice8 = new LuceneSlice(7, null, List.of(new PartialLeafReaderContext(leaf4, 10, 20)), null, null);
-        List<LuceneSlice> sliceList = List.of(slice1, slice2, slice3, slice4, slice5, slice6, slice7, slice8);
+        List<Object> query1 = List.of("1");
+        List<Object> query2 = List.of("q2");
+        List<LuceneSlice> sliceList = List.of(
+            // query1: new segment
+            new LuceneSlice(0, true, null, List.of(new PartialLeafReaderContext(leaf1, 0, 10)), null, query1),
+            new LuceneSlice(1, false, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, query1),
+            new LuceneSlice(2, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, query1),
+            // query1: new segment
+            new LuceneSlice(3, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, query1),
+            new LuceneSlice(4, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, query1),
+            new LuceneSlice(5, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, query1),
+            // query1: new segment
+            new LuceneSlice(6, false, null, List.of(new PartialLeafReaderContext(leaf4, 0, 10)), null, query1),
+            new LuceneSlice(7, false, null, List.of(new PartialLeafReaderContext(leaf4, 10, 20)), null, query1),
+            // query2: new segment
+            new LuceneSlice(8, true, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, query2),
+            new LuceneSlice(9, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, query2),
+            // query1: new segment
+            new LuceneSlice(10, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, query2),
+            new LuceneSlice(11, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, query2),
+            new LuceneSlice(12, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, query2)
+        );
         // single driver
         {
             LuceneSliceQueue queue = new LuceneSliceQueue(sliceList, Map.of());
@@ -72,32 +82,43 @@ public class LuceneSliceQueueTests extends ESTestCase {
             }
             assertNull(queue.nextSlice(randomBoolean() ? last : null));
         }
-        // two drivers
+        // three drivers
         {
             LuceneSliceQueue queue = new LuceneSliceQueue(sliceList, Map.of());
+
             LuceneSlice first = null;
             LuceneSlice second = null;
+            LuceneSlice third = null;
             first = queue.nextSlice(first);
-            assertEquals(slice1, first);
+            assertEquals(sliceList.get(0), first);
             first = queue.nextSlice(first);
-            assertEquals(slice2, first);
+            assertEquals(sliceList.get(1), first);
 
             second = queue.nextSlice(second);
-            assertEquals(slice4, second);
+            assertEquals(sliceList.get(8), second);
             second = queue.nextSlice(second);
-            assertEquals(slice5, second);
+            assertEquals(sliceList.get(9), second);
 
             first = queue.nextSlice(first);
-            assertEquals(slice3, first);
-            second = queue.nextSlice(second);
-            assertEquals(slice6, second);
+            assertEquals(sliceList.get(2), first);
+            third = queue.nextSlice(third);
+            assertEquals(sliceList.get(3), third);
             first = queue.nextSlice(first);
-            assertEquals(slice7, first);
-
-            assertEquals(slice8, queue.nextSlice(randomFrom(first, second)));
+            assertEquals(sliceList.get(6), first);
 
-            assertNull(queue.nextSlice(first));
-            assertNull(queue.nextSlice(second));
+            first = queue.nextSlice(first);
+            assertEquals(sliceList.get(7), first);
+            third = queue.nextSlice(third);
+            assertEquals(sliceList.get(4), third);
+            first = queue.nextSlice(first);
+            assertEquals(sliceList.get(10), first);
+            first = queue.nextSlice(first);
+            assertEquals(sliceList.get(11), first);
+            second = queue.nextSlice(second);
+            assertEquals(sliceList.get(5), second);
+            second = queue.nextSlice(second);
+            assertEquals(sliceList.get(12), second);
+            assertNull(null, queue.nextSlice(randomFrom(sliceList)));
         }
     }
 
@@ -108,13 +129,14 @@ public class LuceneSliceQueueTests extends ESTestCase {
         for (int shard = 0; shard < numShards; shard++) {
             int numSegments = randomIntBetween(1, 10);
             for (int segment = 0; segment < numSegments; segment++) {
-                int numSlices = randomBoolean() ? 1 : between(2, 5);
+                int numSlices = between(10, 50);
                 LeafReaderContext leafContext = new MockLeafReader(randomIntBetween(1000, 2000)).getContext();
                 for (int i = 0; i < numSlices; i++) {
                     final int minDoc = i * 10;
                     final int maxDoc = minDoc + 10;
                     LuceneSlice slice = new LuceneSlice(
                         slicePosition++,
+                        false,
                         mock(ShardContext.class),
                         List.of(new PartialLeafReaderContext(leafContext, minDoc, maxDoc)),
                         null,
@@ -147,8 +169,8 @@ public class LuceneSliceQueueTests extends ESTestCase {
                     var currentLeaf = processedSlices.get(i).getLeaf(0);
                     for (int p = 0; p < i; p++) {
                         PartialLeafReaderContext prevLeaf = processedSlices.get(p).getLeaf(0);
-                        if (prevLeaf == currentLeaf) {
-                            assertThat(prevLeaf.minDoc(), Matchers.lessThanOrEqualTo(processedSlices.get(i).leaves().getFirst().maxDoc()));
+                        if (prevLeaf.leafReaderContext() == currentLeaf.leafReaderContext()) {
+                            assertThat(prevLeaf.maxDoc(), Matchers.lessThanOrEqualTo(currentLeaf.minDoc()));
                         }
                     }
                 }