浏览代码

ESQL: Delay finding field load infrastructure (#103821)

This optimizes loading fields across many, many indices by resolving the
field loading infrastructure when it's first needed rather than up
front. This speeds things up because, if you are loading from many many
shards, you often don't need to set up the field loading infrastructure
for all shards at all - often you'll just need to set it up for a couple
of the shards.
Nik Everett 1 年之前
父节点
当前提交
fac60e5803
共有 14 个文件被更改,包括 567 次插入266 次删除
  1. 0 1
      benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/MultivalueDedupeBenchmark.java
  2. 50 25
      benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java
  3. 5 0
      docs/changelog/103821.yaml
  4. 1 0
      docs/reference/esql/functions/types/add.asciidoc
  5. 30 43
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/BlockReaderFactories.java
  6. 192 127
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java
  7. 7 6
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java
  8. 2 2
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java
  9. 1 1
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java
  10. 5 4
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java
  11. 3 2
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperatorTests.java
  12. 241 49
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java
  13. 15 3
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java
  14. 15 3
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java

+ 0 - 1
benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/MultivalueDedupeBenchmark.java

@@ -45,7 +45,6 @@ import java.util.concurrent.TimeUnit;
 @State(Scope.Thread)
 @State(Scope.Thread)
 @Fork(1)
 @Fork(1)
 public class MultivalueDedupeBenchmark {
 public class MultivalueDedupeBenchmark {
-    private static final BigArrays BIG_ARRAYS = BigArrays.NON_RECYCLING_INSTANCE;  // TODO real big arrays?
     private static final BlockFactory blockFactory = BlockFactory.getInstance(
     private static final BlockFactory blockFactory = BlockFactory.getInstance(
         new NoopCircuitBreaker("noop"),
         new NoopCircuitBreaker("noop"),
         BigArrays.NON_RECYCLING_INSTANCE
         BigArrays.NON_RECYCLING_INSTANCE

+ 50 - 25
benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java

@@ -31,6 +31,7 @@ import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.DocVector;
 import org.elasticsearch.compute.data.DocVector;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.DoubleVector;
 import org.elasticsearch.compute.data.DoubleVector;
+import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.LongBlock;
@@ -96,8 +97,12 @@ public class ValuesSourceReaderBenchmark {
                     for (String name : ValuesSourceReaderBenchmark.class.getField("name").getAnnotationsByType(Param.class)[0].value()) {
                     for (String name : ValuesSourceReaderBenchmark.class.getField("name").getAnnotationsByType(Param.class)[0].value()) {
                         benchmark.layout = layout;
                         benchmark.layout = layout;
                         benchmark.name = name;
                         benchmark.name = name;
-                        benchmark.setupPages();
-                        benchmark.benchmark();
+                        try {
+                            benchmark.setupPages();
+                            benchmark.benchmark();
+                        } catch (Exception e) {
+                            throw new AssertionError("error initializing [" + layout + "/" + name + "]", e);
+                        }
                     }
                     }
                 }
                 }
             } finally {
             } finally {
@@ -111,11 +116,11 @@ public class ValuesSourceReaderBenchmark {
     private static List<ValuesSourceReaderOperator.FieldInfo> fields(String name) {
     private static List<ValuesSourceReaderOperator.FieldInfo> fields(String name) {
         return switch (name) {
         return switch (name) {
             case "3_stored_keywords" -> List.of(
             case "3_stored_keywords" -> List.of(
-                new ValuesSourceReaderOperator.FieldInfo("keyword_1", List.of(blockLoader("stored_keyword_1"))),
-                new ValuesSourceReaderOperator.FieldInfo("keyword_2", List.of(blockLoader("stored_keyword_2"))),
-                new ValuesSourceReaderOperator.FieldInfo("keyword_3", List.of(blockLoader("stored_keyword_3")))
+                new ValuesSourceReaderOperator.FieldInfo("keyword_1", ElementType.BYTES_REF, shardIdx -> blockLoader("stored_keyword_1")),
+                new ValuesSourceReaderOperator.FieldInfo("keyword_2", ElementType.BYTES_REF, shardIdx -> blockLoader("stored_keyword_2")),
+                new ValuesSourceReaderOperator.FieldInfo("keyword_3", ElementType.BYTES_REF, shardIdx -> blockLoader("stored_keyword_3"))
             );
             );
-            default -> List.of(new ValuesSourceReaderOperator.FieldInfo(name, List.of(blockLoader(name))));
+            default -> List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType(name), shardIdx -> blockLoader(name)));
         };
         };
     }
     }
 
 
@@ -125,29 +130,38 @@ public class ValuesSourceReaderBenchmark {
         STORED;
         STORED;
     }
     }
 
 
-    private static BlockLoader blockLoader(String name) {
-        Where where = Where.DOC_VALUES;
-        if (name.startsWith("stored_")) {
-            name = name.substring("stored_".length());
-            where = Where.STORED;
-        } else if (name.startsWith("source_")) {
-            name = name.substring("source_".length());
-            where = Where.SOURCE;
-        }
+    private static ElementType elementType(String name) {
+        name = WhereAndBaseName.fromName(name).name;
         switch (name) {
         switch (name) {
             case "long":
             case "long":
-                return numericBlockLoader(name, where, NumberFieldMapper.NumberType.LONG);
+                return ElementType.LONG;
             case "int":
             case "int":
-                return numericBlockLoader(name, where, NumberFieldMapper.NumberType.INTEGER);
+                return ElementType.INT;
             case "double":
             case "double":
-                return numericBlockLoader(name, where, NumberFieldMapper.NumberType.DOUBLE);
-            case "keyword":
-                name = "keyword_1";
+                return ElementType.DOUBLE;
         }
         }
         if (name.startsWith("keyword")) {
         if (name.startsWith("keyword")) {
+            return ElementType.BYTES_REF;
+        }
+        throw new UnsupportedOperationException("no element type for [" + name + "]");
+    }
+
+    private static BlockLoader blockLoader(String name) {
+        WhereAndBaseName w = WhereAndBaseName.fromName(name);
+        switch (w.name) {
+            case "long":
+                return numericBlockLoader(w, NumberFieldMapper.NumberType.LONG);
+            case "int":
+                return numericBlockLoader(w, NumberFieldMapper.NumberType.INTEGER);
+            case "double":
+                return numericBlockLoader(w, NumberFieldMapper.NumberType.DOUBLE);
+            case "keyword":
+                w = new WhereAndBaseName(w.where, "keyword_1");
+        }
+        if (w.name.startsWith("keyword")) {
             boolean syntheticSource = false;
             boolean syntheticSource = false;
             FieldType ft = new FieldType(KeywordFieldMapper.Defaults.FIELD_TYPE);
             FieldType ft = new FieldType(KeywordFieldMapper.Defaults.FIELD_TYPE);
-            switch (where) {
+            switch (w.where) {
                 case DOC_VALUES:
                 case DOC_VALUES:
                     break;
                     break;
                 case SOURCE:
                 case SOURCE:
@@ -161,7 +175,7 @@ public class ValuesSourceReaderBenchmark {
             }
             }
             ft.freeze();
             ft.freeze();
             return new KeywordFieldMapper.KeywordFieldType(
             return new KeywordFieldMapper.KeywordFieldType(
-                name,
+                w.name,
                 ft,
                 ft,
                 Lucene.KEYWORD_ANALYZER,
                 Lucene.KEYWORD_ANALYZER,
                 Lucene.KEYWORD_ANALYZER,
                 Lucene.KEYWORD_ANALYZER,
@@ -193,10 +207,21 @@ public class ValuesSourceReaderBenchmark {
         throw new IllegalArgumentException("can't read [" + name + "]");
         throw new IllegalArgumentException("can't read [" + name + "]");
     }
     }
 
 
-    private static BlockLoader numericBlockLoader(String name, Where where, NumberFieldMapper.NumberType numberType) {
+    private record WhereAndBaseName(Where where, String name) {
+        static WhereAndBaseName fromName(String name) {
+            if (name.startsWith("stored_")) {
+                return new WhereAndBaseName(Where.STORED, name.substring("stored_".length()));
+            } else if (name.startsWith("source_")) {
+                return new WhereAndBaseName(Where.SOURCE, name.substring("source_".length()));
+            }
+            return new WhereAndBaseName(Where.DOC_VALUES, name);
+        }
+    }
+
+    private static BlockLoader numericBlockLoader(WhereAndBaseName w, NumberFieldMapper.NumberType numberType) {
         boolean stored = false;
         boolean stored = false;
         boolean docValues = true;
         boolean docValues = true;
-        switch (where) {
+        switch (w.where) {
             case DOC_VALUES:
             case DOC_VALUES:
                 break;
                 break;
             case SOURCE:
             case SOURCE:
@@ -207,7 +232,7 @@ public class ValuesSourceReaderBenchmark {
                 throw new UnsupportedOperationException();
                 throw new UnsupportedOperationException();
         }
         }
         return new NumberFieldMapper.NumberFieldType(
         return new NumberFieldMapper.NumberFieldType(
-            name,
+            w.name,
             numberType,
             numberType,
             true,
             true,
             stored,
             stored,

+ 5 - 0
docs/changelog/103821.yaml

@@ -0,0 +1,5 @@
+pr: 103821
+summary: "ESQL: Delay finding field load infrastructure"
+area: ES|QL
+type: enhancement
+issues: []

+ 1 - 0
docs/reference/esql/functions/types/add.asciidoc

@@ -2,6 +2,7 @@
 |===
 |===
 lhs | rhs | result
 lhs | rhs | result
 date_period | date_period | date_period
 date_period | date_period | date_period
+date_period | datetime | datetime
 datetime | date_period | datetime
 datetime | date_period | datetime
 datetime | time_duration | datetime
 datetime | time_duration | datetime
 double | double | double
 double | double | double

+ 30 - 43
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/BlockReaderFactories.java

@@ -11,11 +11,8 @@ import org.elasticsearch.common.logging.HeaderWarning;
 import org.elasticsearch.index.mapper.BlockLoader;
 import org.elasticsearch.index.mapper.BlockLoader;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.index.query.SearchExecutionContext;
-import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.lookup.SearchLookup;
 import org.elasticsearch.search.lookup.SearchLookup;
 
 
-import java.util.ArrayList;
-import java.util.List;
 import java.util.Set;
 import java.util.Set;
 
 
 /**
 /**
@@ -26,56 +23,46 @@ public final class BlockReaderFactories {
 
 
     /**
     /**
      * Resolves *how* ESQL loads field values.
      * Resolves *how* ESQL loads field values.
-     * @param searchContexts a search context per search index we're loading
-     *                       field from
+     * @param ctx a search context for the index we're loading field from
      * @param fieldName the name of the field to load
      * @param fieldName the name of the field to load
      * @param asUnsupportedSource should the field be loaded as "unsupported"?
      * @param asUnsupportedSource should the field be loaded as "unsupported"?
      *                            These will always have {@code null} values
      *                            These will always have {@code null} values
      */
      */
-    public static List<BlockLoader> loaders(List<SearchContext> searchContexts, String fieldName, boolean asUnsupportedSource) {
-        List<BlockLoader> loaders = new ArrayList<>(searchContexts.size());
-
-        for (SearchContext searchContext : searchContexts) {
-            SearchExecutionContext ctx = searchContext.getSearchExecutionContext();
-            if (asUnsupportedSource) {
-                loaders.add(BlockLoader.CONSTANT_NULLS);
-                continue;
-            }
-            MappedFieldType fieldType = ctx.getFieldType(fieldName);
-            if (fieldType == null) {
-                // the field does not exist in this context
-                loaders.add(BlockLoader.CONSTANT_NULLS);
-                continue;
+    public static BlockLoader loader(SearchExecutionContext ctx, String fieldName, boolean asUnsupportedSource) {
+        if (asUnsupportedSource) {
+            return BlockLoader.CONSTANT_NULLS;
+        }
+        MappedFieldType fieldType = ctx.getFieldType(fieldName);
+        if (fieldType == null) {
+            // the field does not exist in this context
+            return BlockLoader.CONSTANT_NULLS;
+        }
+        BlockLoader loader = fieldType.blockLoader(new MappedFieldType.BlockLoaderContext() {
+            @Override
+            public String indexName() {
+                return ctx.getFullyQualifiedIndex().getName();
             }
             }
-            BlockLoader loader = fieldType.blockLoader(new MappedFieldType.BlockLoaderContext() {
-                @Override
-                public String indexName() {
-                    return ctx.getFullyQualifiedIndex().getName();
-                }
 
 
-                @Override
-                public SearchLookup lookup() {
-                    return ctx.lookup();
-                }
+            @Override
+            public SearchLookup lookup() {
+                return ctx.lookup();
+            }
 
 
-                @Override
-                public Set<String> sourcePaths(String name) {
-                    return ctx.sourcePath(name);
-                }
+            @Override
+            public Set<String> sourcePaths(String name) {
+                return ctx.sourcePath(name);
+            }
 
 
-                @Override
-                public String parentField(String field) {
-                    return ctx.parentPath(field);
-                }
-            });
-            if (loader == null) {
-                HeaderWarning.addWarning("Field [{}] cannot be retrieved, it is unsupported or not indexed; returning null", fieldName);
-                loaders.add(BlockLoader.CONSTANT_NULLS);
-                continue;
+            @Override
+            public String parentField(String field) {
+                return ctx.parentPath(field);
             }
             }
-            loaders.add(loader);
+        });
+        if (loader == null) {
+            HeaderWarning.addWarning("Field [{}] cannot be retrieved, it is unsupported or not indexed; returning null", fieldName);
+            return BlockLoader.CONSTANT_NULLS;
         }
         }
 
 
-        return loaders;
+        return loader;
     }
     }
 }
 }

+ 192 - 127
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java

@@ -27,6 +27,7 @@ import org.elasticsearch.compute.data.SingletonOrdinalsBuilder;
 import org.elasticsearch.compute.operator.AbstractPageMappingOperator;
 import org.elasticsearch.compute.operator.AbstractPageMappingOperator;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.compute.operator.Operator;
 import org.elasticsearch.compute.operator.Operator;
+import org.elasticsearch.core.Assertions;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.index.fieldvisitor.StoredFieldLoader;
 import org.elasticsearch.index.fieldvisitor.StoredFieldLoader;
@@ -43,6 +44,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Objects;
 import java.util.TreeMap;
 import java.util.TreeMap;
+import java.util.function.IntFunction;
 import java.util.function.Supplier;
 import java.util.function.Supplier;
 
 
 /**
 /**
@@ -95,22 +97,25 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
         }
         }
     }
     }
 
 
+    /**
+     * Configuration for a field to load.
+     *
+     * {@code blockLoader} maps shard index to the {@link BlockLoader}s
+     * which load the actual blocks.
+     */
+    public record FieldInfo(String name, ElementType type, IntFunction<BlockLoader> blockLoader) {}
+
     public record ShardContext(IndexReader reader, Supplier<SourceLoader> newSourceLoader) {}
     public record ShardContext(IndexReader reader, Supplier<SourceLoader> newSourceLoader) {}
 
 
-    private final List<FieldWork> fields;
+    private final FieldWork[] fields;
     private final List<ShardContext> shardContexts;
     private final List<ShardContext> shardContexts;
     private final int docChannel;
     private final int docChannel;
     private final BlockFactory blockFactory;
     private final BlockFactory blockFactory;
 
 
     private final Map<String, Integer> readersBuilt = new TreeMap<>();
     private final Map<String, Integer> readersBuilt = new TreeMap<>();
 
 
-    /**
-     * Configuration for a field to load.
-     *
-     * {@code blockLoaders} is a list, one entry per shard, of
-     * {@link BlockLoader}s which load the actual blocks.
-     */
-    public record FieldInfo(String name, List<BlockLoader> blockLoaders) {}
+    int lastShard = -1;
+    int lastSegment = -1;
 
 
     /**
     /**
      * Creates a new extractor
      * Creates a new extractor
@@ -118,7 +123,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
      * @param docChannel the channel containing the shard, leaf/segment and doc id
      * @param docChannel the channel containing the shard, leaf/segment and doc id
      */
      */
     public ValuesSourceReaderOperator(BlockFactory blockFactory, List<FieldInfo> fields, List<ShardContext> shardContexts, int docChannel) {
     public ValuesSourceReaderOperator(BlockFactory blockFactory, List<FieldInfo> fields, List<ShardContext> shardContexts, int docChannel) {
-        this.fields = fields.stream().map(f -> new FieldWork(f)).toList();
+        this.fields = fields.stream().map(f -> new FieldWork(f)).toArray(FieldWork[]::new);
         this.shardContexts = shardContexts;
         this.shardContexts = shardContexts;
         this.docChannel = docChannel;
         this.docChannel = docChannel;
         this.blockFactory = blockFactory;
         this.blockFactory = blockFactory;
@@ -128,13 +133,21 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
     protected Page process(Page page) {
     protected Page process(Page page) {
         DocVector docVector = page.<DocBlock>getBlock(docChannel).asVector();
         DocVector docVector = page.<DocBlock>getBlock(docChannel).asVector();
 
 
-        Block[] blocks = new Block[fields.size()];
+        Block[] blocks = new Block[fields.length];
         boolean success = false;
         boolean success = false;
         try {
         try {
             if (docVector.singleSegmentNonDecreasing()) {
             if (docVector.singleSegmentNonDecreasing()) {
                 loadFromSingleLeaf(blocks, docVector);
                 loadFromSingleLeaf(blocks, docVector);
             } else {
             } else {
-                loadFromManyLeaves(blocks, docVector);
+                try (LoadFromMany many = new LoadFromMany(blocks, docVector)) {
+                    many.run();
+                }
+            }
+            if (Assertions.ENABLED) {
+                for (int f = 0; f < fields.length; f++) {
+                    assert blocks[f].elementType() == ElementType.NULL || blocks[f].elementType() == fields[f].info.type
+                        : blocks[f].elementType() + " NOT IN (NULL, " + fields[f].info.type + ")";
+                }
             }
             }
             success = true;
             success = true;
         } catch (IOException e) {
         } catch (IOException e) {
@@ -147,10 +160,51 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
         return page.appendBlocks(blocks);
         return page.appendBlocks(blocks);
     }
     }
 
 
+    private void positionFieldWork(int shard, int segment, int firstDoc) {
+        if (lastShard == shard) {
+            if (lastSegment == segment) {
+                for (FieldWork w : fields) {
+                    w.sameSegment(firstDoc);
+                }
+                return;
+            }
+            lastSegment = segment;
+            for (FieldWork w : fields) {
+                w.sameShardNewSegment();
+            }
+            return;
+        }
+        lastShard = shard;
+        lastSegment = segment;
+        for (FieldWork w : fields) {
+            w.newShard(shard);
+        }
+    }
+
+    private boolean positionFieldWorkDocGuarteedAscending(int shard, int segment) {
+        if (lastShard == shard) {
+            if (lastSegment == segment) {
+                return false;
+            }
+            lastSegment = segment;
+            for (FieldWork w : fields) {
+                w.sameShardNewSegment();
+            }
+            return true;
+        }
+        lastShard = shard;
+        lastSegment = segment;
+        for (FieldWork w : fields) {
+            w.newShard(shard);
+        }
+        return true;
+    }
+
     private void loadFromSingleLeaf(Block[] blocks, DocVector docVector) throws IOException {
     private void loadFromSingleLeaf(Block[] blocks, DocVector docVector) throws IOException {
         int shard = docVector.shards().getInt(0);
         int shard = docVector.shards().getInt(0);
         int segment = docVector.segments().getInt(0);
         int segment = docVector.segments().getInt(0);
         int firstDoc = docVector.docs().getInt(0);
         int firstDoc = docVector.docs().getInt(0);
+        positionFieldWork(shard, segment, firstDoc);
         IntVector docs = docVector.docs();
         IntVector docs = docVector.docs();
         BlockLoader.Docs loaderDocs = new BlockLoader.Docs() {
         BlockLoader.Docs loaderDocs = new BlockLoader.Docs() {
             @Override
             @Override
@@ -164,24 +218,24 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
             }
             }
         };
         };
         StoredFieldsSpec storedFieldsSpec = StoredFieldsSpec.NO_REQUIREMENTS;
         StoredFieldsSpec storedFieldsSpec = StoredFieldsSpec.NO_REQUIREMENTS;
-        List<RowStrideReaderWork> rowStrideReaders = new ArrayList<>(fields.size());
+        List<RowStrideReaderWork> rowStrideReaders = new ArrayList<>(fields.length);
         ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(blockFactory, docs.getPositionCount());
         ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(blockFactory, docs.getPositionCount());
+        LeafReaderContext ctx = ctx(shard, segment);
         try {
         try {
-            for (int b = 0; b < fields.size(); b++) {
-                FieldWork field = fields.get(b);
-                BlockLoader.ColumnAtATimeReader columnAtATime = field.columnAtATime.reader(shard, segment, firstDoc);
+            for (int f = 0; f < fields.length; f++) {
+                FieldWork field = fields[f];
+                BlockLoader.ColumnAtATimeReader columnAtATime = field.columnAtATime(ctx);
                 if (columnAtATime != null) {
                 if (columnAtATime != null) {
-                    blocks[b] = (Block) columnAtATime.read(loaderBlockFactory, loaderDocs);
+                    blocks[f] = (Block) columnAtATime.read(loaderBlockFactory, loaderDocs);
                 } else {
                 } else {
-                    BlockLoader.RowStrideReader rowStride = field.rowStride.reader(shard, segment, firstDoc);
                     rowStrideReaders.add(
                     rowStrideReaders.add(
                         new RowStrideReaderWork(
                         new RowStrideReaderWork(
-                            rowStride,
-                            (Block.Builder) field.info.blockLoaders.get(shard).builder(loaderBlockFactory, docs.getPositionCount()),
-                            b
+                            field.rowStride(ctx),
+                            (Block.Builder) field.loader.builder(loaderBlockFactory, docs.getPositionCount()),
+                            f
                         )
                         )
                     );
                     );
-                    storedFieldsSpec = storedFieldsSpec.merge(field.info.blockLoaders.get(shard).rowStrideStoredFieldSpec());
+                    storedFieldsSpec = storedFieldsSpec.merge(field.loader.rowStrideStoredFieldSpec());
                 }
                 }
             }
             }
 
 
@@ -193,7 +247,6 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
                     "found row stride readers [" + rowStrideReaders + "] without stored fields [" + storedFieldsSpec + "]"
                     "found row stride readers [" + rowStrideReaders + "] without stored fields [" + storedFieldsSpec + "]"
                 );
                 );
             }
             }
-            LeafReaderContext ctx = ctx(shard, segment);
             StoredFieldLoader storedFieldLoader;
             StoredFieldLoader storedFieldLoader;
             if (useSequentialStoredFieldsReader(docVector.docs())) {
             if (useSequentialStoredFieldsReader(docVector.docs())) {
                 storedFieldLoader = StoredFieldLoader.fromSpecSequential(storedFieldsSpec);
                 storedFieldLoader = StoredFieldLoader.fromSpecSequential(storedFieldsSpec);
@@ -203,7 +256,6 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
                 trackStoredFields(storedFieldsSpec, false);
                 trackStoredFields(storedFieldsSpec, false);
             }
             }
             BlockLoaderStoredFieldsFromLeafLoader storedFields = new BlockLoaderStoredFieldsFromLeafLoader(
             BlockLoaderStoredFieldsFromLeafLoader storedFields = new BlockLoaderStoredFieldsFromLeafLoader(
-                // TODO enable the optimization by passing non-null to docs if correct
                 storedFieldLoader.getLoader(ctx, null),
                 storedFieldLoader.getLoader(ctx, null),
                 storedFieldsSpec.requiresSource() ? shardContexts.get(shard).newSourceLoader.get().leaf(ctx.reader(), null) : null
                 storedFieldsSpec.requiresSource() ? shardContexts.get(shard).newSourceLoader.get().leaf(ctx.reader(), null) : null
             );
             );
@@ -226,50 +278,91 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
         }
         }
     }
     }
 
 
-    private void loadFromManyLeaves(Block[] blocks, DocVector docVector) throws IOException {
-        IntVector shards = docVector.shards();
-        IntVector segments = docVector.segments();
-        IntVector docs = docVector.docs();
-        Block.Builder[] builders = new Block.Builder[blocks.length];
-        int[] forwards = docVector.shardSegmentDocMapForwards();
-        ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(blockFactory, docs.getPositionCount());
-        try {
-            for (int b = 0; b < fields.size(); b++) {
-                FieldWork field = fields.get(b);
-                builders[b] = builderFromFirstNonNull(loaderBlockFactory, field, docs.getPositionCount());
+    private class LoadFromMany implements Releasable {
+        private final Block[] target;
+        private final IntVector shards;
+        private final IntVector segments;
+        private final IntVector docs;
+        private final int[] forwards;
+        private final int[] backwards;
+        private final Block.Builder[] builders;
+        private final BlockLoader.RowStrideReader[] rowStride;
+
+        BlockLoaderStoredFieldsFromLeafLoader storedFields;
+
+        LoadFromMany(Block[] target, DocVector docVector) {
+            this.target = target;
+            shards = docVector.shards();
+            segments = docVector.segments();
+            docs = docVector.docs();
+            forwards = docVector.shardSegmentDocMapForwards();
+            backwards = docVector.shardSegmentDocMapBackwards();
+            builders = new Block.Builder[target.length];
+            rowStride = new BlockLoader.RowStrideReader[target.length];
+        }
+
+        void run() throws IOException {
+            for (int f = 0; f < fields.length; f++) {
+                /*
+                 * Important note: each block loader has a method to build an
+                 * optimized block loader, but we have *many* fields and some
+                 * of those block loaders may not be compatible with each other.
+                 * So! We take the least common denominator which is the loader
+                 * from the element expected element type.
+                 */
+                builders[f] = fields[f].info.type.newBlockBuilder(docs.getPositionCount(), blockFactory);
             }
             }
-            int lastShard = -1;
-            int lastSegment = -1;
-            BlockLoaderStoredFieldsFromLeafLoader storedFields = null;
-            for (int i = 0; i < forwards.length; i++) {
-                int p = forwards[i];
-                int shard = shards.getInt(p);
-                int segment = segments.getInt(p);
-                int doc = docs.getInt(p);
-                if (shard != lastShard || segment != lastSegment) {
-                    lastShard = shard;
-                    lastSegment = segment;
-                    StoredFieldsSpec storedFieldsSpec = storedFieldsSpecForShard(shard);
-                    LeafReaderContext ctx = ctx(shard, segment);
-                    storedFields = new BlockLoaderStoredFieldsFromLeafLoader(
-                        StoredFieldLoader.fromSpec(storedFieldsSpec).getLoader(ctx, null),
-                        storedFieldsSpec.requiresSource() ? shardContexts.get(shard).newSourceLoader.get().leaf(ctx.reader(), null) : null
-                    );
-                    if (false == storedFieldsSpec.equals(StoredFieldsSpec.NO_REQUIREMENTS)) {
-                        trackStoredFields(storedFieldsSpec, false);
-                    }
+            int p = forwards[0];
+            int shard = shards.getInt(p);
+            int segment = segments.getInt(p);
+            int firstDoc = docs.getInt(p);
+            positionFieldWork(shard, segment, firstDoc);
+            LeafReaderContext ctx = ctx(shard, segment);
+            fieldsMoved(ctx, shard);
+            read(firstDoc);
+            for (int i = 1; i < forwards.length; i++) {
+                p = forwards[i];
+                shard = shards.getInt(p);
+                segment = segments.getInt(p);
+                boolean changedSegment = positionFieldWorkDocGuarteedAscending(shard, segment);
+                if (changedSegment) {
+                    ctx = ctx(shard, segment);
+                    fieldsMoved(ctx, shard);
                 }
                 }
-                storedFields.advanceTo(doc);
-                for (int r = 0; r < blocks.length; r++) {
-                    fields.get(r).rowStride.reader(shard, segment, doc).read(doc, storedFields, builders[r]);
+                read(docs.getInt(p));
+            }
+            for (int f = 0; f < builders.length; f++) {
+                try (Block orig = builders[f].build()) {
+                    target[f] = orig.filter(backwards);
                 }
                 }
             }
             }
-            for (int r = 0; r < blocks.length; r++) {
-                try (Block orig = builders[r].build()) {
-                    blocks[r] = orig.filter(docVector.shardSegmentDocMapBackwards());
+        }
+
+        private void fieldsMoved(LeafReaderContext ctx, int shard) throws IOException {
+            StoredFieldsSpec storedFieldsSpec = StoredFieldsSpec.NO_REQUIREMENTS;
+            for (int f = 0; f < fields.length; f++) {
+                FieldWork field = fields[f];
+                rowStride[f] = field.rowStride(ctx);
+                storedFieldsSpec = storedFieldsSpec.merge(field.loader.rowStrideStoredFieldSpec());
+                storedFields = new BlockLoaderStoredFieldsFromLeafLoader(
+                    StoredFieldLoader.fromSpec(storedFieldsSpec).getLoader(ctx, null),
+                    storedFieldsSpec.requiresSource() ? shardContexts.get(shard).newSourceLoader.get().leaf(ctx.reader(), null) : null
+                );
+                if (false == storedFieldsSpec.equals(StoredFieldsSpec.NO_REQUIREMENTS)) {
+                    trackStoredFields(storedFieldsSpec, false);
                 }
                 }
             }
             }
-        } finally {
+        }
+
+        private void read(int doc) throws IOException {
+            storedFields.advanceTo(doc);
+            for (int f = 0; f < builders.length; f++) {
+                rowStride[f].read(doc, storedFields, builders[f]);
+            }
+        }
+
+        @Override
+        public void close() {
             Releasables.closeExpectNoException(builders);
             Releasables.closeExpectNoException(builders);
         }
         }
     }
     }
@@ -298,83 +391,55 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
         );
         );
     }
     }
 
 
-    /**
-     * Returns a builder from the first non - {@link BlockLoader#CONSTANT_NULLS} loader
-     * in the list. If they are all the null loader then returns a null builder.
-     */
-    private Block.Builder builderFromFirstNonNull(BlockLoader.BlockFactory loaderBlockFactory, FieldWork field, int positionCount) {
-        for (BlockLoader loader : field.info.blockLoaders) {
-            if (loader != BlockLoader.CONSTANT_NULLS) {
-                return (Block.Builder) loader.builder(loaderBlockFactory, positionCount);
-            }
-        }
-        // All null, just let the first one build the null block loader.
-        return (Block.Builder) field.info.blockLoaders.get(0).builder(loaderBlockFactory, positionCount);
-    }
-
-    private StoredFieldsSpec storedFieldsSpecForShard(int shard) {
-        StoredFieldsSpec storedFieldsSpec = StoredFieldsSpec.NO_REQUIREMENTS;
-        for (int b = 0; b < fields.size(); b++) {
-            FieldWork field = fields.get(b);
-            storedFieldsSpec = storedFieldsSpec.merge(field.info.blockLoaders.get(shard).rowStrideStoredFieldSpec());
-        }
-        return storedFieldsSpec;
-    }
-
     private class FieldWork {
     private class FieldWork {
         final FieldInfo info;
         final FieldInfo info;
-        final GuardedReader<BlockLoader.ColumnAtATimeReader> columnAtATime = new GuardedReader<>() {
-            @Override
-            BlockLoader.ColumnAtATimeReader build(BlockLoader loader, LeafReaderContext ctx) throws IOException {
-                return loader.columnAtATimeReader(ctx);
-            }
 
 
-            @Override
-            String type() {
-                return "column_at_a_time";
-            }
-        };
+        BlockLoader loader;
+        BlockLoader.ColumnAtATimeReader columnAtATime;
+        BlockLoader.RowStrideReader rowStride;
 
 
-        final GuardedReader<BlockLoader.RowStrideReader> rowStride = new GuardedReader<>() {
-            @Override
-            BlockLoader.RowStrideReader build(BlockLoader loader, LeafReaderContext ctx) throws IOException {
-                return loader.rowStrideReader(ctx);
-            }
+        FieldWork(FieldInfo info) {
+            this.info = info;
+        }
 
 
-            @Override
-            String type() {
-                return "row_stride";
+        void sameSegment(int firstDoc) {
+            if (columnAtATime != null && columnAtATime.canReuse(firstDoc) == false) {
+                columnAtATime = null;
             }
             }
-        };
+            if (rowStride != null && rowStride.canReuse(firstDoc) == false) {
+                rowStride = null;
+            }
+        }
 
 
-        FieldWork(FieldInfo info) {
-            this.info = info;
+        void sameShardNewSegment() {
+            columnAtATime = null;
+            rowStride = null;
         }
         }
 
 
-        private abstract class GuardedReader<V extends BlockLoader.Reader> {
-            private int lastShard = -1;
-            private int lastSegment = -1;
-            V lastReader;
+        void newShard(int shard) {
+            loader = info.blockLoader.apply(shard);
+            columnAtATime = null;
+            rowStride = null;
+        }
 
 
-            V reader(int shard, int segment, int startingDocId) throws IOException {
-                if (lastShard == shard && lastSegment == segment) {
-                    if (lastReader == null) {
-                        return null;
-                    }
-                    if (lastReader.canReuse(startingDocId)) {
-                        return lastReader;
-                    }
-                }
-                lastShard = shard;
-                lastSegment = segment;
-                lastReader = build(info.blockLoaders.get(shard), ctx(shard, segment));
-                readersBuilt.merge(info.name + ":" + type() + ":" + lastReader, 1, (prev, one) -> prev + one);
-                return lastReader;
+        BlockLoader.ColumnAtATimeReader columnAtATime(LeafReaderContext ctx) throws IOException {
+            if (columnAtATime == null) {
+                columnAtATime = loader.columnAtATimeReader(ctx);
+                trackReader("column_at_a_time", this.columnAtATime);
             }
             }
+            return columnAtATime;
+        }
 
 
-            abstract V build(BlockLoader loader, LeafReaderContext ctx) throws IOException;
+        BlockLoader.RowStrideReader rowStride(LeafReaderContext ctx) throws IOException {
+            if (rowStride == null) {
+                rowStride = loader.rowStrideReader(ctx);
+                trackReader("row_stride", this.rowStride);
+            }
+            return rowStride;
+        }
 
 
-            abstract String type();
+        private void trackReader(String type, BlockLoader.Reader reader) {
+            readersBuilt.merge(info.name + ":" + type + ":" + reader, 1, (prev, one) -> prev + one);
         }
         }
     }
     }
 
 
@@ -393,7 +458,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
     public String toString() {
     public String toString() {
         StringBuilder sb = new StringBuilder();
         StringBuilder sb = new StringBuilder();
         sb.append("ValuesSourceReaderOperator[fields = [");
         sb.append("ValuesSourceReaderOperator[fields = [");
-        if (fields.size() < 10) {
+        if (fields.length < 10) {
             boolean first = true;
             boolean first = true;
             for (FieldWork f : fields) {
             for (FieldWork f : fields) {
                 if (first) {
                 if (first) {
@@ -404,7 +469,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
                 sb.append(f.info.name);
                 sb.append(f.info.name);
             }
             }
         } else {
         } else {
-            sb.append(fields.size()).append(" fields");
+            sb.append(fields.length).append(" fields");
         }
         }
         return sb.append("]]").toString();
         return sb.append("]]").toString();
     }
     }

+ 7 - 6
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java

@@ -42,6 +42,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Objects;
+import java.util.function.IntFunction;
 import java.util.function.Supplier;
 import java.util.function.Supplier;
 
 
 import static java.util.Objects.requireNonNull;
 import static java.util.Objects.requireNonNull;
@@ -52,7 +53,7 @@ import static java.util.stream.Collectors.joining;
  */
  */
 public class OrdinalsGroupingOperator implements Operator {
 public class OrdinalsGroupingOperator implements Operator {
     public record OrdinalsGroupingOperatorFactory(
     public record OrdinalsGroupingOperatorFactory(
-        List<BlockLoader> blockLoaders,
+        IntFunction<BlockLoader> blockLoaders,
         List<ValuesSourceReaderOperator.ShardContext> shardContexts,
         List<ValuesSourceReaderOperator.ShardContext> shardContexts,
         ElementType groupingElementType,
         ElementType groupingElementType,
         int docChannel,
         int docChannel,
@@ -83,7 +84,7 @@ public class OrdinalsGroupingOperator implements Operator {
         }
         }
     }
     }
 
 
-    private final List<BlockLoader> blockLoaders;
+    private final IntFunction<BlockLoader> blockLoaders;
     private final List<ValuesSourceReaderOperator.ShardContext> shardContexts;
     private final List<ValuesSourceReaderOperator.ShardContext> shardContexts;
     private final int docChannel;
     private final int docChannel;
     private final String groupingField;
     private final String groupingField;
@@ -102,7 +103,7 @@ public class OrdinalsGroupingOperator implements Operator {
     private ValuesAggregator valuesAggregator;
     private ValuesAggregator valuesAggregator;
 
 
     public OrdinalsGroupingOperator(
     public OrdinalsGroupingOperator(
-        List<BlockLoader> blockLoaders,
+        IntFunction<BlockLoader> blockLoaders,
         List<ValuesSourceReaderOperator.ShardContext> shardContexts,
         List<ValuesSourceReaderOperator.ShardContext> shardContexts,
         ElementType groupingElementType,
         ElementType groupingElementType,
         int docChannel,
         int docChannel,
@@ -136,7 +137,7 @@ public class OrdinalsGroupingOperator implements Operator {
         requireNonNull(page, "page is null");
         requireNonNull(page, "page is null");
         DocVector docVector = page.<DocBlock>getBlock(docChannel).asVector();
         DocVector docVector = page.<DocBlock>getBlock(docChannel).asVector();
         final int shardIndex = docVector.shards().getInt(0);
         final int shardIndex = docVector.shards().getInt(0);
-        final var blockLoader = blockLoaders.get(shardIndex);
+        final var blockLoader = blockLoaders.apply(shardIndex);
         boolean pagePassed = false;
         boolean pagePassed = false;
         try {
         try {
             if (docVector.singleSegmentNonDecreasing() && blockLoader.supportsOrdinals()) {
             if (docVector.singleSegmentNonDecreasing() && blockLoader.supportsOrdinals()) {
@@ -464,7 +465,7 @@ public class OrdinalsGroupingOperator implements Operator {
         private final HashAggregationOperator aggregator;
         private final HashAggregationOperator aggregator;
 
 
         ValuesAggregator(
         ValuesAggregator(
-            List<BlockLoader> blockLoaders,
+            IntFunction<BlockLoader> blockLoaders,
             List<ValuesSourceReaderOperator.ShardContext> shardContexts,
             List<ValuesSourceReaderOperator.ShardContext> shardContexts,
             ElementType groupingElementType,
             ElementType groupingElementType,
             int docChannel,
             int docChannel,
@@ -476,7 +477,7 @@ public class OrdinalsGroupingOperator implements Operator {
         ) {
         ) {
             this.extractor = new ValuesSourceReaderOperator(
             this.extractor = new ValuesSourceReaderOperator(
                 driverContext.blockFactory(),
                 driverContext.blockFactory(),
-                List.of(new ValuesSourceReaderOperator.FieldInfo(groupingField, blockLoaders)),
+                List.of(new ValuesSourceReaderOperator.FieldInfo(groupingField, groupingElementType, blockLoaders)),
                 shardContexts,
                 shardContexts,
                 docChannel
                 docChannel
             );
             );

+ 2 - 2
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java

@@ -228,7 +228,7 @@ public class OperatorTests extends MapperServiceTestCase {
                         }
                         }
                     },
                     },
                         new OrdinalsGroupingOperator(
                         new OrdinalsGroupingOperator(
-                            List.of(new KeywordFieldMapper.KeywordFieldType("g").blockLoader(null)),
+                            shardIdx -> new KeywordFieldMapper.KeywordFieldType("g").blockLoader(null),
                             List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                             List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                             ElementType.BYTES_REF,
                             ElementType.BYTES_REF,
                             0,
                             0,
@@ -347,7 +347,7 @@ public class OperatorTests extends MapperServiceTestCase {
     }
     }
 
 
     static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query query, int limit) {
     static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query query, int limit) {
-        final SearchContext searchContext = mockSearchContext(reader);
+        final SearchContext searchContext = mockSearchContext(reader, 0);
         return new LuceneSourceOperator.Factory(
         return new LuceneSourceOperator.Factory(
             List.of(searchContext),
             List.of(searchContext),
             ctx -> query,
             ctx -> query,

+ 1 - 1
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java

@@ -83,7 +83,7 @@ public class LuceneCountOperatorTests extends AnyOperatorTestCase {
             throw new RuntimeException(e);
             throw new RuntimeException(e);
         }
         }
 
 
-        SearchContext ctx = LuceneSourceOperatorTests.mockSearchContext(reader);
+        SearchContext ctx = LuceneSourceOperatorTests.mockSearchContext(reader, 0);
         when(ctx.getSearchExecutionContext().getIndexReader()).thenReturn(reader);
         when(ctx.getSearchExecutionContext().getIndexReader()).thenReturn(reader);
         final Query query;
         final Query query;
         if (enableShortcut && randomBoolean()) {
         if (enableShortcut && randomBoolean()) {

+ 5 - 4
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java

@@ -18,6 +18,7 @@ import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
 import org.apache.lucene.tests.index.RandomIndexWriter;
 import org.elasticsearch.common.breaker.CircuitBreakingException;
 import org.elasticsearch.common.breaker.CircuitBreakingException;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.operator.AnyOperatorTestCase;
 import org.elasticsearch.compute.operator.AnyOperatorTestCase;
@@ -96,7 +97,7 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase {
             throw new RuntimeException(e);
             throw new RuntimeException(e);
         }
         }
 
 
-        SearchContext ctx = mockSearchContext(reader);
+        SearchContext ctx = mockSearchContext(reader, 0);
         when(ctx.getSearchExecutionContext().getFieldType(anyString())).thenAnswer(inv -> {
         when(ctx.getSearchExecutionContext().getFieldType(anyString())).thenAnswer(inv -> {
             String name = inv.getArgument(0);
             String name = inv.getArgument(0);
             return switch (name) {
             return switch (name) {
@@ -176,7 +177,7 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase {
 
 
     private void testSimple(DriverContext ctx, int size, int limit) {
     private void testSimple(DriverContext ctx, int size, int limit) {
         LuceneSourceOperator.Factory factory = simple(ctx.bigArrays(), DataPartitioning.SHARD, size, limit);
         LuceneSourceOperator.Factory factory = simple(ctx.bigArrays(), DataPartitioning.SHARD, size, limit);
-        Operator.OperatorFactory readS = ValuesSourceReaderOperatorTests.factory(reader, S_FIELD);
+        Operator.OperatorFactory readS = ValuesSourceReaderOperatorTests.factory(reader, S_FIELD, ElementType.LONG);
 
 
         List<Page> results = new ArrayList<>();
         List<Page> results = new ArrayList<>();
 
 
@@ -204,7 +205,7 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase {
      * Creates a mock search context with the given index reader.
      * Creates a mock search context with the given index reader.
      * The returned mock search context can be used to test with {@link LuceneOperator}.
      * The returned mock search context can be used to test with {@link LuceneOperator}.
      */
      */
-    public static SearchContext mockSearchContext(IndexReader reader) {
+    public static SearchContext mockSearchContext(IndexReader reader, int shardId) {
         try {
         try {
             ContextIndexSearcher searcher = new ContextIndexSearcher(
             ContextIndexSearcher searcher = new ContextIndexSearcher(
                 reader,
                 reader,
@@ -218,7 +219,7 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase {
             SearchExecutionContext searchExecutionContext = mock(SearchExecutionContext.class);
             SearchExecutionContext searchExecutionContext = mock(SearchExecutionContext.class);
             when(searchContext.getSearchExecutionContext()).thenReturn(searchExecutionContext);
             when(searchContext.getSearchExecutionContext()).thenReturn(searchExecutionContext);
             when(searchExecutionContext.getFullyQualifiedIndex()).thenReturn(new Index("test", "uid"));
             when(searchExecutionContext.getFullyQualifiedIndex()).thenReturn(new Index("test", "uid"));
-            when(searchExecutionContext.getShardId()).thenReturn(0);
+            when(searchExecutionContext.getShardId()).thenReturn(shardId);
             return searchContext;
             return searchContext;
         } catch (IOException e) {
         } catch (IOException e) {
             throw new UncheckedIOException(e);
             throw new UncheckedIOException(e);

+ 3 - 2
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperatorTests.java

@@ -17,6 +17,7 @@ import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
 import org.apache.lucene.tests.index.RandomIndexWriter;
 import org.elasticsearch.common.breaker.CircuitBreakingException;
 import org.elasticsearch.common.breaker.CircuitBreakingException;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.operator.AnyOperatorTestCase;
 import org.elasticsearch.compute.operator.AnyOperatorTestCase;
@@ -87,7 +88,7 @@ public class LuceneTopNSourceOperatorTests extends AnyOperatorTestCase {
             throw new RuntimeException(e);
             throw new RuntimeException(e);
         }
         }
 
 
-        SearchContext ctx = LuceneSourceOperatorTests.mockSearchContext(reader);
+        SearchContext ctx = LuceneSourceOperatorTests.mockSearchContext(reader, 0);
         when(ctx.getSearchExecutionContext().getFieldType(anyString())).thenAnswer(inv -> {
         when(ctx.getSearchExecutionContext().getFieldType(anyString())).thenAnswer(inv -> {
             String name = inv.getArgument(0);
             String name = inv.getArgument(0);
             return switch (name) {
             return switch (name) {
@@ -173,7 +174,7 @@ public class LuceneTopNSourceOperatorTests extends AnyOperatorTestCase {
 
 
     private void testSimple(DriverContext ctx, int size, int limit) {
     private void testSimple(DriverContext ctx, int size, int limit) {
         LuceneTopNSourceOperator.Factory factory = simple(ctx.bigArrays(), DataPartitioning.SHARD, size, limit);
         LuceneTopNSourceOperator.Factory factory = simple(ctx.bigArrays(), DataPartitioning.SHARD, size, limit);
-        Operator.OperatorFactory readS = ValuesSourceReaderOperatorTests.factory(reader, S_FIELD);
+        Operator.OperatorFactory readS = ValuesSourceReaderOperatorTests.factory(reader, S_FIELD, ElementType.LONG);
 
 
         List<Page> results = new ArrayList<>();
         List<Page> results = new ArrayList<>();
         OperatorTestCase.runDriver(
         OperatorTestCase.runDriver(

+ 241 - 49
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java

@@ -42,6 +42,7 @@ import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.DocBlock;
 import org.elasticsearch.compute.data.DocBlock;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.DoubleVector;
 import org.elasticsearch.compute.data.DoubleVector;
+import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.LongBlock;
@@ -69,12 +70,14 @@ import org.elasticsearch.index.mapper.SourceLoader;
 import org.elasticsearch.index.mapper.TextFieldMapper;
 import org.elasticsearch.index.mapper.TextFieldMapper;
 import org.elasticsearch.index.mapper.TextSearchInfo;
 import org.elasticsearch.index.mapper.TextSearchInfo;
 import org.elasticsearch.index.mapper.TsidExtractingIdFieldMapper;
 import org.elasticsearch.index.mapper.TsidExtractingIdFieldMapper;
+import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.lookup.SearchLookup;
 import org.elasticsearch.search.lookup.SearchLookup;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.json.JsonXContent;
 import org.elasticsearch.xcontent.json.JsonXContent;
 import org.hamcrest.Matcher;
 import org.hamcrest.Matcher;
 import org.junit.After;
 import org.junit.After;
 
 
+import java.io.Closeable;
 import java.io.IOException;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Collections;
@@ -82,6 +85,9 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.Set;
 import java.util.Set;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import java.util.stream.IntStream;
 
 
 import static org.elasticsearch.compute.lucene.LuceneSourceOperatorTests.mockSearchContext;
 import static org.elasticsearch.compute.lucene.LuceneSourceOperatorTests.mockSearchContext;
@@ -129,19 +135,20 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
                 throw new RuntimeException(e);
                 throw new RuntimeException(e);
             }
             }
         }
         }
-        return factory(reader, docValuesNumberField("long", NumberFieldMapper.NumberType.LONG));
+        return factory(reader, docValuesNumberField("long", NumberFieldMapper.NumberType.LONG), ElementType.LONG);
     }
     }
 
 
-    static Operator.OperatorFactory factory(IndexReader reader, MappedFieldType ft) {
-        return factory(reader, ft.name(), ft.blockLoader(null));
+    static Operator.OperatorFactory factory(IndexReader reader, MappedFieldType ft, ElementType elementType) {
+        return factory(reader, ft.name(), elementType, ft.blockLoader(null));
     }
     }
 
 
-    static Operator.OperatorFactory factory(IndexReader reader, String name, BlockLoader loader) {
-        return new ValuesSourceReaderOperator.Factory(
-            List.of(new ValuesSourceReaderOperator.FieldInfo(name, List.of(loader))),
-            List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
-            0
-        );
+    static Operator.OperatorFactory factory(IndexReader reader, String name, ElementType elementType, BlockLoader loader) {
+        return new ValuesSourceReaderOperator.Factory(List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType, shardIdx -> {
+            if (shardIdx != 0) {
+                fail("unexpected shardIdx [" + shardIdx + "]");
+            }
+            return loader;
+        })), List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)), 0);
     }
     }
 
 
     @Override
     @Override
@@ -160,7 +167,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
             throw new RuntimeException(e);
             throw new RuntimeException(e);
         }
         }
         var luceneFactory = new LuceneSourceOperator.Factory(
         var luceneFactory = new LuceneSourceOperator.Factory(
-            List.of(mockSearchContext(reader)),
+            List.of(mockSearchContext(reader, 0)),
             ctx -> new MatchAllDocsQuery(),
             ctx -> new MatchAllDocsQuery(),
             DataPartitioning.SHARD,
             DataPartitioning.SHARD,
             randomIntBetween(1, 10),
             randomIntBetween(1, 10),
@@ -172,6 +179,10 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
 
 
     private void initIndex(int size, int commitEvery) throws IOException {
     private void initIndex(int size, int commitEvery) throws IOException {
         keyToTags.clear();
         keyToTags.clear();
+        reader = initIndex(directory, size, commitEvery);
+    }
+
+    private IndexReader initIndex(Directory directory, int size, int commitEvery) throws IOException {
         try (
         try (
             IndexWriter writer = new IndexWriter(
             IndexWriter writer = new IndexWriter(
                 directory,
                 directory,
@@ -240,7 +251,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
                 }
                 }
             }
             }
         }
         }
-        reader = DirectoryReader.open(directory);
+        return DirectoryReader.open(directory);
     }
     }
 
 
     @Override
     @Override
@@ -308,12 +319,13 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         Checks checks = new Checks(Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING);
         Checks checks = new Checks(Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING);
         FieldCase testCase = new FieldCase(
         FieldCase testCase = new FieldCase(
             new KeywordFieldMapper.KeywordFieldType("kwd"),
             new KeywordFieldMapper.KeywordFieldType("kwd"),
+            ElementType.BYTES_REF,
             checks::tags,
             checks::tags,
             StatusChecks::keywordsFromDocValues
             StatusChecks::keywordsFromDocValues
         );
         );
         operators.add(
         operators.add(
             new ValuesSourceReaderOperator.Factory(
             new ValuesSourceReaderOperator.Factory(
-                List.of(testCase.info, fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER))),
+                List.of(testCase.info, fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER), ElementType.INT)),
                 List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                 List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                 0
                 0
             ).get(driverContext)
             ).get(driverContext)
@@ -356,8 +368,17 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         loadSimpleAndAssert(driverContext, List.of(source), Block.MvOrdering.UNORDERED);
         loadSimpleAndAssert(driverContext, List.of(source), Block.MvOrdering.UNORDERED);
     }
     }
 
 
-    private static ValuesSourceReaderOperator.FieldInfo fieldInfo(MappedFieldType ft) {
-        return new ValuesSourceReaderOperator.FieldInfo(ft.name(), List.of(ft.blockLoader(new MappedFieldType.BlockLoaderContext() {
+    private static ValuesSourceReaderOperator.FieldInfo fieldInfo(MappedFieldType ft, ElementType elementType) {
+        return new ValuesSourceReaderOperator.FieldInfo(ft.name(), elementType, shardIdx -> {
+            if (shardIdx != 0) {
+                fail("unexpected shardIdx [" + shardIdx + "]");
+            }
+            return ft.blockLoader(blContext());
+        });
+    }
+
+    private static MappedFieldType.BlockLoaderContext blContext() {
+        return new MappedFieldType.BlockLoaderContext() {
             @Override
             @Override
             public String indexName() {
             public String indexName() {
                 return "test_index";
                 return "test_index";
@@ -377,7 +398,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
             public String parentField(String field) {
             public String parentField(String field) {
                 return null;
                 return null;
             }
             }
-        })));
+        };
     }
     }
 
 
     private void loadSimpleAndAssert(DriverContext driverContext, List<Page> input, Block.MvOrdering docValuesMvOrdering) {
     private void loadSimpleAndAssert(DriverContext driverContext, List<Page> input, Block.MvOrdering docValuesMvOrdering) {
@@ -386,7 +407,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         List<Operator> operators = new ArrayList<>();
         List<Operator> operators = new ArrayList<>();
         operators.add(
         operators.add(
             new ValuesSourceReaderOperator.Factory(
             new ValuesSourceReaderOperator.Factory(
-                List.of(fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER))),
+                List.of(fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER), ElementType.INT)),
                 List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                 List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                 0
                 0
             ).get(driverContext)
             ).get(driverContext)
@@ -439,13 +460,14 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
     }
     }
 
 
     record FieldCase(ValuesSourceReaderOperator.FieldInfo info, CheckResults checkResults, CheckReadersWithName checkReaders) {
     record FieldCase(ValuesSourceReaderOperator.FieldInfo info, CheckResults checkResults, CheckReadersWithName checkReaders) {
-        FieldCase(MappedFieldType ft, CheckResults checkResults, CheckReadersWithName checkReaders) {
-            this(fieldInfo(ft), checkResults, checkReaders);
+        FieldCase(MappedFieldType ft, ElementType elementType, CheckResults checkResults, CheckReadersWithName checkReaders) {
+            this(fieldInfo(ft, elementType), checkResults, checkReaders);
         }
         }
 
 
-        FieldCase(MappedFieldType ft, CheckResults checkResults, CheckReaders checkReaders) {
+        FieldCase(MappedFieldType ft, ElementType elementType, CheckResults checkResults, CheckReaders checkReaders) {
             this(
             this(
                 ft,
                 ft,
+                elementType,
                 checkResults,
                 checkResults,
                 (name, forcedRowByRow, pageCount, segmentCount, readersBuilt) -> checkReaders.check(
                 (name, forcedRowByRow, pageCount, segmentCount, readersBuilt) -> checkReaders.check(
                     forcedRowByRow,
                     forcedRowByRow,
@@ -506,11 +528,17 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         Checks checks = new Checks(docValuesMvOrdering);
         Checks checks = new Checks(docValuesMvOrdering);
         List<FieldCase> r = new ArrayList<>();
         List<FieldCase> r = new ArrayList<>();
         r.add(
         r.add(
-            new FieldCase(docValuesNumberField("long", NumberFieldMapper.NumberType.LONG), checks::longs, StatusChecks::longsFromDocValues)
+            new FieldCase(
+                docValuesNumberField("long", NumberFieldMapper.NumberType.LONG),
+                ElementType.LONG,
+                checks::longs,
+                StatusChecks::longsFromDocValues
+            )
         );
         );
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("mv_long", NumberFieldMapper.NumberType.LONG),
                 docValuesNumberField("mv_long", NumberFieldMapper.NumberType.LONG),
+                ElementType.LONG,
                 checks::mvLongsFromDocValues,
                 checks::mvLongsFromDocValues,
                 StatusChecks::mvLongsFromDocValues
                 StatusChecks::mvLongsFromDocValues
             )
             )
@@ -518,26 +546,39 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("missing_long", NumberFieldMapper.NumberType.LONG),
                 docValuesNumberField("missing_long", NumberFieldMapper.NumberType.LONG),
+                ElementType.LONG,
                 checks::constantNulls,
                 checks::constantNulls,
                 StatusChecks::constantNulls
                 StatusChecks::constantNulls
             )
             )
         );
         );
         r.add(
         r.add(
-            new FieldCase(sourceNumberField("source_long", NumberFieldMapper.NumberType.LONG), checks::longs, StatusChecks::longsFromSource)
+            new FieldCase(
+                sourceNumberField("source_long", NumberFieldMapper.NumberType.LONG),
+                ElementType.LONG,
+                checks::longs,
+                StatusChecks::longsFromSource
+            )
         );
         );
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 sourceNumberField("mv_source_long", NumberFieldMapper.NumberType.LONG),
                 sourceNumberField("mv_source_long", NumberFieldMapper.NumberType.LONG),
+                ElementType.LONG,
                 checks::mvLongsUnordered,
                 checks::mvLongsUnordered,
                 StatusChecks::mvLongsFromSource
                 StatusChecks::mvLongsFromSource
             )
             )
         );
         );
         r.add(
         r.add(
-            new FieldCase(docValuesNumberField("int", NumberFieldMapper.NumberType.INTEGER), checks::ints, StatusChecks::intsFromDocValues)
+            new FieldCase(
+                docValuesNumberField("int", NumberFieldMapper.NumberType.INTEGER),
+                ElementType.INT,
+                checks::ints,
+                StatusChecks::intsFromDocValues
+            )
         );
         );
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("mv_int", NumberFieldMapper.NumberType.INTEGER),
                 docValuesNumberField("mv_int", NumberFieldMapper.NumberType.INTEGER),
+                ElementType.INT,
                 checks::mvIntsFromDocValues,
                 checks::mvIntsFromDocValues,
                 StatusChecks::mvIntsFromDocValues
                 StatusChecks::mvIntsFromDocValues
             )
             )
@@ -545,16 +586,23 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("missing_int", NumberFieldMapper.NumberType.INTEGER),
                 docValuesNumberField("missing_int", NumberFieldMapper.NumberType.INTEGER),
+                ElementType.INT,
                 checks::constantNulls,
                 checks::constantNulls,
                 StatusChecks::constantNulls
                 StatusChecks::constantNulls
             )
             )
         );
         );
         r.add(
         r.add(
-            new FieldCase(sourceNumberField("source_int", NumberFieldMapper.NumberType.INTEGER), checks::ints, StatusChecks::intsFromSource)
+            new FieldCase(
+                sourceNumberField("source_int", NumberFieldMapper.NumberType.INTEGER),
+                ElementType.INT,
+                checks::ints,
+                StatusChecks::intsFromSource
+            )
         );
         );
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 sourceNumberField("mv_source_int", NumberFieldMapper.NumberType.INTEGER),
                 sourceNumberField("mv_source_int", NumberFieldMapper.NumberType.INTEGER),
+                ElementType.INT,
                 checks::mvIntsUnordered,
                 checks::mvIntsUnordered,
                 StatusChecks::mvIntsFromSource
                 StatusChecks::mvIntsFromSource
             )
             )
@@ -562,6 +610,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("short", NumberFieldMapper.NumberType.SHORT),
                 docValuesNumberField("short", NumberFieldMapper.NumberType.SHORT),
+                ElementType.INT,
                 checks::shorts,
                 checks::shorts,
                 StatusChecks::shortsFromDocValues
                 StatusChecks::shortsFromDocValues
             )
             )
@@ -569,6 +618,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("mv_short", NumberFieldMapper.NumberType.SHORT),
                 docValuesNumberField("mv_short", NumberFieldMapper.NumberType.SHORT),
+                ElementType.INT,
                 checks::mvShorts,
                 checks::mvShorts,
                 StatusChecks::mvShortsFromDocValues
                 StatusChecks::mvShortsFromDocValues
             )
             )
@@ -576,16 +626,23 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("missing_short", NumberFieldMapper.NumberType.SHORT),
                 docValuesNumberField("missing_short", NumberFieldMapper.NumberType.SHORT),
+                ElementType.INT,
                 checks::constantNulls,
                 checks::constantNulls,
                 StatusChecks::constantNulls
                 StatusChecks::constantNulls
             )
             )
         );
         );
         r.add(
         r.add(
-            new FieldCase(docValuesNumberField("byte", NumberFieldMapper.NumberType.BYTE), checks::bytes, StatusChecks::bytesFromDocValues)
+            new FieldCase(
+                docValuesNumberField("byte", NumberFieldMapper.NumberType.BYTE),
+                ElementType.INT,
+                checks::bytes,
+                StatusChecks::bytesFromDocValues
+            )
         );
         );
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("mv_byte", NumberFieldMapper.NumberType.BYTE),
                 docValuesNumberField("mv_byte", NumberFieldMapper.NumberType.BYTE),
+                ElementType.INT,
                 checks::mvBytes,
                 checks::mvBytes,
                 StatusChecks::mvBytesFromDocValues
                 StatusChecks::mvBytesFromDocValues
             )
             )
@@ -593,6 +650,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("missing_byte", NumberFieldMapper.NumberType.BYTE),
                 docValuesNumberField("missing_byte", NumberFieldMapper.NumberType.BYTE),
+                ElementType.INT,
                 checks::constantNulls,
                 checks::constantNulls,
                 StatusChecks::constantNulls
                 StatusChecks::constantNulls
             )
             )
@@ -600,6 +658,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("double", NumberFieldMapper.NumberType.DOUBLE),
                 docValuesNumberField("double", NumberFieldMapper.NumberType.DOUBLE),
+                ElementType.DOUBLE,
                 checks::doubles,
                 checks::doubles,
                 StatusChecks::doublesFromDocValues
                 StatusChecks::doublesFromDocValues
             )
             )
@@ -607,6 +666,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("mv_double", NumberFieldMapper.NumberType.DOUBLE),
                 docValuesNumberField("mv_double", NumberFieldMapper.NumberType.DOUBLE),
+                ElementType.DOUBLE,
                 checks::mvDoubles,
                 checks::mvDoubles,
                 StatusChecks::mvDoublesFromDocValues
                 StatusChecks::mvDoublesFromDocValues
             )
             )
@@ -614,39 +674,106 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 docValuesNumberField("missing_double", NumberFieldMapper.NumberType.DOUBLE),
                 docValuesNumberField("missing_double", NumberFieldMapper.NumberType.DOUBLE),
+                ElementType.DOUBLE,
                 checks::constantNulls,
                 checks::constantNulls,
                 StatusChecks::constantNulls
                 StatusChecks::constantNulls
             )
             )
         );
         );
-        r.add(new FieldCase(new BooleanFieldMapper.BooleanFieldType("bool"), checks::bools, StatusChecks::boolFromDocValues));
-        r.add(new FieldCase(new BooleanFieldMapper.BooleanFieldType("mv_bool"), checks::mvBools, StatusChecks::mvBoolFromDocValues));
-        r.add(new FieldCase(new BooleanFieldMapper.BooleanFieldType("missing_bool"), checks::constantNulls, StatusChecks::constantNulls));
-        r.add(new FieldCase(new KeywordFieldMapper.KeywordFieldType("kwd"), checks::tags, StatusChecks::keywordsFromDocValues));
+        r.add(
+            new FieldCase(
+                new BooleanFieldMapper.BooleanFieldType("bool"),
+                ElementType.BOOLEAN,
+                checks::bools,
+                StatusChecks::boolFromDocValues
+            )
+        );
+        r.add(
+            new FieldCase(
+                new BooleanFieldMapper.BooleanFieldType("mv_bool"),
+                ElementType.BOOLEAN,
+                checks::mvBools,
+                StatusChecks::mvBoolFromDocValues
+            )
+        );
+        r.add(
+            new FieldCase(
+                new BooleanFieldMapper.BooleanFieldType("missing_bool"),
+                ElementType.BOOLEAN,
+                checks::constantNulls,
+                StatusChecks::constantNulls
+            )
+        );
+        r.add(
+            new FieldCase(
+                new KeywordFieldMapper.KeywordFieldType("kwd"),
+                ElementType.BYTES_REF,
+                checks::tags,
+                StatusChecks::keywordsFromDocValues
+            )
+        );
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 new KeywordFieldMapper.KeywordFieldType("mv_kwd"),
                 new KeywordFieldMapper.KeywordFieldType("mv_kwd"),
+                ElementType.BYTES_REF,
                 checks::mvStringsFromDocValues,
                 checks::mvStringsFromDocValues,
                 StatusChecks::mvKeywordsFromDocValues
                 StatusChecks::mvKeywordsFromDocValues
             )
             )
         );
         );
