Browse Source

Do not share Weight between Drivers (#133446)

We have encountered the following error in serverless:

```
java.lang.NullPointerException: Cannot invoke \"org.apache.lucene.search.BulkScorer.score(org.apache.lucene.search.LeafCollector, org.apache.lucene.util.Bits, int, int)\" because \"this.bulkScorer\" is null
at org.elasticsearch.compute.lucene.LuceneOperator$LuceneScorer.scoreNextRange(LuceneOperator.java:233)
at org.elasticsearch.compute.lucene.LuceneSourceOperator.getCheckedOutput(LuceneSourceOperator.java:307)
at org.elasticsearch.compute.lucene.LuceneOperator.getOutput(LuceneOperator.java:143)
at org.elasticsearch.compute.operator.Driver.runSingleLoopIteration(Driver.java:272)
at org.elasticsearch.compute.operator.Driver.run(Driver.java:186)
at org.elasticsearch.compute.operator.Driver$1.doRun(Driver.java:420)
```

I spent considerable time trying to reproduce this issue but was 
unsuccessful, although I understand how it could occur. Weight should
not be shared between threads. Most Weight implementations are safe to
share, but those for term queries (e.g., TermQuery, multi-term queries)
are not, as they contain mutable

This change proposes to stop sharing Weight between Drivers.
Nhat Nguyen 1 month ago
parent
commit
2f100656d1

+ 5 - 0
docs/changelog/133446.yaml

@@ -0,0 +1,5 @@
+pr: 133446
+summary: Do not share Weight between Drivers
+area: ES|QL
+type: bug
+issues: []

+ 62 - 32
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java

@@ -165,40 +165,61 @@ public abstract class LuceneOperator extends SourceOperator {
     protected void additionalClose() { /* Override this method to add any additional cleanup logic if needed */ }
 
     LuceneScorer getCurrentOrLoadNextScorer() {
-        while (currentScorer == null || currentScorer.isDone()) {
-            if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
-                sliceIndex = 0;
-                currentSlice = sliceQueue.nextSlice(currentSlice);
-                if (currentSlice == null) {
-                    doneCollecting = true;
+        while (true) {
+            while (currentScorer == null || currentScorer.isDone()) {
+                var partialLeaf = nextPartialLeaf();
+                if (partialLeaf == null) {
+                    assert doneCollecting;
                     return null;
                 }
-                processedSlices++;
-                processedShards.add(currentSlice.shardContext().shardIdentifier());
-                int shardId = currentSlice.shardContext().index();
-                if (currentScorerShardRefCounted == null || currentScorerShardRefCounted.index() != shardId) {
-                    currentScorerShardRefCounted = new ShardRefCounted.Single(shardId, shardContextCounters.get(shardId));
-                }
+                logger.trace("Starting {}", partialLeaf);
+                loadScorerForNewPartialLeaf(partialLeaf);
             }
-            final PartialLeafReaderContext partialLeaf = currentSlice.getLeaf(sliceIndex++);
-            logger.trace("Starting {}", partialLeaf);
-            final LeafReaderContext leaf = partialLeaf.leafReaderContext();
-            if (currentScorer == null // First time
-                || currentScorer.leafReaderContext() != leaf // Moved to a new leaf
-                || currentScorer.weight != currentSlice.weight() // Moved to a new query
-            ) {
-                final Weight weight = currentSlice.weight();
-                processedQueries.add(weight.getQuery());
-                currentScorer = new LuceneScorer(currentSlice.shardContext(), weight, currentSlice.tags(), leaf);
+            // Has the executing thread changed? If so, we need to reinitialize the scorer. The reinitialized bulkScorer
+            // can be null even if it was non-null previously, due to lazy initialization in Weight#bulkScorer.
+            // Hence, we need to check the previous condition again.
+            if (currentScorer.executingThread == Thread.currentThread()) {
+                return currentScorer;
+            } else {
+                currentScorer.reinitialize();
             }
-            assert currentScorer.maxPosition <= partialLeaf.maxDoc() : currentScorer.maxPosition + ">" + partialLeaf.maxDoc();
-            currentScorer.maxPosition = partialLeaf.maxDoc();
-            currentScorer.position = Math.max(currentScorer.position, partialLeaf.minDoc());
         }
-        if (Thread.currentThread() != currentScorer.executingThread) {
-            currentScorer.reinitialize();
+    }
+
+    private PartialLeafReaderContext nextPartialLeaf() {
+        if (currentSlice == null || sliceIndex >= currentSlice.numLeaves()) {
+            sliceIndex = 0;
+            currentSlice = sliceQueue.nextSlice(currentSlice);
+            if (currentSlice == null) {
+                doneCollecting = true;
+                return null;
+            }
+            processedSlices++;
+            int shardId = currentSlice.shardContext().index();
+            if (currentScorerShardRefCounted == null || currentScorerShardRefCounted.index() != shardId) {
+                currentScorerShardRefCounted = new ShardRefCounted.Single(shardId, shardContextCounters.get(shardId));
+            }
+            processedShards.add(currentSlice.shardContext().shardIdentifier());
         }
-        return currentScorer;
+        return currentSlice.getLeaf(sliceIndex++);
+    }
+
+    private void loadScorerForNewPartialLeaf(PartialLeafReaderContext partialLeaf) {
+        final LeafReaderContext leaf = partialLeaf.leafReaderContext();
+        if (currentScorer != null
+            && currentScorer.query() == currentSlice.query()
+            && currentScorer.shardContext == currentSlice.shardContext()) {
+            if (currentScorer.leafReaderContext != leaf) {
+                currentScorer = new LuceneScorer(currentSlice.shardContext(), currentScorer.weight, currentSlice.queryAndTags(), leaf);
+            }
+        } else {
+            final var weight = currentSlice.createWeight();
+            currentScorer = new LuceneScorer(currentSlice.shardContext(), weight, currentSlice.queryAndTags(), leaf);
+            processedQueries.add(currentScorer.query());
+        }
+        assert currentScorer.maxPosition <= partialLeaf.maxDoc() : currentScorer.maxPosition + ">" + partialLeaf.maxDoc();
+        currentScorer.maxPosition = partialLeaf.maxDoc();
+        currentScorer.position = Math.max(currentScorer.position, partialLeaf.minDoc());
     }
 
     /**
@@ -214,18 +235,23 @@ public abstract class LuceneOperator extends SourceOperator {
     static final class LuceneScorer {
         private final ShardContext shardContext;
         private final Weight weight;
+        private final LuceneSliceQueue.QueryAndTags queryAndTags;
         private final LeafReaderContext leafReaderContext;
-        private final List<Object> tags;
 
         private BulkScorer bulkScorer;
         private int position;
         private int maxPosition;
         private Thread executingThread;
 
-        LuceneScorer(ShardContext shardContext, Weight weight, List<Object> tags, LeafReaderContext leafReaderContext) {
+        LuceneScorer(
+            ShardContext shardContext,
+            Weight weight,
+            LuceneSliceQueue.QueryAndTags queryAndTags,
+            LeafReaderContext leafReaderContext
+        ) {
             this.shardContext = shardContext;
             this.weight = weight;
-            this.tags = tags;
+            this.queryAndTags = queryAndTags;
             this.leafReaderContext = leafReaderContext;
             reinitialize();
         }
@@ -275,7 +301,11 @@ public abstract class LuceneOperator extends SourceOperator {
          * Tags to add to the data returned by this query.
          */
         List<Object> tags() {
-            return tags;
+            return queryAndTags.tags();
+        }
+
+        Query query() {
+            return queryAndTags.query();
         }
     }
 

+ 24 - 2
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSlice.java

@@ -7,8 +7,12 @@
 
 package org.elasticsearch.compute.lucene;
 
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.Weight;
 
+import java.io.IOException;
+import java.io.UncheckedIOException;
 import java.util.List;
 
 /**
@@ -19,9 +23,10 @@ public record LuceneSlice(
     boolean queryHead,
     ShardContext shardContext,
     List<PartialLeafReaderContext> leaves,
-    Weight weight,
-    List<Object> tags
+    ScoreMode scoreMode,
+    LuceneSliceQueue.QueryAndTags queryAndTags
 ) {
+
     int numLeaves() {
         return leaves.size();
     }
@@ -29,4 +34,21 @@ public record LuceneSlice(
     PartialLeafReaderContext getLeaf(int index) {
         return leaves.get(index);
     }
+
+    Query query() {
+        return queryAndTags.query();
+    }
+
+    List<Object> tags() {
+        return queryAndTags.tags();
+    }
+
+    Weight createWeight() {
+        var searcher = shardContext.searcher();
+        try {
+            return searcher.createWeight(queryAndTags.query(), scoreMode, 1);
+        } catch (IOException e) {
+            throw new UncheckedIOException(e);
+        }
+    }
 }

+ 2 - 13
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSliceQueue.java

@@ -12,7 +12,6 @@ import org.apache.lucene.search.ConstantScoreQuery;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreMode;
-import org.apache.lucene.search.Weight;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -209,12 +208,12 @@ public final class LuceneSliceQueue {
                 PartitioningStrategy partitioning = PartitioningStrategy.pick(dataPartitioning, autoStrategy, ctx, query);
                 partitioningStrategies.put(ctx.shardIdentifier(), partitioning);
                 List<List<PartialLeafReaderContext>> groups = partitioning.groups(ctx.searcher(), taskConcurrency);
-                Weight weight = weight(ctx, query, scoreMode);
+                var rewrittenQueryAndTag = new QueryAndTags(query, queryAndExtra.tags);
                 boolean queryHead = true;
                 for (List<PartialLeafReaderContext> group : groups) {
                     if (group.isEmpty() == false) {
                         final int slicePosition = nextSliceId++;
-                        slices.add(new LuceneSlice(slicePosition, queryHead, ctx, group, weight, queryAndExtra.tags));
+                        slices.add(new LuceneSlice(slicePosition, queryHead, ctx, group, scoreMode, rewrittenQueryAndTag));
                         queryHead = false;
                     }
                 }
@@ -316,16 +315,6 @@ public final class LuceneSliceQueue {
         }
     }
 
-    static Weight weight(ShardContext ctx, Query query, ScoreMode scoreMode) {
-        var searcher = ctx.searcher();
-        try {
-            Query actualQuery = scoreMode.needsScores() ? query : new ConstantScoreQuery(query);
-            return searcher.createWeight(actualQuery, scoreMode, 1);
-        } catch (IOException e) {
-            throw new UncheckedIOException(e);
-        }
-    }
-
     static final class AdaptivePartitioner {
         final int desiredDocsPerSlice;
         final int maxDocsPerSlice;

+ 1 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSourceOperator.java

@@ -154,7 +154,7 @@ public final class TimeSeriesSourceOperator extends LuceneOperator {
                     return a.timeSeriesHash.compareTo(b.timeSeriesHash) < 0;
                 }
             };
-            Weight weight = luceneSlice.weight();
+            Weight weight = luceneSlice.createWeight();
             processedQueries.add(weight.getQuery());
             int maxSegmentOrd = 0;
             for (var leafReaderContext : luceneSlice.leaves()) {

+ 19 - 16
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSliceQueueTests.java

@@ -24,6 +24,8 @@ import org.apache.lucene.index.StoredFields;
 import org.apache.lucene.index.TermVectors;
 import org.apache.lucene.index.Terms;
 import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.util.Bits;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.test.ESTestCase;
@@ -50,27 +52,28 @@ public class LuceneSliceQueueTests extends ESTestCase {
         LeafReaderContext leaf2 = new MockLeafReader(1000).getContext();
         LeafReaderContext leaf3 = new MockLeafReader(1000).getContext();
         LeafReaderContext leaf4 = new MockLeafReader(1000).getContext();
-        List<Object> query1 = List.of("1");
-        List<Object> query2 = List.of("q2");
+        LuceneSliceQueue.QueryAndTags t1 = new LuceneSliceQueue.QueryAndTags(new MatchAllDocsQuery(), List.of("q1"));
+        LuceneSliceQueue.QueryAndTags t2 = new LuceneSliceQueue.QueryAndTags(new MatchAllDocsQuery(), List.of("q2"));
+        var scoreMode = ScoreMode.COMPLETE_NO_SCORES;
         List<LuceneSlice> sliceList = List.of(
             // query1: new segment
-            new LuceneSlice(0, true, null, List.of(new PartialLeafReaderContext(leaf1, 0, 10)), null, query1),
-            new LuceneSlice(1, false, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, query1),
-            new LuceneSlice(2, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, query1),
+            new LuceneSlice(0, true, null, List.of(new PartialLeafReaderContext(leaf1, 0, 10)), scoreMode, t1),
+            new LuceneSlice(1, false, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), scoreMode, t1),
+            new LuceneSlice(2, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), scoreMode, t1),
             // query1: new segment
-            new LuceneSlice(3, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, query1),
-            new LuceneSlice(4, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, query1),
-            new LuceneSlice(5, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, query1),
+            new LuceneSlice(3, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), scoreMode, t1),
+            new LuceneSlice(4, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), scoreMode, t1),
+            new LuceneSlice(5, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), scoreMode, t1),
             // query1: new segment
-            new LuceneSlice(6, false, null, List.of(new PartialLeafReaderContext(leaf4, 0, 10)), null, query1),
-            new LuceneSlice(7, false, null, List.of(new PartialLeafReaderContext(leaf4, 10, 20)), null, query1),
+            new LuceneSlice(6, false, null, List.of(new PartialLeafReaderContext(leaf4, 0, 10)), scoreMode, t1),
+            new LuceneSlice(7, false, null, List.of(new PartialLeafReaderContext(leaf4, 10, 20)), scoreMode, t1),
             // query2: new segment
-            new LuceneSlice(8, true, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), null, query2),
-            new LuceneSlice(9, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), null, query2),
+            new LuceneSlice(8, true, null, List.of(new PartialLeafReaderContext(leaf2, 0, 10)), scoreMode, t2),
+            new LuceneSlice(9, false, null, List.of(new PartialLeafReaderContext(leaf2, 10, 20)), scoreMode, t2),
             // query1: new segment
-            new LuceneSlice(10, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), null, query2),
-            new LuceneSlice(11, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), null, query2),
-            new LuceneSlice(12, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), null, query2)
+            new LuceneSlice(10, false, null, List.of(new PartialLeafReaderContext(leaf3, 0, 20)), scoreMode, t2),
+            new LuceneSlice(11, false, null, List.of(new PartialLeafReaderContext(leaf3, 10, 20)), scoreMode, t2),
+            new LuceneSlice(12, false, null, List.of(new PartialLeafReaderContext(leaf3, 20, 30)), scoreMode, t2)
         );
         // single driver
         {
@@ -139,7 +142,7 @@ public class LuceneSliceQueueTests extends ESTestCase {
                         false,
                         mock(ShardContext.class),
                         List.of(new PartialLeafReaderContext(leafContext, minDoc, maxDoc)),
-                        null,
+                        ScoreMode.COMPLETE_NO_SCORES,
                         null
                     );
                     sliceList.add(slice);

+ 4 - 6
x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushQueriesIT.java

@@ -367,12 +367,10 @@ public class PushQueriesIT extends ESRestTestCase {
             matchesList().item(matchesMap().entry("name", "test").entry("type", anyOf(equalTo("text"), equalTo("keyword")))),
             equalTo(found ? List.of(List.of(value)) : List.of())
         );
-        Matcher<String> luceneQueryMatcher = anyOf(
-            () -> Iterators.map(
-                luceneQueryOptions.iterator(),
-                (String s) -> equalTo(s.replaceAll("%value", value).replaceAll("%different_value", differentValue))
-            )
-        );
+        Matcher<String> luceneQueryMatcher = anyOf(() -> Iterators.map(luceneQueryOptions.iterator(), (String s) -> {
+            String q = s.replaceAll("%value", value).replaceAll("%different_value", differentValue);
+            return equalTo("ConstantScore(" + q + ")");
+        }));
 
         @SuppressWarnings("unchecked")
         List<Map<String, Object>> profiles = (List<Map<String, Object>>) ((Map<String, Object>) result.get("profile")).get("drivers");

+ 2 - 2
x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/RestEsqlIT.java

@@ -912,7 +912,7 @@ public class RestEsqlIT extends RestEsqlTestCase {
                 .entry("pages_emitted", greaterThan(0))
                 .entry("rows_emitted", greaterThan(0))
                 .entry("process_nanos", greaterThan(0))
-                .entry("processed_queries", List.of("*:*"))
+                .entry("processed_queries", List.of("ConstantScore(*:*)"))
                 .entry("partitioning_strategies", matchesMap().entry("rest-esql-test:0", "SHARD"));
             case "ValuesSourceReaderOperator" -> basicProfile().entry("pages_received", greaterThan(0))
                 .entry("pages_emitted", greaterThan(0))
@@ -950,7 +950,7 @@ public class RestEsqlIT extends RestEsqlTestCase {
                 .entry("slice_max", 0)
                 .entry("slice_min", 0)
                 .entry("process_nanos", greaterThan(0))
-                .entry("processed_queries", List.of("*:*"))
+                .entry("processed_queries", List.of("ConstantScore(*:*)"))
                 .entry("slice_index", 0);
             default -> throw new AssertionError("unexpected status: " + o);
         };

+ 1 - 1
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java

@@ -111,7 +111,7 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
                         assertThat(description, equalTo("data"));
                         LuceneSourceOperator.Status oStatus = (LuceneSourceOperator.Status) o.status();
                         assertThat(oStatus.processedSlices(), lessThanOrEqualTo(oStatus.totalSlices()));
-                        assertThat(oStatus.processedQueries(), equalTo(Set.of("*:*")));
+                        assertThat(oStatus.processedQueries(), equalTo(Set.of("ConstantScore(*:*)")));
                         assertThat(oStatus.processedShards(), equalTo(Set.of("test:0")));
                         assertThat(oStatus.sliceIndex(), lessThanOrEqualTo(oStatus.totalSlices()));
                         assertThat(oStatus.sliceMin(), greaterThanOrEqualTo(0));