Browse Source

Bulk loading enrich fields in ESQL (#106796)

Today, the enrich lookup processes input terms one by one: querying one 
term, then loading enrich fields for matching documents of that term
immediately. However, this approach can add significant overhead, such
as the driver run loop, creating/releasing many pages, and especially
excessive number of I/O seeks during loading _source, fields.

This PR accumulates matching documents up to 256 before loading enrich 
fields. The 256 limit is chosen to avoid a significant sorting cost and
long waits for cancellation.
Nhat Nguyen 1 year ago
parent
commit
96b513a7de

+ 5 - 0
docs/changelog/106796.yaml

@@ -0,0 +1,5 @@
+pr: 106796
+summary: Bulk loading enrich fields in ESQL
+area: ES|QL
+type: enhancement
+issues: []

+ 1 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java

@@ -270,6 +270,7 @@ public class EnrichLookupService {
             };
             var queryOperator = new EnrichQuerySourceOperator(
                 driverContext.blockFactory(),
+                EnrichQuerySourceOperator.DEFAULT_MAX_PAGE_SIZE,
                 queryList,
                 searchExecutionContext.getIndexReader()
             );

+ 86 - 49
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperator.java

@@ -15,7 +15,6 @@ import org.apache.lucene.search.LeafCollector;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.Scorable;
 import org.apache.lucene.search.ScoreMode;
-import org.apache.lucene.search.Weight;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.DocVector;
 import org.elasticsearch.compute.data.IntBlock;
@@ -36,14 +35,17 @@ final class EnrichQuerySourceOperator extends SourceOperator {
 
     private final BlockFactory blockFactory;
     private final QueryList queryList;
-    private int queryPosition;
-    private Weight weight = null;
+    private int queryPosition = -1;
     private final IndexReader indexReader;
-    private int leafIndex = 0;
     private final IndexSearcher searcher;
+    private final int maxPageSize;
 
-    EnrichQuerySourceOperator(BlockFactory blockFactory, QueryList queryList, IndexReader indexReader) {
+    // using smaller pages enables quick cancellation and reduces sorting costs
+    static final int DEFAULT_MAX_PAGE_SIZE = 256;
+
+    EnrichQuerySourceOperator(BlockFactory blockFactory, int maxPageSize, QueryList queryList, IndexReader indexReader) {
         this.blockFactory = blockFactory;
+        this.maxPageSize = maxPageSize;
         this.queryList = queryList;
         this.indexReader = indexReader;
         this.searcher = new IndexSearcher(indexReader);
@@ -59,62 +61,96 @@ final class EnrichQuerySourceOperator extends SourceOperator {
 
     @Override
     public Page getOutput() {
-        if (leafIndex == indexReader.leaves().size()) {
-            queryPosition++;
-            leafIndex = 0;
-            weight = null;
-        }
-        if (isFinished()) {
-            return null;
-        }
-        if (weight == null) {
-            Query query = queryList.getQuery(queryPosition);
-            if (query != null) {
-                try {
-                    query = searcher.rewrite(new ConstantScoreQuery(query));
-                    weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
-                } catch (IOException e) {
-                    throw new UncheckedIOException(e);
-                }
+        int estimatedSize = Math.min(maxPageSize, queryList.getPositionCount() - queryPosition);
+        IntVector.Builder positionsBuilder = null;
+        IntVector.Builder docsBuilder = null;
+        IntVector.Builder segmentsBuilder = null;
+        try {
+            positionsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
+            docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
+            if (indexReader.leaves().size() > 1) {
+                segmentsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
             }
+            int totalMatches = 0;
+            do {
+                Query query = nextQuery();
+                if (query == null) {
+                    assert isFinished();
+                    break;
+                }
+                query = searcher.rewrite(new ConstantScoreQuery(query));
+                final var weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
+                if (weight == null) {
+                    continue;
+                }
+                for (LeafReaderContext leaf : indexReader.leaves()) {
+                    var scorer = weight.bulkScorer(leaf);
+                    if (scorer == null) {
+                        continue;
+                    }
+                    final DocCollector collector = new DocCollector(docsBuilder);
+                    scorer.score(collector, leaf.reader().getLiveDocs());
+                    int matches = collector.matches;
+
+                    if (segmentsBuilder != null) {
+                        for (int i = 0; i < matches; i++) {
+                            segmentsBuilder.appendInt(leaf.ord);
+                        }
+                    }
+                    for (int i = 0; i < matches; i++) {
+                        positionsBuilder.appendInt(queryPosition);
+                    }
+                    totalMatches += matches;
+                }
+            } while (totalMatches < maxPageSize);
+
+            return buildPage(totalMatches, positionsBuilder, segmentsBuilder, docsBuilder);
+        } catch (IOException e) {
+            throw new UncheckedIOException(e);
+        } finally {
+            Releasables.close(docsBuilder, segmentsBuilder, positionsBuilder);
         }
+    }
+
+    Page buildPage(int positions, IntVector.Builder positionsBuilder, IntVector.Builder segmentsBuilder, IntVector.Builder docsBuilder) {
+        IntVector positionsVector = null;
+        IntVector shardsVector = null;
+        IntVector segmentsVector = null;
+        IntVector docsVector = null;
+        Page page = null;
         try {
-            return queryOneLeaf(weight, leafIndex++);
-        } catch (IOException ex) {
-            throw new UncheckedIOException(ex);
+            positionsVector = positionsBuilder.build();
+            shardsVector = blockFactory.newConstantIntVector(0, positions);
+            if (segmentsBuilder == null) {
+                segmentsVector = blockFactory.newConstantIntVector(0, positions);
+            } else {
+                segmentsVector = segmentsBuilder.build();
+            }
+            docsVector = docsBuilder.build();
+            page = new Page(new DocVector(shardsVector, segmentsVector, docsVector, null).asBlock(), positionsVector.asBlock());
+        } finally {
+            if (page == null) {
+                Releasables.close(positionsBuilder, segmentsVector, docsBuilder, positionsVector, shardsVector, docsVector);
+            }
         }
+        return page;
     }
 
-    private Page queryOneLeaf(Weight weight, int leafIndex) throws IOException {
-        if (weight == null) {
-            return null;
-        }
-        LeafReaderContext leafReaderContext = indexReader.leaves().get(leafIndex);
-        var scorer = weight.bulkScorer(leafReaderContext);
-        if (scorer == null) {
-            return null;
-        }
-        IntVector docs = null, segments = null, shards = null, positions = null;
-        boolean success = false;
-        try (IntVector.Builder docsBuilder = blockFactory.newIntVectorBuilder(1)) {
-            scorer.score(new DocCollector(docsBuilder), leafReaderContext.reader().getLiveDocs());
-            docs = docsBuilder.build();
-            final int positionCount = docs.getPositionCount();
-            segments = blockFactory.newConstantIntVector(leafIndex, positionCount);
-            shards = blockFactory.newConstantIntVector(0, positionCount);
-            positions = blockFactory.newConstantIntVector(queryPosition, positionCount);
-            Page page = new Page(new DocVector(shards, segments, docs, true).asBlock(), positions.asBlock());
-            success = true;
-            return page;
-        } finally {
-            if (success == false) {
-                Releasables.close(docs, shards, segments, positions);
+    private Query nextQuery() {
+        ++queryPosition;
+        while (isFinished() == false) {
+            Query query = queryList.getQuery(queryPosition);
+            if (query != null) {
+                return query;
             }
+            ++queryPosition;
         }
+        return null;
     }
 
     private static class DocCollector implements LeafCollector {
         final IntVector.Builder docIds;
+        int matches = 0;
 
         DocCollector(IntVector.Builder docIds) {
             this.docIds = docIds;
@@ -127,6 +163,7 @@ final class EnrichQuerySourceOperator extends SourceOperator {
 
         @Override
         public void collect(int doc) {
+            ++matches;
             docIds.appendInt(doc);
         }
     }

+ 24 - 55
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperatorTests.java

@@ -48,6 +48,7 @@ import java.util.Set;
 
 import static org.elasticsearch.xpack.ql.type.DataTypes.KEYWORD;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
 import static org.mockito.Mockito.mock;
 
 public class EnrichQuerySourceOperatorTests extends ESTestCase {
@@ -120,60 +121,26 @@ public class EnrichQuerySourceOperatorTests extends ESTestCase {
         // 3 -> [] -> []
         // 4 -> [a1] -> [3]
         // 5 -> [] -> []
-        EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, queryList, reader);
-        {
-            Page p0 = queryOperator.getOutput();
-            assertNotNull(p0);
-            assertThat(p0.getPositionCount(), equalTo(2));
-            IntVector docs = getDocVector(p0, 0);
-            assertThat(docs.getInt(0), equalTo(1));
-            assertThat(docs.getInt(1), equalTo(4));
-            Block positions = p0.getBlock(1);
-            assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(0));
-            assertThat(BlockUtils.toJavaObject(positions, 1), equalTo(0));
-            p0.releaseBlocks();
-        }
-        {
-            Page p1 = queryOperator.getOutput();
-            assertNotNull(p1);
-            assertThat(p1.getPositionCount(), equalTo(3));
-            IntVector docs = getDocVector(p1, 0);
-            assertThat(docs.getInt(0), equalTo(0));
-            assertThat(docs.getInt(1), equalTo(1));
-            assertThat(docs.getInt(2), equalTo(2));
-            Block positions = p1.getBlock(1);
-            assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(1));
-            assertThat(BlockUtils.toJavaObject(positions, 1), equalTo(1));
-            assertThat(BlockUtils.toJavaObject(positions, 2), equalTo(1));
-            p1.releaseBlocks();
-        }
-        {
-            Page p2 = queryOperator.getOutput();
-            assertNull(p2);
-        }
-        {
-            Page p3 = queryOperator.getOutput();
-            assertNull(p3);
-        }
-        {
-            Page p4 = queryOperator.getOutput();
-            assertNotNull(p4);
-            assertThat(p4.getPositionCount(), equalTo(1));
-            IntVector docs = getDocVector(p4, 0);
-            assertThat(docs.getInt(0), equalTo(3));
-            Block positions = p4.getBlock(1);
-            assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(4));
-            p4.releaseBlocks();
-        }
-        {
-            Page p5 = queryOperator.getOutput();
-            assertNull(p5);
-        }
-        {
-            assertFalse(queryOperator.isFinished());
-            Page p6 = queryOperator.getOutput();
-            assertNull(p6);
-        }
+        EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, 128, queryList, reader);
+        Page p0 = queryOperator.getOutput();
+        assertNotNull(p0);
+        assertThat(p0.getPositionCount(), equalTo(6));
+        IntVector docs = getDocVector(p0, 0);
+        assertThat(docs.getInt(0), equalTo(1));
+        assertThat(docs.getInt(1), equalTo(4));
+        assertThat(docs.getInt(2), equalTo(0));
+        assertThat(docs.getInt(3), equalTo(1));
+        assertThat(docs.getInt(4), equalTo(2));
+        assertThat(docs.getInt(5), equalTo(3));
+
+        Block positions = p0.getBlock(1);
+        assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(0));
+        assertThat(BlockUtils.toJavaObject(positions, 1), equalTo(0));
+        assertThat(BlockUtils.toJavaObject(positions, 2), equalTo(1));
+        assertThat(BlockUtils.toJavaObject(positions, 3), equalTo(1));
+        assertThat(BlockUtils.toJavaObject(positions, 4), equalTo(1));
+        assertThat(BlockUtils.toJavaObject(positions, 5), equalTo(4));
+        p0.releaseBlocks();
         assertTrue(queryOperator.isFinished());
         IOUtils.close(reader, dir, inputTerms);
     }
@@ -220,13 +187,15 @@ public class EnrichQuerySourceOperatorTests extends ESTestCase {
         }
         MappedFieldType uidField = new KeywordFieldMapper.KeywordFieldType("uid");
         var queryList = QueryList.termQueryList(uidField, mock(SearchExecutionContext.class), inputTerms, KEYWORD);
-        EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, queryList, reader);
+        int maxPageSize = between(1, 256);
+        EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, maxPageSize, queryList, reader);
         Map<Integer, Set<Integer>> actualPositions = new HashMap<>();
         while (queryOperator.isFinished() == false) {
             Page page = queryOperator.getOutput();
             if (page != null) {
                 IntVector docs = getDocVector(page, 0);
                 IntBlock positions = page.getBlock(1);
+                assertThat(positions.getPositionCount(), lessThanOrEqualTo(maxPageSize));
                 for (int i = 0; i < page.getPositionCount(); i++) {
                     int doc = docs.getInt(i);
                     int position = positions.getInt(i);