-        r.add(new FieldCase(new KeywordFieldMapper.KeywordFieldType("missing_kwd"), checks::constantNulls, StatusChecks::constantNulls));
-        r.add(new FieldCase(storedKeywordField("stored_kwd"), checks::strings, StatusChecks::keywordsFromStored));
-        r.add(new FieldCase(storedKeywordField("mv_stored_kwd"), checks::mvStringsUnordered, StatusChecks::mvKeywordsFromStored));
-        r.add(new FieldCase(sourceKeywordField("source_kwd"), checks::strings, StatusChecks::keywordsFromSource));
-        r.add(new FieldCase(sourceKeywordField("mv_source_kwd"), checks::mvStringsUnordered, StatusChecks::mvKeywordsFromSource));
-        r.add(new FieldCase(new TextFieldMapper.TextFieldType("source_text", false), checks::strings, StatusChecks::textFromSource));
+        r.add(
+            new FieldCase(
+                new KeywordFieldMapper.KeywordFieldType("missing_kwd"),
+                ElementType.BYTES_REF,
+                checks::constantNulls,
+                StatusChecks::constantNulls
+            )
+        );
+        r.add(new FieldCase(storedKeywordField("stored_kwd"), ElementType.BYTES_REF, checks::strings, StatusChecks::keywordsFromStored));
+        r.add(
+            new FieldCase(
+                storedKeywordField("mv_stored_kwd"),
+                ElementType.BYTES_REF,
+                checks::mvStringsUnordered,
+                StatusChecks::mvKeywordsFromStored
+            )
+        );
+        r.add(new FieldCase(sourceKeywordField("source_kwd"), ElementType.BYTES_REF, checks::strings, StatusChecks::keywordsFromSource));
+        r.add(
+            new FieldCase(
+                sourceKeywordField("mv_source_kwd"),
+                ElementType.BYTES_REF,
+                checks::mvStringsUnordered,
+                StatusChecks::mvKeywordsFromSource
+            )
+        );
+        r.add(
+            new FieldCase(
+                new TextFieldMapper.TextFieldType("source_text", false),
+                ElementType.BYTES_REF,
+                checks::strings,
+                StatusChecks::textFromSource
+            )
+        );
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 new TextFieldMapper.TextFieldType("mv_source_text", false),
                 new TextFieldMapper.TextFieldType("mv_source_text", false),
