Browse Source

Refactor LuceneQueryEvaluator to use Blocks instead of Vectors (#133246)

Kathleen DeRusso 1 month ago
parent
commit
2b65b0ff04

+ 30 - 22
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java

@@ -24,7 +24,6 @@ import org.elasticsearch.compute.data.DocBlock;
 import org.elasticsearch.compute.data.DocVector;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.Page;
-import org.elasticsearch.compute.data.Vector;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 
@@ -44,12 +43,12 @@ import java.util.function.Consumer;
  * It's much faster to push queries to the {@link LuceneSourceOperator} or the like, but sometimes this isn't possible. So
  * this class is here to save the day.
  */
-public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements Releasable {
+public abstract class LuceneQueryEvaluator<T extends Block.Builder> implements Releasable {
 
     public record ShardConfig(Query query, IndexSearcher searcher) {}
 
     private final BlockFactory blockFactory;
-    private final ShardConfig[] shards;
+    protected final ShardConfig[] shards;
 
     private final List<ShardState> perShardState;
 
@@ -67,9 +66,9 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
         DocVector docs = (DocVector) block.asVector();
         try {
             if (docs.singleSegmentNonDecreasing()) {
-                return evalSingleSegmentNonDecreasing(docs).asBlock();
+                return evalSingleSegmentNonDecreasing(docs);
             } else {
-                return evalSlow(docs).asBlock();
+                return evalSlow(docs);
             }
         } catch (IOException e) {
             throw new UncheckedIOException(e);
@@ -106,15 +105,15 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
      *     common.
      * </p>
      */
-    private Vector evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
+    private Block evalSingleSegmentNonDecreasing(DocVector docs) throws IOException {
         ShardState shardState = shardState(docs.shards().getInt(0));
         SegmentState segmentState = shardState.segmentState(docs.segments().getInt(0));
         int min = docs.docs().getInt(0);
         int max = docs.docs().getInt(docs.getPositionCount() - 1);
         int length = max - min + 1;
-        try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) {
+        try (T scoreBuilder = createBlockBuilder(blockFactory, docs.getPositionCount())) {
             if (length == docs.getPositionCount() && length > 1) {
-                return segmentState.scoreDense(scoreBuilder, min, max);
+                return segmentState.scoreDense(scoreBuilder, min, max, docs.getPositionCount());
             }
             return segmentState.scoreSparse(scoreBuilder, docs.docs());
         }
@@ -134,13 +133,13 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
      *     the order that the {@link DocVector} came in.
      * </p>
      */
-    private Vector evalSlow(DocVector docs) throws IOException {
+    private Block evalSlow(DocVector docs) throws IOException {
         int[] map = docs.shardSegmentDocMapForwards();
         // Clear any state flags from the previous run
         int prevShard = -1;
         int prevSegment = -1;
         SegmentState segmentState = null;
-        try (T scoreBuilder = createVectorBuilder(blockFactory, docs.getPositionCount())) {
+        try (T scoreBuilder = createBlockBuilder(blockFactory, docs.getPositionCount())) {
             for (int i = 0; i < docs.getPositionCount(); i++) {
                 int shard = docs.shards().getInt(docs.shards().getInt(map[i]));
                 int segment = docs.segments().getInt(map[i]);
@@ -156,7 +155,7 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
                     segmentState.scoreSingleDocWithScorer(scoreBuilder, docs.docs().getInt(map[i]));
                 }
             }
-            try (Vector outOfOrder = scoreBuilder.build()) {
+            try (Block outOfOrder = scoreBuilder.build()) {
                 return outOfOrder.filter(docs.shardSegmentDocMapBackwards());
             }
         }
@@ -247,9 +246,9 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
          * Score a range using the {@link BulkScorer}. This should be faster
          * than using {@link #scoreSparse} for dense doc ids.
          */
-        Vector scoreDense(T scoreBuilder, int min, int max) throws IOException {
+        Block scoreDense(T scoreBuilder, int min, int max, int positionCount) throws IOException {
             if (noMatch) {
-                return createNoMatchVector(blockFactory, max - min + 1);
+                return createNoMatchBlock(blockFactory, max - min + 1);
             }
             if (bulkScorer == null ||  // The bulkScorer wasn't initialized
                 Thread.currentThread() != bulkScorerThread // The bulkScorer was initialized on a different thread
@@ -258,7 +257,7 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
                 bulkScorer = weight.bulkScorer(ctx);
                 if (bulkScorer == null) {
                     noMatch = true;
-                    return createNoMatchVector(blockFactory, max - min + 1);
+                    return createNoMatchBlock(blockFactory, positionCount);
                 }
             }
             try (
@@ -266,11 +265,14 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
                     min,
                     max,
                     scoreBuilder,
+                    ctx,
                     LuceneQueryEvaluator.this::appendNoMatch,
-                    LuceneQueryEvaluator.this::appendMatch
+                    LuceneQueryEvaluator.this::appendMatch,
+                    weight.getQuery()
                 )
             ) {
                 bulkScorer.score(collector, ctx.reader().getLiveDocs(), min, max + 1);
+                collector.finish();
                 return collector.build();
             }
         }
@@ -279,10 +281,10 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
          * Score a vector of doc ids using {@link Scorer}. If you have a dense range of
          * doc ids it'd be faster to use {@link #scoreDense}.
          */
-        Vector scoreSparse(T scoreBuilder, IntVector docs) throws IOException {
+        Block scoreSparse(T scoreBuilder, IntVector docs) throws IOException {
             initScorer(docs.getInt(0));
             if (noMatch) {
-                return createNoMatchVector(blockFactory, docs.getPositionCount());
+                return createNoMatchBlock(blockFactory, docs.getPositionCount());
             }
             for (int i = 0; i < docs.getPositionCount(); i++) {
                 scoreSingleDocWithScorer(scoreBuilder, docs.getInt(i));
@@ -326,11 +328,13 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
      * doc ids are sent to {@link LeafCollector#collect(int)} in ascending order
      * which isn't documented, but @jpountz swears is true.
      */
-    static class DenseCollector<U extends Vector.Builder> implements LeafCollector, Releasable {
+    static class DenseCollector<U extends Block.Builder> implements LeafCollector, Releasable {
         private final U scoreBuilder;
         private final int max;
+        private final LeafReaderContext leafReaderContext;
         private final Consumer<U> appendNoMatch;
         private final CheckedBiConsumer<U, Scorable, IOException> appendMatch;
+        private final Query query;
 
         private Scorable scorer;
         int next;
@@ -339,14 +343,18 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
             int min,
             int max,
             U scoreBuilder,
+            LeafReaderContext leafReaderContext,
             Consumer<U> appendNoMatch,
-            CheckedBiConsumer<U, Scorable, IOException> appendMatch
+            CheckedBiConsumer<U, Scorable, IOException> appendMatch,
+            Query query
         ) {
             this.scoreBuilder = scoreBuilder;
             this.max = max;
             next = min;
+            this.leafReaderContext = leafReaderContext;
             this.appendNoMatch = appendNoMatch;
             this.appendMatch = appendMatch;
+            this.query = query;
         }
 
         @Override
@@ -362,7 +370,7 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
             appendMatch.accept(scoreBuilder, scorer);
         }
 
-        public Vector build() {
+        public Block build() {
             return scoreBuilder.build();
         }
 
@@ -387,12 +395,12 @@ public abstract class LuceneQueryEvaluator<T extends Vector.Builder> implements
     /**
      * Creates a vector where all positions correspond to elements that don't match the query
      */
-    protected abstract Vector createNoMatchVector(BlockFactory blockFactory, int size);
+    protected abstract Block createNoMatchBlock(BlockFactory blockFactory, int size);
 
     /**
      * Creates the corresponding vector builder to store the results of evaluating the query
      */
-    protected abstract T createVectorBuilder(BlockFactory blockFactory, int size);
+    protected abstract T createBlockBuilder(BlockFactory blockFactory, int size);
 
     /**
      * Appends a matching result to a builder created by @link createVectorBuilder}

+ 8 - 10
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluator.java

@@ -12,9 +12,9 @@ import org.apache.lucene.search.Scorable;
 import org.apache.lucene.search.ScoreMode;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BooleanBlock;
 import org.elasticsearch.compute.data.BooleanVector;
 import org.elasticsearch.compute.data.Page;
-import org.elasticsearch.compute.data.Vector;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.compute.operator.EvalOperator;
 
@@ -26,9 +26,7 @@ import java.io.IOException;
  * a {@link BooleanVector}.
  * @see LuceneQueryScoreEvaluator
  */
-public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator<BooleanVector.Builder>
-    implements
-        EvalOperator.ExpressionEvaluator {
+public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator<BooleanBlock.Builder> implements EvalOperator.ExpressionEvaluator {
 
     LuceneQueryExpressionEvaluator(BlockFactory blockFactory, ShardConfig[] shards) {
         super(blockFactory, shards);
@@ -45,22 +43,22 @@ public class LuceneQueryExpressionEvaluator extends LuceneQueryEvaluator<Boolean
     }
 
     @Override
-    protected Vector createNoMatchVector(BlockFactory blockFactory, int size) {
-        return blockFactory.newConstantBooleanVector(false, size);
+    protected Block createNoMatchBlock(BlockFactory blockFactory, int size) {
+        return blockFactory.newConstantBooleanBlockWith(false, size);
     }
 
     @Override
-    protected BooleanVector.Builder createVectorBuilder(BlockFactory blockFactory, int size) {
-        return blockFactory.newBooleanVectorFixedBuilder(size);
+    protected BooleanBlock.Builder createBlockBuilder(BlockFactory blockFactory, int size) {
+        return blockFactory.newBooleanBlockBuilder(size);
     }
 
     @Override
-    protected void appendNoMatch(BooleanVector.Builder builder) {
+    protected void appendNoMatch(BooleanBlock.Builder builder) {
         builder.appendBoolean(false);
     }
 
     @Override
-    protected void appendMatch(BooleanVector.Builder builder, Scorable scorer) throws IOException {
+    protected void appendMatch(BooleanBlock.Builder builder, Scorable scorer) throws IOException {
         builder.appendBoolean(true);
     }
 

+ 7 - 8
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryScoreEvaluator.java

@@ -14,7 +14,6 @@ import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.DoubleVector;
 import org.elasticsearch.compute.data.Page;
-import org.elasticsearch.compute.data.Vector;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.compute.operator.ScoreOperator;
 
@@ -27,7 +26,7 @@ import java.io.IOException;
  * Elements that don't match will have a score of {@link #NO_MATCH_SCORE}.
  * @see LuceneQueryScoreEvaluator
  */
-public class LuceneQueryScoreEvaluator extends LuceneQueryEvaluator<DoubleVector.Builder> implements ScoreOperator.ExpressionScorer {
+public class LuceneQueryScoreEvaluator extends LuceneQueryEvaluator<DoubleBlock.Builder> implements ScoreOperator.ExpressionScorer {
 
     public static final double NO_MATCH_SCORE = 0.0;
 
@@ -46,22 +45,22 @@ public class LuceneQueryScoreEvaluator extends LuceneQueryEvaluator<DoubleVector
     }
 
     @Override
-    protected Vector createNoMatchVector(BlockFactory blockFactory, int size) {
-        return blockFactory.newConstantDoubleVector(NO_MATCH_SCORE, size);
+    protected DoubleBlock createNoMatchBlock(BlockFactory blockFactory, int size) {
+        return blockFactory.newConstantDoubleBlockWith(NO_MATCH_SCORE, size);
     }
 
     @Override
-    protected DoubleVector.Builder createVectorBuilder(BlockFactory blockFactory, int size) {
-        return blockFactory.newDoubleVectorFixedBuilder(size);
+    protected DoubleBlock.Builder createBlockBuilder(BlockFactory blockFactory, int size) {
+        return blockFactory.newDoubleBlockBuilder(size);
     }
 
     @Override
-    protected void appendNoMatch(DoubleVector.Builder builder) {
+    protected void appendNoMatch(DoubleBlock.Builder builder) {
         builder.appendDouble(NO_MATCH_SCORE);
     }
 
     @Override
-    protected void appendMatch(DoubleVector.Builder builder, Scorable scorer) throws IOException {
+    protected void appendMatch(DoubleBlock.Builder builder, Scorable scorer) throws IOException {
         builder.appendDouble(scorer.score());
     }
 

+ 4 - 5
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluatorTests.java

@@ -25,14 +25,13 @@ import org.apache.lucene.tests.store.BaseDirectoryWrapper;
 import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.compute.OperatorTests;
+import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.BytesRefBlock;
-import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.DocBlock;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.Page;
-import org.elasticsearch.compute.data.Vector;
 import org.elasticsearch.compute.lucene.read.ValuesSourceReaderOperator;
 import org.elasticsearch.compute.operator.Driver;
 import org.elasticsearch.compute.operator.DriverContext;
@@ -59,7 +58,7 @@ import static org.hamcrest.Matchers.equalTo;
 /**
  * Base class for testing Lucene query evaluators.
  */
-public abstract class LuceneQueryEvaluatorTests<T extends Vector, U extends Vector.Builder> extends ComputeTestCase {
+public abstract class LuceneQueryEvaluatorTests<T extends Block, U extends Block.Builder> extends ComputeTestCase {
 
     private static final String FIELD = "g";
 
@@ -168,9 +167,9 @@ public abstract class LuceneQueryEvaluatorTests<T extends Vector, U extends Vect
         int matchCount = 0;
         for (Page page : results) {
             int initialBlockIndex = termsBlockIndex(page);
-            BytesRefVector terms = page.<BytesRefBlock>getBlock(initialBlockIndex).asVector();
+            BytesRefBlock terms = page.<BytesRefBlock>getBlock(initialBlockIndex);
             @SuppressWarnings("unchecked")
-            T resultVector = (T) page.getBlock(resultsBlockIndex(page)).asVector();
+            T resultVector = (T) page.getBlock(resultsBlockIndex(page));
             for (int i = 0; i < page.getPositionCount(); i++) {
                 BytesRef termAtPosition = terms.getBytesRef(i, new BytesRef());
                 boolean isMatch = matching.contains(termAtPosition.utf8ToString());

+ 9 - 7
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryExpressionEvaluatorTests.java

@@ -9,7 +9,7 @@ package org.elasticsearch.compute.lucene;
 
 import org.apache.lucene.search.Scorable;
 import org.elasticsearch.compute.data.BlockFactory;
-import org.elasticsearch.compute.data.BooleanVector;
+import org.elasticsearch.compute.data.BooleanBlock;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.lucene.LuceneQueryEvaluator.DenseCollector;
 import org.elasticsearch.compute.operator.EvalOperator;
@@ -17,18 +17,20 @@ import org.elasticsearch.compute.operator.Operator;
 
 import static org.hamcrest.Matchers.equalTo;
 
-public class LuceneQueryExpressionEvaluatorTests extends LuceneQueryEvaluatorTests<BooleanVector, BooleanVector.Builder> {
+public class LuceneQueryExpressionEvaluatorTests extends LuceneQueryEvaluatorTests<BooleanBlock, BooleanBlock.Builder> {
 
     private final boolean useScoring = randomBoolean();
 
     @Override
-    protected DenseCollector<BooleanVector.Builder> createDenseCollector(int min, int max) {
+    protected DenseCollector<BooleanBlock.Builder> createDenseCollector(int min, int max) {
         return new LuceneQueryEvaluator.DenseCollector<>(
             min,
             max,
-            blockFactory().newBooleanVectorFixedBuilder(max - min + 1),
+            blockFactory().newBooleanBlockBuilder(max - min + 1),
+            null,
             b -> b.appendBoolean(false),
-            (b, s) -> b.appendBoolean(true)
+            (b, s) -> b.appendBoolean(true),
+            null
         );
     }
 
@@ -54,12 +56,12 @@ public class LuceneQueryExpressionEvaluatorTests extends LuceneQueryEvaluatorTes
     }
 
     @Override
-    protected void assertCollectedResultMatch(BooleanVector resultVector, int position, boolean isMatch) {
+    protected void assertCollectedResultMatch(BooleanBlock resultVector, int position, boolean isMatch) {
         assertThat(resultVector.getBoolean(position), equalTo(isMatch));
     }
 
     @Override
-    protected void assertTermResultMatch(BooleanVector resultVector, int position, boolean isMatch) {
+    protected void assertTermResultMatch(BooleanBlock resultVector, int position, boolean isMatch) {
         assertThat(resultVector.getBoolean(position), equalTo(isMatch));
     }
 }

+ 9 - 7
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneQueryScoreEvaluatorTests.java

@@ -9,7 +9,7 @@ package org.elasticsearch.compute.lucene;
 
 import org.apache.lucene.search.Scorable;
 import org.elasticsearch.compute.data.BlockFactory;
-import org.elasticsearch.compute.data.DoubleVector;
+import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.operator.Operator;
 import org.elasticsearch.compute.operator.ScoreOperator;
@@ -20,19 +20,21 @@ import static org.elasticsearch.compute.lucene.LuceneQueryScoreEvaluator.NO_MATC
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 
-public class LuceneQueryScoreEvaluatorTests extends LuceneQueryEvaluatorTests<DoubleVector, DoubleVector.Builder> {
+public class LuceneQueryScoreEvaluatorTests extends LuceneQueryEvaluatorTests<DoubleBlock, DoubleBlock.Builder> {
 
     private static final float TEST_SCORE = 1.5f;
     private static final Double DEFAULT_SCORE = 1.0;
 
     @Override
-    protected LuceneQueryEvaluator.DenseCollector<DoubleVector.Builder> createDenseCollector(int min, int max) {
+    protected LuceneQueryEvaluator.DenseCollector<DoubleBlock.Builder> createDenseCollector(int min, int max) {
         return new LuceneQueryEvaluator.DenseCollector<>(
             min,
             max,
-            blockFactory().newDoubleVectorFixedBuilder(max - min + 1),
+            blockFactory().newDoubleBlockBuilder(max - min + 1),
+            null,
             b -> b.appendDouble(NO_MATCH_SCORE),
-            (b, s) -> b.appendDouble(s.score())
+            (b, s) -> b.appendDouble(s.score()),
+            null
         );
     }
 
@@ -63,7 +65,7 @@ public class LuceneQueryScoreEvaluatorTests extends LuceneQueryEvaluatorTests<Do
     }
 
     @Override
-    protected void assertCollectedResultMatch(DoubleVector resultVector, int position, boolean isMatch) {
+    protected void assertCollectedResultMatch(DoubleBlock resultVector, int position, boolean isMatch) {
         if (isMatch) {
             assertThat(resultVector.getDouble(position), equalTo((double) TEST_SCORE));
         } else {
@@ -73,7 +75,7 @@ public class LuceneQueryScoreEvaluatorTests extends LuceneQueryEvaluatorTests<Do
     }
 
     @Override
-    protected void assertTermResultMatch(DoubleVector resultVector, int position, boolean isMatch) {
+    protected void assertTermResultMatch(DoubleBlock resultVector, int position, boolean isMatch) {
         if (isMatch) {
             assertThat(resultVector.getDouble(position), greaterThan(DEFAULT_SCORE));
         } else {