+                ElementType.BYTES_REF,
                 checks::mvStringsUnordered,
                 checks::mvStringsUnordered,
                 StatusChecks::mvTextFromSource
                 StatusChecks::mvTextFromSource
             )
             )
         );
         );
-        r.add(new FieldCase(storedTextField("stored_text"), checks::strings, StatusChecks::textFromStored));
-        r.add(new FieldCase(storedTextField("mv_stored_text"), checks::mvStringsUnordered, StatusChecks::mvTextFromStored));
+        r.add(new FieldCase(storedTextField("stored_text"), ElementType.BYTES_REF, checks::strings, StatusChecks::textFromStored));
+        r.add(
+            new FieldCase(
+                storedTextField("mv_stored_text"),
+                ElementType.BYTES_REF,
+                checks::mvStringsUnordered,
+                StatusChecks::mvTextFromStored
+            )
+        );
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 textFieldWithDelegate("text_with_delegate", new KeywordFieldMapper.KeywordFieldType("kwd")),
                 textFieldWithDelegate("text_with_delegate", new KeywordFieldMapper.KeywordFieldType("kwd")),
+                ElementType.BYTES_REF,
                 checks::tags,
                 checks::tags,
                 StatusChecks::textWithDelegate
                 StatusChecks::textWithDelegate
             )
             )
@@ -654,6 +781,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 textFieldWithDelegate("mv_text_with_delegate", new KeywordFieldMapper.KeywordFieldType("mv_kwd")),
                 textFieldWithDelegate("mv_text_with_delegate", new KeywordFieldMapper.KeywordFieldType("mv_kwd")),
+                ElementType.BYTES_REF,
                 checks::mvStringsFromDocValues,
                 checks::mvStringsFromDocValues,
                 StatusChecks::mvTextWithDelegate
                 StatusChecks::mvTextWithDelegate
             )
             )
@@ -661,22 +789,27 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
                 textFieldWithDelegate("missing_text_with_delegate", new KeywordFieldMapper.KeywordFieldType("missing_kwd")),
                 textFieldWithDelegate("missing_text_with_delegate", new KeywordFieldMapper.KeywordFieldType("missing_kwd")),
+                ElementType.BYTES_REF,
                 checks::constantNulls,
                 checks::constantNulls,
                 StatusChecks::constantNullTextWithDelegate
                 StatusChecks::constantNullTextWithDelegate
             )
             )
         );
         );
-        r.add(new FieldCase(new ProvidedIdFieldMapper(() -> false).fieldType(), checks::ids, StatusChecks::id));
-        r.add(new FieldCase(TsidExtractingIdFieldMapper.INSTANCE.fieldType(), checks::ids, StatusChecks::id));
+        r.add(new FieldCase(new ProvidedIdFieldMapper(() -> false).fieldType(), ElementType.BYTES_REF, checks::ids, StatusChecks::id));
+        r.add(new FieldCase(TsidExtractingIdFieldMapper.INSTANCE.fieldType(), ElementType.BYTES_REF, checks::ids, StatusChecks::id));
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
-                new ValuesSourceReaderOperator.FieldInfo("constant_bytes", List.of(BlockLoader.constantBytes(new BytesRef("foo")))),
+                new ValuesSourceReaderOperator.FieldInfo(
+                    "constant_bytes",
+                    ElementType.BYTES_REF,
+                    shardIdx -> BlockLoader.constantBytes(new BytesRef("foo"))
+                ),
                 checks::constantBytes,
                 checks::constantBytes,
                 StatusChecks::constantBytes
                 StatusChecks::constantBytes
             )
             )
         );
         );
         r.add(
         r.add(
             new FieldCase(
             new FieldCase(
-                new ValuesSourceReaderOperator.FieldInfo("null", List.of(BlockLoader.CONSTANT_NULLS)),
+                new ValuesSourceReaderOperator.FieldInfo("null", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS),
                 checks::constantNulls,
                 checks::constantNulls,
                 StatusChecks::constantNulls
                 StatusChecks::constantNulls
             )
             )
@@ -1149,7 +1282,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
 
 
         DriverContext driverContext = driverContext();
         DriverContext driverContext = driverContext();
         var luceneFactory = new LuceneSourceOperator.Factory(
         var luceneFactory = new LuceneSourceOperator.Factory(
-            List.of(mockSearchContext(reader)),
+            List.of(mockSearchContext(reader, 0)),
             ctx -> new MatchAllDocsQuery(),
             ctx -> new MatchAllDocsQuery(),
             randomFrom(DataPartitioning.values()),
             randomFrom(DataPartitioning.values()),
             randomIntBetween(1, 10),
             randomIntBetween(1, 10),
@@ -1161,10 +1294,10 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
                 driverContext,
                 driverContext,
                 luceneFactory.get(driverContext),
                 luceneFactory.get(driverContext),
                 List.of(
                 List.of(
-                    factory(reader, intFt).get(driverContext),
-                    factory(reader, longFt).get(driverContext),
-                    factory(reader, doubleFt).get(driverContext),
-                    factory(reader, kwFt).get(driverContext)
+                    factory(reader, intFt, ElementType.INT).get(driverContext),
+                    factory(reader, longFt, ElementType.LONG).get(driverContext),
+                    factory(reader, doubleFt, ElementType.DOUBLE).get(driverContext),
+                    factory(reader, kwFt, ElementType.BYTES_REF).get(driverContext)
                 ),
                 ),
                 new PageConsumerOperator(page -> {
                 new PageConsumerOperator(page -> {
                     try {
                     try {
@@ -1286,8 +1419,8 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
                 List.of(
                 List.of(
                     new ValuesSourceReaderOperator.Factory(
                     new ValuesSourceReaderOperator.Factory(
                         List.of(
                         List.of(
-                            new ValuesSourceReaderOperator.FieldInfo("null1", List.of(BlockLoader.CONSTANT_NULLS)),
-                            new ValuesSourceReaderOperator.FieldInfo("null2", List.of(BlockLoader.CONSTANT_NULLS))
+                            new ValuesSourceReaderOperator.FieldInfo("null1", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS),
+                            new ValuesSourceReaderOperator.FieldInfo("null2", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS)
                         ),
                         ),
                         List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                         List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                         0
                         0
@@ -1331,8 +1464,8 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         assertTrue(source.get(0).<DocBlock>getBlock(0).asVector().singleSegmentNonDecreasing());
         assertTrue(source.get(0).<DocBlock>getBlock(0).asVector().singleSegmentNonDecreasing());
         Operator op = new ValuesSourceReaderOperator.Factory(
         Operator op = new ValuesSourceReaderOperator.Factory(
             List.of(
             List.of(
-                fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER)),
-                fieldInfo(storedTextField("stored_text"))
+                fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER), ElementType.INT),
+                fieldInfo(storedTextField("stored_text"), ElementType.BYTES_REF)
             ),
             ),
             List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
             List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
             0
             0
@@ -1368,4 +1501,63 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
             assertThat(op.toString(), equalTo("ValuesSourceReaderOperator[fields = [" + cases.size() + " fields]]"));
             assertThat(op.toString(), equalTo("ValuesSourceReaderOperator[fields = [" + cases.size() + " fields]]"));
         }
         }
     }
     }
+
+    public void testManyShards() throws IOException {
+        int shardCount = between(2, 10);
+        int size = between(100, 1000);
+        Directory[] dirs = new Directory[shardCount];
+        IndexReader[] readers = new IndexReader[shardCount];
+        Closeable[] closeMe = new Closeable[shardCount * 2];
+        Set<Integer> seenShards = new TreeSet<>();
+        Map<Integer, Integer> keyCounts = new TreeMap<>();
+        try {
+            for (int d = 0; d < dirs.length; d++) {
+                closeMe[d * 2 + 1] = dirs[d] = newDirectory();
+                closeMe[d * 2] = readers[d] = initIndex(dirs[d], size, between(10, size * 2));
+            }
+            List<SearchContext> contexts = new ArrayList<>();
+            List<ValuesSourceReaderOperator.ShardContext> readerShardContexts = new ArrayList<>();
+            for (int s = 0; s < shardCount; s++) {
+                contexts.add(mockSearchContext(readers[s], s));
+                readerShardContexts.add(new ValuesSourceReaderOperator.ShardContext(readers[s], () -> SourceLoader.FROM_STORED_SOURCE));
+            }
+            var luceneFactory = new LuceneSourceOperator.Factory(
+                contexts,
+                ctx -> new MatchAllDocsQuery(),
+                DataPartitioning.SHARD,
+                randomIntBetween(1, 10),
+                1000,
+                LuceneOperator.NO_LIMIT
+            );
+            MappedFieldType ft = docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER);
+            var readerFactory = new ValuesSourceReaderOperator.Factory(
+                List.of(new ValuesSourceReaderOperator.FieldInfo("key", ElementType.INT, shardIdx -> {
+                    seenShards.add(shardIdx);
+                    return ft.blockLoader(blContext());
+                })),
+                readerShardContexts,
+                0
+            );
+            DriverContext driverContext = driverContext();
+            List<Page> results = drive(
+                readerFactory.get(driverContext),
+                CannedSourceOperator.collectPages(luceneFactory.get(driverContext)).iterator(),
+                driverContext
+            );
+            assertThat(seenShards, equalTo(IntStream.range(0, shardCount).boxed().collect(Collectors.toCollection(TreeSet::new))));
+            for (Page p : results) {
+                IntBlock keyBlock = p.getBlock(1);
+                IntVector keys = keyBlock.asVector();
+                for (int i = 0; i < keys.getPositionCount(); i++) {
+                    keyCounts.merge(keys.getInt(i), 1, (prev, one) -> prev + one);
+                }
+            }
+            assertThat(keyCounts.keySet(), hasSize(size));
+            for (int k = 0; k < size; k++) {
+                assertThat(keyCounts.get(k), equalTo(shardCount));
+            }
+        } finally {
+            IOUtils.close(closeMe);
+        }
+    }
 }
 }

+ 15 - 3
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java

@@ -42,6 +42,7 @@ import org.elasticsearch.compute.operator.SourceOperator;
 import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.Releasables;
+import org.elasticsearch.index.mapper.BlockLoader;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
@@ -273,12 +274,23 @@ public class EnrichLookupService {
                 NamedExpression extractField = extractFields.get(i);
                 NamedExpression extractField = extractFields.get(i);
                 final ElementType elementType = PlannerUtils.toElementType(extractField.dataType());
                 final ElementType elementType = PlannerUtils.toElementType(extractField.dataType());
                 mergingTypes[i] = elementType;
                 mergingTypes[i] = elementType;
-                var loaders = BlockReaderFactories.loaders(
-                    List.of(searchContext),
+                BlockLoader loader = BlockReaderFactories.loader(
+                    searchContext.getSearchExecutionContext(),
                     extractField instanceof Alias a ? ((NamedExpression) a.child()).name() : extractField.name(),
                     extractField instanceof Alias a ? ((NamedExpression) a.child()).name() : extractField.name(),
                     EsqlDataTypes.isUnsupported(extractField.dataType())
                     EsqlDataTypes.isUnsupported(extractField.dataType())
                 );
                 );
-                fields.add(new ValuesSourceReaderOperator.FieldInfo(extractField.name(), loaders));
+                fields.add(
+                    new ValuesSourceReaderOperator.FieldInfo(
+                        extractField.name(),
+                        PlannerUtils.toElementType(extractField.dataType()),
+                        shardIdx -> {
+                            if (shardIdx != 0) {
+                                throw new IllegalStateException("only one shard");
+                            }
+                            return loader;
+                        }
+                    )
+                );
             }
             }
             intermediateOperators.add(
             intermediateOperators.add(
                 new ValuesSourceReaderOperator(
                 new ValuesSourceReaderOperator(

+ 15 - 3
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java

@@ -43,6 +43,7 @@ import org.elasticsearch.xpack.ql.type.DataType;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.List;
 import java.util.function.Function;
 import java.util.function.Function;
+import java.util.function.IntFunction;
 
 
 import static org.elasticsearch.common.lucene.search.Queries.newNonNestedFilter;
 import static org.elasticsearch.common.lucene.search.Queries.newNonNestedFilter;
 import static org.elasticsearch.compute.lucene.LuceneSourceOperator.NO_LIMIT;
 import static org.elasticsearch.compute.lucene.LuceneSourceOperator.NO_LIMIT;
@@ -74,9 +75,15 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
             }
             }
             layout.append(attr);
             layout.append(attr);
             DataType dataType = attr.dataType();
             DataType dataType = attr.dataType();
+            ElementType elementType = PlannerUtils.toElementType(dataType);
             String fieldName = attr.name();
             String fieldName = attr.name();
-            List<BlockLoader> loaders = BlockReaderFactories.loaders(searchContexts, fieldName, EsqlDataTypes.isUnsupported(dataType));
-            fields.add(new ValuesSourceReaderOperator.FieldInfo(fieldName, loaders));
+            boolean isSupported = EsqlDataTypes.isUnsupported(dataType);
+            IntFunction<BlockLoader> loader = s -> BlockReaderFactories.loader(
+                searchContexts.get(s).getSearchExecutionContext(),
+                fieldName,
+                isSupported
+            );
+            fields.add(new ValuesSourceReaderOperator.FieldInfo(fieldName, elementType, loader));
         }
         }
         return source.with(new ValuesSourceReaderOperator.Factory(fields, readers, docChannel), layout.build());
         return source.with(new ValuesSourceReaderOperator.Factory(fields, readers, docChannel), layout.build());
     }
     }
@@ -165,8 +172,13 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
             .toList();
             .toList();
         // The grouping-by values are ready, let's group on them directly.
         // The grouping-by values are ready, let's group on them directly.
         // Costin: why are they ready and not already exposed in the layout?
         // Costin: why are they ready and not already exposed in the layout?
+        boolean isUnsupported = EsqlDataTypes.isUnsupported(attrSource.dataType());
         return new OrdinalsGroupingOperator.OrdinalsGroupingOperatorFactory(
         return new OrdinalsGroupingOperator.OrdinalsGroupingOperatorFactory(
-            BlockReaderFactories.loaders(searchContexts, attrSource.name(), EsqlDataTypes.isUnsupported(attrSource.dataType())),
+            shardIdx -> BlockReaderFactories.loader(
+                searchContexts.get(shardIdx).getSearchExecutionContext(),
+                attrSource.name(),
+                isUnsupported
+            ),
             shardContexts,
             shardContexts,
             groupElementType,
             groupElementType,
             docChannel,
             docChannel,