Преглед на файлове

ESQL: Delay finding field load infrastructure (#103821)

This optimizes loading fields across many, many indices by resolving the
field loading infrastructure when it's first needed rather than up
front. This speeds things up because, if you are loading from many many
shards, you often don't need to set up the field loading infrastructure
for all shards at all - often you'll just need to set it up for a couple
of the shards.
Nik Everett преди 1 година
родител
ревизия
fac60e5803
променени са 14 файла, в които са добавени 567 реда и са изтрити 266 реда
  1. 0 1
      benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/MultivalueDedupeBenchmark.java
  2. 50 25
      benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java
  3. 5 0
      docs/changelog/103821.yaml
  4. 1 0
      docs/reference/esql/functions/types/add.asciidoc
  5. 30 43
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/BlockReaderFactories.java
  6. 192 127
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java
  7. 7 6
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java
  8. 2 2
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java
  9. 1 1
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java
  10. 5 4
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java
  11. 3 2
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperatorTests.java
  12. 241 49
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java
  13. 15 3
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java
  14. 15 3
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java

+ 0 - 1
benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/MultivalueDedupeBenchmark.java

@@ -45,7 +45,6 @@ import java.util.concurrent.TimeUnit;
 @State(Scope.Thread)
 @Fork(1)
 public class MultivalueDedupeBenchmark {
-    private static final BigArrays BIG_ARRAYS = BigArrays.NON_RECYCLING_INSTANCE;  // TODO real big arrays?
     private static final BlockFactory blockFactory = BlockFactory.getInstance(
         new NoopCircuitBreaker("noop"),
         BigArrays.NON_RECYCLING_INSTANCE

+ 50 - 25
benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesSourceReaderBenchmark.java

@@ -31,6 +31,7 @@ import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.DocVector;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.DoubleVector;
+import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LongBlock;
@@ -96,8 +97,12 @@ public class ValuesSourceReaderBenchmark {
                     for (String name : ValuesSourceReaderBenchmark.class.getField("name").getAnnotationsByType(Param.class)[0].value()) {
                         benchmark.layout = layout;
                         benchmark.name = name;
-                        benchmark.setupPages();
-                        benchmark.benchmark();
+                        try {
+                            benchmark.setupPages();
+                            benchmark.benchmark();
+                        } catch (Exception e) {
+                            throw new AssertionError("error initializing [" + layout + "/" + name + "]", e);
+                        }
                     }
                 }
             } finally {
@@ -111,11 +116,11 @@ public class ValuesSourceReaderBenchmark {
     private static List<ValuesSourceReaderOperator.FieldInfo> fields(String name) {
         return switch (name) {
             case "3_stored_keywords" -> List.of(
-                new ValuesSourceReaderOperator.FieldInfo("keyword_1", List.of(blockLoader("stored_keyword_1"))),
-                new ValuesSourceReaderOperator.FieldInfo("keyword_2", List.of(blockLoader("stored_keyword_2"))),
-                new ValuesSourceReaderOperator.FieldInfo("keyword_3", List.of(blockLoader("stored_keyword_3")))
+                new ValuesSourceReaderOperator.FieldInfo("keyword_1", ElementType.BYTES_REF, shardIdx -> blockLoader("stored_keyword_1")),
+                new ValuesSourceReaderOperator.FieldInfo("keyword_2", ElementType.BYTES_REF, shardIdx -> blockLoader("stored_keyword_2")),
+                new ValuesSourceReaderOperator.FieldInfo("keyword_3", ElementType.BYTES_REF, shardIdx -> blockLoader("stored_keyword_3"))
             );
-            default -> List.of(new ValuesSourceReaderOperator.FieldInfo(name, List.of(blockLoader(name))));
+            default -> List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType(name), shardIdx -> blockLoader(name)));
         };
     }
 
@@ -125,29 +130,38 @@ public class ValuesSourceReaderBenchmark {
         STORED;
     }
 
-    private static BlockLoader blockLoader(String name) {
-        Where where = Where.DOC_VALUES;
-        if (name.startsWith("stored_")) {
-            name = name.substring("stored_".length());
-            where = Where.STORED;
-        } else if (name.startsWith("source_")) {
-            name = name.substring("source_".length());
-            where = Where.SOURCE;
-        }
+    private static ElementType elementType(String name) {
+        name = WhereAndBaseName.fromName(name).name;
         switch (name) {
             case "long":
-                return numericBlockLoader(name, where, NumberFieldMapper.NumberType.LONG);
+                return ElementType.LONG;
             case "int":
-                return numericBlockLoader(name, where, NumberFieldMapper.NumberType.INTEGER);
+                return ElementType.INT;
             case "double":
-                return numericBlockLoader(name, where, NumberFieldMapper.NumberType.DOUBLE);
-            case "keyword":
-                name = "keyword_1";
+                return ElementType.DOUBLE;
         }
         if (name.startsWith("keyword")) {
+            return ElementType.BYTES_REF;
+        }
+        throw new UnsupportedOperationException("no element type for [" + name + "]");
+    }
+
+    private static BlockLoader blockLoader(String name) {
+        WhereAndBaseName w = WhereAndBaseName.fromName(name);
+        switch (w.name) {
+            case "long":
+                return numericBlockLoader(w, NumberFieldMapper.NumberType.LONG);
+            case "int":
+                return numericBlockLoader(w, NumberFieldMapper.NumberType.INTEGER);
+            case "double":
+                return numericBlockLoader(w, NumberFieldMapper.NumberType.DOUBLE);
+            case "keyword":
+                w = new WhereAndBaseName(w.where, "keyword_1");
+        }
+        if (w.name.startsWith("keyword")) {
             boolean syntheticSource = false;
             FieldType ft = new FieldType(KeywordFieldMapper.Defaults.FIELD_TYPE);
-            switch (where) {
+            switch (w.where) {
                 case DOC_VALUES:
                     break;
                 case SOURCE:
@@ -161,7 +175,7 @@ public class ValuesSourceReaderBenchmark {
             }
             ft.freeze();
             return new KeywordFieldMapper.KeywordFieldType(
-                name,
+                w.name,
                 ft,
                 Lucene.KEYWORD_ANALYZER,
                 Lucene.KEYWORD_ANALYZER,
@@ -193,10 +207,21 @@ public class ValuesSourceReaderBenchmark {
         throw new IllegalArgumentException("can't read [" + name + "]");
     }
 
-    private static BlockLoader numericBlockLoader(String name, Where where, NumberFieldMapper.NumberType numberType) {
+    private record WhereAndBaseName(Where where, String name) {
+        static WhereAndBaseName fromName(String name) {
+            if (name.startsWith("stored_")) {
+                return new WhereAndBaseName(Where.STORED, name.substring("stored_".length()));
+            } else if (name.startsWith("source_")) {
+                return new WhereAndBaseName(Where.SOURCE, name.substring("source_".length()));
+            }
+            return new WhereAndBaseName(Where.DOC_VALUES, name);
+        }
+    }
+
+    private static BlockLoader numericBlockLoader(WhereAndBaseName w, NumberFieldMapper.NumberType numberType) {
         boolean stored = false;
         boolean docValues = true;
-        switch (where) {
+        switch (w.where) {
             case DOC_VALUES:
                 break;
             case SOURCE:
@@ -207,7 +232,7 @@ public class ValuesSourceReaderBenchmark {
                 throw new UnsupportedOperationException();
         }
         return new NumberFieldMapper.NumberFieldType(
-            name,
+            w.name,
             numberType,
             true,
             stored,

+ 5 - 0
docs/changelog/103821.yaml

@@ -0,0 +1,5 @@
+pr: 103821
+summary: "ESQL: Delay finding field load infrastructure"
+area: ES|QL
+type: enhancement
+issues: []

+ 1 - 0
docs/reference/esql/functions/types/add.asciidoc

@@ -2,6 +2,7 @@
 |===
 lhs | rhs | result
 date_period | date_period | date_period
+date_period | datetime | datetime
 datetime | date_period | datetime
 datetime | time_duration | datetime
 double | double | double

+ 30 - 43
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/BlockReaderFactories.java

@@ -11,11 +11,8 @@ import org.elasticsearch.common.logging.HeaderWarning;
 import org.elasticsearch.index.mapper.BlockLoader;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.query.SearchExecutionContext;
-import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.lookup.SearchLookup;
 
-import java.util.ArrayList;
-import java.util.List;
 import java.util.Set;
 
 /**
@@ -26,56 +23,46 @@ public final class BlockReaderFactories {
 
     /**
      * Resolves *how* ESQL loads field values.
-     * @param searchContexts a search context per search index we're loading
-     *                       field from
+     * @param ctx a search context for the index we're loading field from
      * @param fieldName the name of the field to load
      * @param asUnsupportedSource should the field be loaded as "unsupported"?
      *                            These will always have {@code null} values
      */
-    public static List<BlockLoader> loaders(List<SearchContext> searchContexts, String fieldName, boolean asUnsupportedSource) {
-        List<BlockLoader> loaders = new ArrayList<>(searchContexts.size());
-
-        for (SearchContext searchContext : searchContexts) {
-            SearchExecutionContext ctx = searchContext.getSearchExecutionContext();
-            if (asUnsupportedSource) {
-                loaders.add(BlockLoader.CONSTANT_NULLS);
-                continue;
-            }
-            MappedFieldType fieldType = ctx.getFieldType(fieldName);
-            if (fieldType == null) {
-                // the field does not exist in this context
-                loaders.add(BlockLoader.CONSTANT_NULLS);
-                continue;
+    public static BlockLoader loader(SearchExecutionContext ctx, String fieldName, boolean asUnsupportedSource) {
+        if (asUnsupportedSource) {
+            return BlockLoader.CONSTANT_NULLS;
+        }
+        MappedFieldType fieldType = ctx.getFieldType(fieldName);
+        if (fieldType == null) {
+            // the field does not exist in this context
+            return BlockLoader.CONSTANT_NULLS;
+        }
+        BlockLoader loader = fieldType.blockLoader(new MappedFieldType.BlockLoaderContext() {
+            @Override
+            public String indexName() {
+                return ctx.getFullyQualifiedIndex().getName();
             }
-            BlockLoader loader = fieldType.blockLoader(new MappedFieldType.BlockLoaderContext() {
-                @Override
-                public String indexName() {
-                    return ctx.getFullyQualifiedIndex().getName();
-                }
 
-                @Override
-                public SearchLookup lookup() {
-                    return ctx.lookup();
-                }
+            @Override
+            public SearchLookup lookup() {
+                return ctx.lookup();
+            }
 
-                @Override
-                public Set<String> sourcePaths(String name) {
-                    return ctx.sourcePath(name);
-                }
+            @Override
+            public Set<String> sourcePaths(String name) {
+                return ctx.sourcePath(name);
+            }
 
-                @Override
-                public String parentField(String field) {
-                    return ctx.parentPath(field);
-                }
-            });
-            if (loader == null) {
-                HeaderWarning.addWarning("Field [{}] cannot be retrieved, it is unsupported or not indexed; returning null", fieldName);
-                loaders.add(BlockLoader.CONSTANT_NULLS);
-                continue;
+            @Override
+            public String parentField(String field) {
+                return ctx.parentPath(field);
             }
-            loaders.add(loader);
+        });
+        if (loader == null) {
+            HeaderWarning.addWarning("Field [{}] cannot be retrieved, it is unsupported or not indexed; returning null", fieldName);
+            return BlockLoader.CONSTANT_NULLS;
         }
 
-        return loaders;
+        return loader;
     }
 }

+ 192 - 127
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperator.java

@@ -27,6 +27,7 @@ import org.elasticsearch.compute.data.SingletonOrdinalsBuilder;
 import org.elasticsearch.compute.operator.AbstractPageMappingOperator;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.compute.operator.Operator;
+import org.elasticsearch.core.Assertions;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.index.fieldvisitor.StoredFieldLoader;
@@ -43,6 +44,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.TreeMap;
+import java.util.function.IntFunction;
 import java.util.function.Supplier;
 
 /**
@@ -95,22 +97,25 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
         }
     }
 
+    /**
+     * Configuration for a field to load.
+     *
+     * {@code blockLoader} maps shard index to the {@link BlockLoader}s
+     * which load the actual blocks.
+     */
+    public record FieldInfo(String name, ElementType type, IntFunction<BlockLoader> blockLoader) {}
+
     public record ShardContext(IndexReader reader, Supplier<SourceLoader> newSourceLoader) {}
 
-    private final List<FieldWork> fields;
+    private final FieldWork[] fields;
     private final List<ShardContext> shardContexts;
     private final int docChannel;
     private final BlockFactory blockFactory;
 
     private final Map<String, Integer> readersBuilt = new TreeMap<>();
 
-    /**
-     * Configuration for a field to load.
-     *
-     * {@code blockLoaders} is a list, one entry per shard, of
-     * {@link BlockLoader}s which load the actual blocks.
-     */
-    public record FieldInfo(String name, List<BlockLoader> blockLoaders) {}
+    int lastShard = -1;
+    int lastSegment = -1;
 
     /**
      * Creates a new extractor
@@ -118,7 +123,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
      * @param docChannel the channel containing the shard, leaf/segment and doc id
      */
     public ValuesSourceReaderOperator(BlockFactory blockFactory, List<FieldInfo> fields, List<ShardContext> shardContexts, int docChannel) {
-        this.fields = fields.stream().map(f -> new FieldWork(f)).toList();
+        this.fields = fields.stream().map(f -> new FieldWork(f)).toArray(FieldWork[]::new);
         this.shardContexts = shardContexts;
         this.docChannel = docChannel;
         this.blockFactory = blockFactory;
@@ -128,13 +133,21 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
     protected Page process(Page page) {
         DocVector docVector = page.<DocBlock>getBlock(docChannel).asVector();
 
-        Block[] blocks = new Block[fields.size()];
+        Block[] blocks = new Block[fields.length];
         boolean success = false;
         try {
             if (docVector.singleSegmentNonDecreasing()) {
                 loadFromSingleLeaf(blocks, docVector);
             } else {
-                loadFromManyLeaves(blocks, docVector);
+                try (LoadFromMany many = new LoadFromMany(blocks, docVector)) {
+                    many.run();
+                }
+            }
+            if (Assertions.ENABLED) {
+                for (int f = 0; f < fields.length; f++) {
+                    assert blocks[f].elementType() == ElementType.NULL || blocks[f].elementType() == fields[f].info.type
+                        : blocks[f].elementType() + " NOT IN (NULL, " + fields[f].info.type + ")";
+                }
             }
             success = true;
         } catch (IOException e) {
@@ -147,10 +160,51 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
         return page.appendBlocks(blocks);
     }
 
+    private void positionFieldWork(int shard, int segment, int firstDoc) {
+        if (lastShard == shard) {
+            if (lastSegment == segment) {
+                for (FieldWork w : fields) {
+                    w.sameSegment(firstDoc);
+                }
+                return;
+            }
+            lastSegment = segment;
+            for (FieldWork w : fields) {
+                w.sameShardNewSegment();
+            }
+            return;
+        }
+        lastShard = shard;
+        lastSegment = segment;
+        for (FieldWork w : fields) {
+            w.newShard(shard);
+        }
+    }
+
+    private boolean positionFieldWorkDocGuarteedAscending(int shard, int segment) {
+        if (lastShard == shard) {
+            if (lastSegment == segment) {
+                return false;
+            }
+            lastSegment = segment;
+            for (FieldWork w : fields) {
+                w.sameShardNewSegment();
+            }
+            return true;
+        }
+        lastShard = shard;
+        lastSegment = segment;
+        for (FieldWork w : fields) {
+            w.newShard(shard);
+        }
+        return true;
+    }
+
     private void loadFromSingleLeaf(Block[] blocks, DocVector docVector) throws IOException {
         int shard = docVector.shards().getInt(0);
         int segment = docVector.segments().getInt(0);
         int firstDoc = docVector.docs().getInt(0);
+        positionFieldWork(shard, segment, firstDoc);
         IntVector docs = docVector.docs();
         BlockLoader.Docs loaderDocs = new BlockLoader.Docs() {
             @Override
@@ -164,24 +218,24 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
             }
         };
         StoredFieldsSpec storedFieldsSpec = StoredFieldsSpec.NO_REQUIREMENTS;
-        List<RowStrideReaderWork> rowStrideReaders = new ArrayList<>(fields.size());
+        List<RowStrideReaderWork> rowStrideReaders = new ArrayList<>(fields.length);
         ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(blockFactory, docs.getPositionCount());
+        LeafReaderContext ctx = ctx(shard, segment);
         try {
-            for (int b = 0; b < fields.size(); b++) {
-                FieldWork field = fields.get(b);
-                BlockLoader.ColumnAtATimeReader columnAtATime = field.columnAtATime.reader(shard, segment, firstDoc);
+            for (int f = 0; f < fields.length; f++) {
+                FieldWork field = fields[f];
+                BlockLoader.ColumnAtATimeReader columnAtATime = field.columnAtATime(ctx);
                 if (columnAtATime != null) {
-                    blocks[b] = (Block) columnAtATime.read(loaderBlockFactory, loaderDocs);
+                    blocks[f] = (Block) columnAtATime.read(loaderBlockFactory, loaderDocs);
                 } else {
-                    BlockLoader.RowStrideReader rowStride = field.rowStride.reader(shard, segment, firstDoc);
                     rowStrideReaders.add(
                         new RowStrideReaderWork(
-                            rowStride,
-                            (Block.Builder) field.info.blockLoaders.get(shard).builder(loaderBlockFactory, docs.getPositionCount()),
-                            b
+                            field.rowStride(ctx),
+                            (Block.Builder) field.loader.builder(loaderBlockFactory, docs.getPositionCount()),
+                            f
                         )
                     );
-                    storedFieldsSpec = storedFieldsSpec.merge(field.info.blockLoaders.get(shard).rowStrideStoredFieldSpec());
+                    storedFieldsSpec = storedFieldsSpec.merge(field.loader.rowStrideStoredFieldSpec());
                 }
             }
 
@@ -193,7 +247,6 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
                     "found row stride readers [" + rowStrideReaders + "] without stored fields [" + storedFieldsSpec + "]"
                 );
             }
-            LeafReaderContext ctx = ctx(shard, segment);
             StoredFieldLoader storedFieldLoader;
             if (useSequentialStoredFieldsReader(docVector.docs())) {
                 storedFieldLoader = StoredFieldLoader.fromSpecSequential(storedFieldsSpec);
@@ -203,7 +256,6 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
                 trackStoredFields(storedFieldsSpec, false);
             }
             BlockLoaderStoredFieldsFromLeafLoader storedFields = new BlockLoaderStoredFieldsFromLeafLoader(
-                // TODO enable the optimization by passing non-null to docs if correct
                 storedFieldLoader.getLoader(ctx, null),
                 storedFieldsSpec.requiresSource() ? shardContexts.get(shard).newSourceLoader.get().leaf(ctx.reader(), null) : null
             );
@@ -226,50 +278,91 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
         }
     }
 
-    private void loadFromManyLeaves(Block[] blocks, DocVector docVector) throws IOException {
-        IntVector shards = docVector.shards();
-        IntVector segments = docVector.segments();
-        IntVector docs = docVector.docs();
-        Block.Builder[] builders = new Block.Builder[blocks.length];
-        int[] forwards = docVector.shardSegmentDocMapForwards();
-        ComputeBlockLoaderFactory loaderBlockFactory = new ComputeBlockLoaderFactory(blockFactory, docs.getPositionCount());
-        try {
-            for (int b = 0; b < fields.size(); b++) {
-                FieldWork field = fields.get(b);
-                builders[b] = builderFromFirstNonNull(loaderBlockFactory, field, docs.getPositionCount());
+    private class LoadFromMany implements Releasable {
+        private final Block[] target;
+        private final IntVector shards;
+        private final IntVector segments;
+        private final IntVector docs;
+        private final int[] forwards;
+        private final int[] backwards;
+        private final Block.Builder[] builders;
+        private final BlockLoader.RowStrideReader[] rowStride;
+
+        BlockLoaderStoredFieldsFromLeafLoader storedFields;
+
+        LoadFromMany(Block[] target, DocVector docVector) {
+            this.target = target;
+            shards = docVector.shards();
+            segments = docVector.segments();
+            docs = docVector.docs();
+            forwards = docVector.shardSegmentDocMapForwards();
+            backwards = docVector.shardSegmentDocMapBackwards();
+            builders = new Block.Builder[target.length];
+            rowStride = new BlockLoader.RowStrideReader[target.length];
+        }
+
+        void run() throws IOException {
+            for (int f = 0; f < fields.length; f++) {
+                /*
+                 * Important note: each block loader has a method to build an
+                 * optimized block loader, but we have *many* fields and some
+                 * of those block loaders may not be compatible with each other.
+                 * So! We take the least common denominator which is the loader
+                 * from the element expected element type.
+                 */
+                builders[f] = fields[f].info.type.newBlockBuilder(docs.getPositionCount(), blockFactory);
             }
-            int lastShard = -1;
-            int lastSegment = -1;
-            BlockLoaderStoredFieldsFromLeafLoader storedFields = null;
-            for (int i = 0; i < forwards.length; i++) {
-                int p = forwards[i];
-                int shard = shards.getInt(p);
-                int segment = segments.getInt(p);
-                int doc = docs.getInt(p);
-                if (shard != lastShard || segment != lastSegment) {
-                    lastShard = shard;
-                    lastSegment = segment;
-                    StoredFieldsSpec storedFieldsSpec = storedFieldsSpecForShard(shard);
-                    LeafReaderContext ctx = ctx(shard, segment);
-                    storedFields = new BlockLoaderStoredFieldsFromLeafLoader(
-                        StoredFieldLoader.fromSpec(storedFieldsSpec).getLoader(ctx, null),
-                        storedFieldsSpec.requiresSource() ? shardContexts.get(shard).newSourceLoader.get().leaf(ctx.reader(), null) : null
-                    );
-                    if (false == storedFieldsSpec.equals(StoredFieldsSpec.NO_REQUIREMENTS)) {
-                        trackStoredFields(storedFieldsSpec, false);
-                    }
+            int p = forwards[0];
+            int shard = shards.getInt(p);
+            int segment = segments.getInt(p);
+            int firstDoc = docs.getInt(p);
+            positionFieldWork(shard, segment, firstDoc);
+            LeafReaderContext ctx = ctx(shard, segment);
+            fieldsMoved(ctx, shard);
+            read(firstDoc);
+            for (int i = 1; i < forwards.length; i++) {
+                p = forwards[i];
+                shard = shards.getInt(p);
+                segment = segments.getInt(p);
+                boolean changedSegment = positionFieldWorkDocGuarteedAscending(shard, segment);
+                if (changedSegment) {
+                    ctx = ctx(shard, segment);
+                    fieldsMoved(ctx, shard);
                 }
-                storedFields.advanceTo(doc);
-                for (int r = 0; r < blocks.length; r++) {
-                    fields.get(r).rowStride.reader(shard, segment, doc).read(doc, storedFields, builders[r]);
+                read(docs.getInt(p));
+            }
+            for (int f = 0; f < builders.length; f++) {
+                try (Block orig = builders[f].build()) {
+                    target[f] = orig.filter(backwards);
                 }
             }
-            for (int r = 0; r < blocks.length; r++) {
-                try (Block orig = builders[r].build()) {
-                    blocks[r] = orig.filter(docVector.shardSegmentDocMapBackwards());
+        }
+
+        private void fieldsMoved(LeafReaderContext ctx, int shard) throws IOException {
+            StoredFieldsSpec storedFieldsSpec = StoredFieldsSpec.NO_REQUIREMENTS;
+            for (int f = 0; f < fields.length; f++) {
+                FieldWork field = fields[f];
+                rowStride[f] = field.rowStride(ctx);
+                storedFieldsSpec = storedFieldsSpec.merge(field.loader.rowStrideStoredFieldSpec());
+                storedFields = new BlockLoaderStoredFieldsFromLeafLoader(
+                    StoredFieldLoader.fromSpec(storedFieldsSpec).getLoader(ctx, null),
+                    storedFieldsSpec.requiresSource() ? shardContexts.get(shard).newSourceLoader.get().leaf(ctx.reader(), null) : null
+                );
+                if (false == storedFieldsSpec.equals(StoredFieldsSpec.NO_REQUIREMENTS)) {
+                    trackStoredFields(storedFieldsSpec, false);
                 }
             }
-        } finally {
+        }
+
+        private void read(int doc) throws IOException {
+            storedFields.advanceTo(doc);
+            for (int f = 0; f < builders.length; f++) {
+                rowStride[f].read(doc, storedFields, builders[f]);
+            }
+        }
+
+        @Override
+        public void close() {
             Releasables.closeExpectNoException(builders);
         }
     }
@@ -298,83 +391,55 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
         );
     }
 
-    /**
-     * Returns a builder from the first non - {@link BlockLoader#CONSTANT_NULLS} loader
-     * in the list. If they are all the null loader then returns a null builder.
-     */
-    private Block.Builder builderFromFirstNonNull(BlockLoader.BlockFactory loaderBlockFactory, FieldWork field, int positionCount) {
-        for (BlockLoader loader : field.info.blockLoaders) {
-            if (loader != BlockLoader.CONSTANT_NULLS) {
-                return (Block.Builder) loader.builder(loaderBlockFactory, positionCount);
-            }
-        }
-        // All null, just let the first one build the null block loader.
-        return (Block.Builder) field.info.blockLoaders.get(0).builder(loaderBlockFactory, positionCount);
-    }
-
-    private StoredFieldsSpec storedFieldsSpecForShard(int shard) {
-        StoredFieldsSpec storedFieldsSpec = StoredFieldsSpec.NO_REQUIREMENTS;
-        for (int b = 0; b < fields.size(); b++) {
-            FieldWork field = fields.get(b);
-            storedFieldsSpec = storedFieldsSpec.merge(field.info.blockLoaders.get(shard).rowStrideStoredFieldSpec());
-        }
-        return storedFieldsSpec;
-    }
-
     private class FieldWork {
         final FieldInfo info;
-        final GuardedReader<BlockLoader.ColumnAtATimeReader> columnAtATime = new GuardedReader<>() {
-            @Override
-            BlockLoader.ColumnAtATimeReader build(BlockLoader loader, LeafReaderContext ctx) throws IOException {
-                return loader.columnAtATimeReader(ctx);
-            }
 
-            @Override
-            String type() {
-                return "column_at_a_time";
-            }
-        };
+        BlockLoader loader;
+        BlockLoader.ColumnAtATimeReader columnAtATime;
+        BlockLoader.RowStrideReader rowStride;
 
-        final GuardedReader<BlockLoader.RowStrideReader> rowStride = new GuardedReader<>() {
-            @Override
-            BlockLoader.RowStrideReader build(BlockLoader loader, LeafReaderContext ctx) throws IOException {
-                return loader.rowStrideReader(ctx);
-            }
+        FieldWork(FieldInfo info) {
+            this.info = info;
+        }
 
-            @Override
-            String type() {
-                return "row_stride";
+        void sameSegment(int firstDoc) {
+            if (columnAtATime != null && columnAtATime.canReuse(firstDoc) == false) {
+                columnAtATime = null;
             }
-        };
+            if (rowStride != null && rowStride.canReuse(firstDoc) == false) {
+                rowStride = null;
+            }
+        }
 
-        FieldWork(FieldInfo info) {
-            this.info = info;
+        void sameShardNewSegment() {
+            columnAtATime = null;
+            rowStride = null;
         }
 
-        private abstract class GuardedReader<V extends BlockLoader.Reader> {
-            private int lastShard = -1;
-            private int lastSegment = -1;
-            V lastReader;
+        void newShard(int shard) {
+            loader = info.blockLoader.apply(shard);
+            columnAtATime = null;
+            rowStride = null;
+        }
 
-            V reader(int shard, int segment, int startingDocId) throws IOException {
-                if (lastShard == shard && lastSegment == segment) {
-                    if (lastReader == null) {
-                        return null;
-                    }
-                    if (lastReader.canReuse(startingDocId)) {
-                        return lastReader;
-                    }
-                }
-                lastShard = shard;
-                lastSegment = segment;
-                lastReader = build(info.blockLoaders.get(shard), ctx(shard, segment));
-                readersBuilt.merge(info.name + ":" + type() + ":" + lastReader, 1, (prev, one) -> prev + one);
-                return lastReader;
+        BlockLoader.ColumnAtATimeReader columnAtATime(LeafReaderContext ctx) throws IOException {
+            if (columnAtATime == null) {
+                columnAtATime = loader.columnAtATimeReader(ctx);
+                trackReader("column_at_a_time", this.columnAtATime);
             }
+            return columnAtATime;
+        }
 
-            abstract V build(BlockLoader loader, LeafReaderContext ctx) throws IOException;
+        BlockLoader.RowStrideReader rowStride(LeafReaderContext ctx) throws IOException {
+            if (rowStride == null) {
+                rowStride = loader.rowStrideReader(ctx);
+                trackReader("row_stride", this.rowStride);
+            }
+            return rowStride;
+        }
 
-            abstract String type();
+        private void trackReader(String type, BlockLoader.Reader reader) {
+            readersBuilt.merge(info.name + ":" + type + ":" + reader, 1, (prev, one) -> prev + one);
         }
     }
 
@@ -393,7 +458,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
     public String toString() {
         StringBuilder sb = new StringBuilder();
         sb.append("ValuesSourceReaderOperator[fields = [");
-        if (fields.size() < 10) {
+        if (fields.length < 10) {
             boolean first = true;
             for (FieldWork f : fields) {
                 if (first) {
@@ -404,7 +469,7 @@ public class ValuesSourceReaderOperator extends AbstractPageMappingOperator {
                 sb.append(f.info.name);
             }
         } else {
-            sb.append(fields.size()).append(" fields");
+            sb.append(fields.length).append(" fields");
         }
         return sb.append("]]").toString();
     }

+ 7 - 6
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java

@@ -42,6 +42,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.function.IntFunction;
 import java.util.function.Supplier;
 
 import static java.util.Objects.requireNonNull;
@@ -52,7 +53,7 @@ import static java.util.stream.Collectors.joining;
  */
 public class OrdinalsGroupingOperator implements Operator {
     public record OrdinalsGroupingOperatorFactory(
-        List<BlockLoader> blockLoaders,
+        IntFunction<BlockLoader> blockLoaders,
         List<ValuesSourceReaderOperator.ShardContext> shardContexts,
         ElementType groupingElementType,
         int docChannel,
@@ -83,7 +84,7 @@ public class OrdinalsGroupingOperator implements Operator {
         }
     }
 
-    private final List<BlockLoader> blockLoaders;
+    private final IntFunction<BlockLoader> blockLoaders;
     private final List<ValuesSourceReaderOperator.ShardContext> shardContexts;
     private final int docChannel;
     private final String groupingField;
@@ -102,7 +103,7 @@ public class OrdinalsGroupingOperator implements Operator {
     private ValuesAggregator valuesAggregator;
 
     public OrdinalsGroupingOperator(
-        List<BlockLoader> blockLoaders,
+        IntFunction<BlockLoader> blockLoaders,
         List<ValuesSourceReaderOperator.ShardContext> shardContexts,
         ElementType groupingElementType,
         int docChannel,
@@ -136,7 +137,7 @@ public class OrdinalsGroupingOperator implements Operator {
         requireNonNull(page, "page is null");
         DocVector docVector = page.<DocBlock>getBlock(docChannel).asVector();
         final int shardIndex = docVector.shards().getInt(0);
-        final var blockLoader = blockLoaders.get(shardIndex);
+        final var blockLoader = blockLoaders.apply(shardIndex);
         boolean pagePassed = false;
         try {
             if (docVector.singleSegmentNonDecreasing() && blockLoader.supportsOrdinals()) {
@@ -464,7 +465,7 @@ public class OrdinalsGroupingOperator implements Operator {
         private final HashAggregationOperator aggregator;
 
         ValuesAggregator(
-            List<BlockLoader> blockLoaders,
+            IntFunction<BlockLoader> blockLoaders,
             List<ValuesSourceReaderOperator.ShardContext> shardContexts,
             ElementType groupingElementType,
             int docChannel,
@@ -476,7 +477,7 @@ public class OrdinalsGroupingOperator implements Operator {
         ) {
             this.extractor = new ValuesSourceReaderOperator(
                 driverContext.blockFactory(),
-                List.of(new ValuesSourceReaderOperator.FieldInfo(groupingField, blockLoaders)),
+                List.of(new ValuesSourceReaderOperator.FieldInfo(groupingField, groupingElementType, blockLoaders)),
                 shardContexts,
                 docChannel
             );

+ 2 - 2
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/OperatorTests.java

@@ -228,7 +228,7 @@ public class OperatorTests extends MapperServiceTestCase {
                         }
                     },
                         new OrdinalsGroupingOperator(
-                            List.of(new KeywordFieldMapper.KeywordFieldType("g").blockLoader(null)),
+                            shardIdx -> new KeywordFieldMapper.KeywordFieldType("g").blockLoader(null),
                             List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                             ElementType.BYTES_REF,
                             0,
@@ -347,7 +347,7 @@ public class OperatorTests extends MapperServiceTestCase {
     }
 
     static LuceneOperator.Factory luceneOperatorFactory(IndexReader reader, Query query, int limit) {
-        final SearchContext searchContext = mockSearchContext(reader);
+        final SearchContext searchContext = mockSearchContext(reader, 0);
         return new LuceneSourceOperator.Factory(
             List.of(searchContext),
             ctx -> query,

+ 1 - 1
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java

@@ -83,7 +83,7 @@ public class LuceneCountOperatorTests extends AnyOperatorTestCase {
             throw new RuntimeException(e);
         }
 
-        SearchContext ctx = LuceneSourceOperatorTests.mockSearchContext(reader);
+        SearchContext ctx = LuceneSourceOperatorTests.mockSearchContext(reader, 0);
         when(ctx.getSearchExecutionContext().getIndexReader()).thenReturn(reader);
         final Query query;
         if (enableShortcut && randomBoolean()) {

+ 5 - 4
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java

@@ -18,6 +18,7 @@ import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
 import org.elasticsearch.common.breaker.CircuitBreakingException;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.operator.AnyOperatorTestCase;
@@ -96,7 +97,7 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase {
             throw new RuntimeException(e);
         }
 
-        SearchContext ctx = mockSearchContext(reader);
+        SearchContext ctx = mockSearchContext(reader, 0);
         when(ctx.getSearchExecutionContext().getFieldType(anyString())).thenAnswer(inv -> {
             String name = inv.getArgument(0);
             return switch (name) {
@@ -176,7 +177,7 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase {
 
     private void testSimple(DriverContext ctx, int size, int limit) {
         LuceneSourceOperator.Factory factory = simple(ctx.bigArrays(), DataPartitioning.SHARD, size, limit);
-        Operator.OperatorFactory readS = ValuesSourceReaderOperatorTests.factory(reader, S_FIELD);
+        Operator.OperatorFactory readS = ValuesSourceReaderOperatorTests.factory(reader, S_FIELD, ElementType.LONG);
 
         List<Page> results = new ArrayList<>();
 
@@ -204,7 +205,7 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase {
      * Creates a mock search context with the given index reader.
      * The returned mock search context can be used to test with {@link LuceneOperator}.
      */
-    public static SearchContext mockSearchContext(IndexReader reader) {
+    public static SearchContext mockSearchContext(IndexReader reader, int shardId) {
         try {
             ContextIndexSearcher searcher = new ContextIndexSearcher(
                 reader,
@@ -218,7 +219,7 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase {
             SearchExecutionContext searchExecutionContext = mock(SearchExecutionContext.class);
             when(searchContext.getSearchExecutionContext()).thenReturn(searchExecutionContext);
             when(searchExecutionContext.getFullyQualifiedIndex()).thenReturn(new Index("test", "uid"));
-            when(searchExecutionContext.getShardId()).thenReturn(0);
+            when(searchExecutionContext.getShardId()).thenReturn(shardId);
             return searchContext;
         } catch (IOException e) {
             throw new UncheckedIOException(e);

+ 3 - 2
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperatorTests.java

@@ -17,6 +17,7 @@ import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
 import org.elasticsearch.common.breaker.CircuitBreakingException;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.operator.AnyOperatorTestCase;
@@ -87,7 +88,7 @@ public class LuceneTopNSourceOperatorTests extends AnyOperatorTestCase {
             throw new RuntimeException(e);
         }
 
-        SearchContext ctx = LuceneSourceOperatorTests.mockSearchContext(reader);
+        SearchContext ctx = LuceneSourceOperatorTests.mockSearchContext(reader, 0);
         when(ctx.getSearchExecutionContext().getFieldType(anyString())).thenAnswer(inv -> {
             String name = inv.getArgument(0);
             return switch (name) {
@@ -173,7 +174,7 @@ public class LuceneTopNSourceOperatorTests extends AnyOperatorTestCase {
 
     private void testSimple(DriverContext ctx, int size, int limit) {
         LuceneTopNSourceOperator.Factory factory = simple(ctx.bigArrays(), DataPartitioning.SHARD, size, limit);
-        Operator.OperatorFactory readS = ValuesSourceReaderOperatorTests.factory(reader, S_FIELD);
+        Operator.OperatorFactory readS = ValuesSourceReaderOperatorTests.factory(reader, S_FIELD, ElementType.LONG);
 
         List<Page> results = new ArrayList<>();
         OperatorTestCase.runDriver(

+ 241 - 49
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValuesSourceReaderOperatorTests.java

@@ -42,6 +42,7 @@ import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.DocBlock;
 import org.elasticsearch.compute.data.DoubleBlock;
 import org.elasticsearch.compute.data.DoubleVector;
+import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LongBlock;
@@ -69,12 +70,14 @@ import org.elasticsearch.index.mapper.SourceLoader;
 import org.elasticsearch.index.mapper.TextFieldMapper;
 import org.elasticsearch.index.mapper.TextSearchInfo;
 import org.elasticsearch.index.mapper.TsidExtractingIdFieldMapper;
+import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.lookup.SearchLookup;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.json.JsonXContent;
 import org.hamcrest.Matcher;
 import org.junit.After;
 
+import java.io.Closeable;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -82,6 +85,9 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 import static org.elasticsearch.compute.lucene.LuceneSourceOperatorTests.mockSearchContext;
@@ -129,19 +135,20 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
                 throw new RuntimeException(e);
             }
         }
-        return factory(reader, docValuesNumberField("long", NumberFieldMapper.NumberType.LONG));
+        return factory(reader, docValuesNumberField("long", NumberFieldMapper.NumberType.LONG), ElementType.LONG);
     }
 
-    static Operator.OperatorFactory factory(IndexReader reader, MappedFieldType ft) {
-        return factory(reader, ft.name(), ft.blockLoader(null));
+    static Operator.OperatorFactory factory(IndexReader reader, MappedFieldType ft, ElementType elementType) {
+        return factory(reader, ft.name(), elementType, ft.blockLoader(null));
     }
 
-    static Operator.OperatorFactory factory(IndexReader reader, String name, BlockLoader loader) {
-        return new ValuesSourceReaderOperator.Factory(
-            List.of(new ValuesSourceReaderOperator.FieldInfo(name, List.of(loader))),
-            List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
-            0
-        );
+    static Operator.OperatorFactory factory(IndexReader reader, String name, ElementType elementType, BlockLoader loader) {
+        return new ValuesSourceReaderOperator.Factory(List.of(new ValuesSourceReaderOperator.FieldInfo(name, elementType, shardIdx -> {
+            if (shardIdx != 0) {
+                fail("unexpected shardIdx [" + shardIdx + "]");
+            }
+            return loader;
+        })), List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)), 0);
     }
 
     @Override
@@ -160,7 +167,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
             throw new RuntimeException(e);
         }
         var luceneFactory = new LuceneSourceOperator.Factory(
-            List.of(mockSearchContext(reader)),
+            List.of(mockSearchContext(reader, 0)),
             ctx -> new MatchAllDocsQuery(),
             DataPartitioning.SHARD,
             randomIntBetween(1, 10),
@@ -172,6 +179,10 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
 
     private void initIndex(int size, int commitEvery) throws IOException {
         keyToTags.clear();
+        reader = initIndex(directory, size, commitEvery);
+    }
+
+    private IndexReader initIndex(Directory directory, int size, int commitEvery) throws IOException {
         try (
             IndexWriter writer = new IndexWriter(
                 directory,
@@ -240,7 +251,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
                 }
             }
         }
-        reader = DirectoryReader.open(directory);
+        return DirectoryReader.open(directory);
     }
 
     @Override
@@ -308,12 +319,13 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         Checks checks = new Checks(Block.MvOrdering.DEDUPLICATED_AND_SORTED_ASCENDING);
         FieldCase testCase = new FieldCase(
             new KeywordFieldMapper.KeywordFieldType("kwd"),
+            ElementType.BYTES_REF,
             checks::tags,
             StatusChecks::keywordsFromDocValues
         );
         operators.add(
             new ValuesSourceReaderOperator.Factory(
-                List.of(testCase.info, fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER))),
+                List.of(testCase.info, fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER), ElementType.INT)),
                 List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                 0
             ).get(driverContext)
@@ -356,8 +368,17 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         loadSimpleAndAssert(driverContext, List.of(source), Block.MvOrdering.UNORDERED);
     }
 
-    private static ValuesSourceReaderOperator.FieldInfo fieldInfo(MappedFieldType ft) {
-        return new ValuesSourceReaderOperator.FieldInfo(ft.name(), List.of(ft.blockLoader(new MappedFieldType.BlockLoaderContext() {
+    private static ValuesSourceReaderOperator.FieldInfo fieldInfo(MappedFieldType ft, ElementType elementType) {
+        return new ValuesSourceReaderOperator.FieldInfo(ft.name(), elementType, shardIdx -> {
+            if (shardIdx != 0) {
+                fail("unexpected shardIdx [" + shardIdx + "]");
+            }
+            return ft.blockLoader(blContext());
+        });
+    }
+
+    private static MappedFieldType.BlockLoaderContext blContext() {
+        return new MappedFieldType.BlockLoaderContext() {
             @Override
             public String indexName() {
                 return "test_index";
@@ -377,7 +398,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
             public String parentField(String field) {
                 return null;
             }
-        })));
+        };
     }
 
     private void loadSimpleAndAssert(DriverContext driverContext, List<Page> input, Block.MvOrdering docValuesMvOrdering) {
@@ -386,7 +407,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         List<Operator> operators = new ArrayList<>();
         operators.add(
             new ValuesSourceReaderOperator.Factory(
-                List.of(fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER))),
+                List.of(fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER), ElementType.INT)),
                 List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                 0
             ).get(driverContext)
@@ -439,13 +460,14 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
     }
 
     record FieldCase(ValuesSourceReaderOperator.FieldInfo info, CheckResults checkResults, CheckReadersWithName checkReaders) {
-        FieldCase(MappedFieldType ft, CheckResults checkResults, CheckReadersWithName checkReaders) {
-            this(fieldInfo(ft), checkResults, checkReaders);
+        FieldCase(MappedFieldType ft, ElementType elementType, CheckResults checkResults, CheckReadersWithName checkReaders) {
+            this(fieldInfo(ft, elementType), checkResults, checkReaders);
         }
 
-        FieldCase(MappedFieldType ft, CheckResults checkResults, CheckReaders checkReaders) {
+        FieldCase(MappedFieldType ft, ElementType elementType, CheckResults checkResults, CheckReaders checkReaders) {
             this(
                 ft,
+                elementType,
                 checkResults,
                 (name, forcedRowByRow, pageCount, segmentCount, readersBuilt) -> checkReaders.check(
                     forcedRowByRow,
@@ -506,11 +528,17 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         Checks checks = new Checks(docValuesMvOrdering);
         List<FieldCase> r = new ArrayList<>();
         r.add(
-            new FieldCase(docValuesNumberField("long", NumberFieldMapper.NumberType.LONG), checks::longs, StatusChecks::longsFromDocValues)
+            new FieldCase(
+                docValuesNumberField("long", NumberFieldMapper.NumberType.LONG),
+                ElementType.LONG,
+                checks::longs,
+                StatusChecks::longsFromDocValues
+            )
         );
         r.add(
             new FieldCase(
                 docValuesNumberField("mv_long", NumberFieldMapper.NumberType.LONG),
+                ElementType.LONG,
                 checks::mvLongsFromDocValues,
                 StatusChecks::mvLongsFromDocValues
             )
@@ -518,26 +546,39 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 docValuesNumberField("missing_long", NumberFieldMapper.NumberType.LONG),
+                ElementType.LONG,
                 checks::constantNulls,
                 StatusChecks::constantNulls
             )
         );
         r.add(
-            new FieldCase(sourceNumberField("source_long", NumberFieldMapper.NumberType.LONG), checks::longs, StatusChecks::longsFromSource)
+            new FieldCase(
+                sourceNumberField("source_long", NumberFieldMapper.NumberType.LONG),
+                ElementType.LONG,
+                checks::longs,
+                StatusChecks::longsFromSource
+            )
         );
         r.add(
             new FieldCase(
                 sourceNumberField("mv_source_long", NumberFieldMapper.NumberType.LONG),
+                ElementType.LONG,
                 checks::mvLongsUnordered,
                 StatusChecks::mvLongsFromSource
             )
         );
         r.add(
-            new FieldCase(docValuesNumberField("int", NumberFieldMapper.NumberType.INTEGER), checks::ints, StatusChecks::intsFromDocValues)
+            new FieldCase(
+                docValuesNumberField("int", NumberFieldMapper.NumberType.INTEGER),
+                ElementType.INT,
+                checks::ints,
+                StatusChecks::intsFromDocValues
+            )
         );
         r.add(
             new FieldCase(
                 docValuesNumberField("mv_int", NumberFieldMapper.NumberType.INTEGER),
+                ElementType.INT,
                 checks::mvIntsFromDocValues,
                 StatusChecks::mvIntsFromDocValues
             )
@@ -545,16 +586,23 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 docValuesNumberField("missing_int", NumberFieldMapper.NumberType.INTEGER),
+                ElementType.INT,
                 checks::constantNulls,
                 StatusChecks::constantNulls
             )
         );
         r.add(
-            new FieldCase(sourceNumberField("source_int", NumberFieldMapper.NumberType.INTEGER), checks::ints, StatusChecks::intsFromSource)
+            new FieldCase(
+                sourceNumberField("source_int", NumberFieldMapper.NumberType.INTEGER),
+                ElementType.INT,
+                checks::ints,
+                StatusChecks::intsFromSource
+            )
         );
         r.add(
             new FieldCase(
                 sourceNumberField("mv_source_int", NumberFieldMapper.NumberType.INTEGER),
+                ElementType.INT,
                 checks::mvIntsUnordered,
                 StatusChecks::mvIntsFromSource
             )
@@ -562,6 +610,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 docValuesNumberField("short", NumberFieldMapper.NumberType.SHORT),
+                ElementType.INT,
                 checks::shorts,
                 StatusChecks::shortsFromDocValues
             )
@@ -569,6 +618,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 docValuesNumberField("mv_short", NumberFieldMapper.NumberType.SHORT),
+                ElementType.INT,
                 checks::mvShorts,
                 StatusChecks::mvShortsFromDocValues
             )
@@ -576,16 +626,23 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 docValuesNumberField("missing_short", NumberFieldMapper.NumberType.SHORT),
+                ElementType.INT,
                 checks::constantNulls,
                 StatusChecks::constantNulls
             )
         );
         r.add(
-            new FieldCase(docValuesNumberField("byte", NumberFieldMapper.NumberType.BYTE), checks::bytes, StatusChecks::bytesFromDocValues)
+            new FieldCase(
+                docValuesNumberField("byte", NumberFieldMapper.NumberType.BYTE),
+                ElementType.INT,
+                checks::bytes,
+                StatusChecks::bytesFromDocValues
+            )
         );
         r.add(
             new FieldCase(
                 docValuesNumberField("mv_byte", NumberFieldMapper.NumberType.BYTE),
+                ElementType.INT,
                 checks::mvBytes,
                 StatusChecks::mvBytesFromDocValues
             )
@@ -593,6 +650,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 docValuesNumberField("missing_byte", NumberFieldMapper.NumberType.BYTE),
+                ElementType.INT,
                 checks::constantNulls,
                 StatusChecks::constantNulls
             )
@@ -600,6 +658,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 docValuesNumberField("double", NumberFieldMapper.NumberType.DOUBLE),
+                ElementType.DOUBLE,
                 checks::doubles,
                 StatusChecks::doublesFromDocValues
             )
@@ -607,6 +666,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 docValuesNumberField("mv_double", NumberFieldMapper.NumberType.DOUBLE),
+                ElementType.DOUBLE,
                 checks::mvDoubles,
                 StatusChecks::mvDoublesFromDocValues
             )
@@ -614,39 +674,106 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 docValuesNumberField("missing_double", NumberFieldMapper.NumberType.DOUBLE),
+                ElementType.DOUBLE,
                 checks::constantNulls,
                 StatusChecks::constantNulls
             )
         );
-        r.add(new FieldCase(new BooleanFieldMapper.BooleanFieldType("bool"), checks::bools, StatusChecks::boolFromDocValues));
-        r.add(new FieldCase(new BooleanFieldMapper.BooleanFieldType("mv_bool"), checks::mvBools, StatusChecks::mvBoolFromDocValues));
-        r.add(new FieldCase(new BooleanFieldMapper.BooleanFieldType("missing_bool"), checks::constantNulls, StatusChecks::constantNulls));
-        r.add(new FieldCase(new KeywordFieldMapper.KeywordFieldType("kwd"), checks::tags, StatusChecks::keywordsFromDocValues));
+        r.add(
+            new FieldCase(
+                new BooleanFieldMapper.BooleanFieldType("bool"),
+                ElementType.BOOLEAN,
+                checks::bools,
+                StatusChecks::boolFromDocValues
+            )
+        );
+        r.add(
+            new FieldCase(
+                new BooleanFieldMapper.BooleanFieldType("mv_bool"),
+                ElementType.BOOLEAN,
+                checks::mvBools,
+                StatusChecks::mvBoolFromDocValues
+            )
+        );
+        r.add(
+            new FieldCase(
+                new BooleanFieldMapper.BooleanFieldType("missing_bool"),
+                ElementType.BOOLEAN,
+                checks::constantNulls,
+                StatusChecks::constantNulls
+            )
+        );
+        r.add(
+            new FieldCase(
+                new KeywordFieldMapper.KeywordFieldType("kwd"),
+                ElementType.BYTES_REF,
+                checks::tags,
+                StatusChecks::keywordsFromDocValues
+            )
+        );
         r.add(
             new FieldCase(
                 new KeywordFieldMapper.KeywordFieldType("mv_kwd"),
+                ElementType.BYTES_REF,
                 checks::mvStringsFromDocValues,
                 StatusChecks::mvKeywordsFromDocValues
             )
         );
-        r.add(new FieldCase(new KeywordFieldMapper.KeywordFieldType("missing_kwd"), checks::constantNulls, StatusChecks::constantNulls));
-        r.add(new FieldCase(storedKeywordField("stored_kwd"), checks::strings, StatusChecks::keywordsFromStored));
-        r.add(new FieldCase(storedKeywordField("mv_stored_kwd"), checks::mvStringsUnordered, StatusChecks::mvKeywordsFromStored));
-        r.add(new FieldCase(sourceKeywordField("source_kwd"), checks::strings, StatusChecks::keywordsFromSource));
-        r.add(new FieldCase(sourceKeywordField("mv_source_kwd"), checks::mvStringsUnordered, StatusChecks::mvKeywordsFromSource));
-        r.add(new FieldCase(new TextFieldMapper.TextFieldType("source_text", false), checks::strings, StatusChecks::textFromSource));
+        r.add(
+            new FieldCase(
+                new KeywordFieldMapper.KeywordFieldType("missing_kwd"),
+                ElementType.BYTES_REF,
+                checks::constantNulls,
+                StatusChecks::constantNulls
+            )
+        );
+        r.add(new FieldCase(storedKeywordField("stored_kwd"), ElementType.BYTES_REF, checks::strings, StatusChecks::keywordsFromStored));
+        r.add(
+            new FieldCase(
+                storedKeywordField("mv_stored_kwd"),
+                ElementType.BYTES_REF,
+                checks::mvStringsUnordered,
+                StatusChecks::mvKeywordsFromStored
+            )
+        );
+        r.add(new FieldCase(sourceKeywordField("source_kwd"), ElementType.BYTES_REF, checks::strings, StatusChecks::keywordsFromSource));
+        r.add(
+            new FieldCase(
+                sourceKeywordField("mv_source_kwd"),
+                ElementType.BYTES_REF,
+                checks::mvStringsUnordered,
+                StatusChecks::mvKeywordsFromSource
+            )
+        );
+        r.add(
+            new FieldCase(
+                new TextFieldMapper.TextFieldType("source_text", false),
+                ElementType.BYTES_REF,
+                checks::strings,
+                StatusChecks::textFromSource
+            )
+        );
         r.add(
             new FieldCase(
                 new TextFieldMapper.TextFieldType("mv_source_text", false),
+                ElementType.BYTES_REF,
                 checks::mvStringsUnordered,
                 StatusChecks::mvTextFromSource
             )
         );
-        r.add(new FieldCase(storedTextField("stored_text"), checks::strings, StatusChecks::textFromStored));
-        r.add(new FieldCase(storedTextField("mv_stored_text"), checks::mvStringsUnordered, StatusChecks::mvTextFromStored));
+        r.add(new FieldCase(storedTextField("stored_text"), ElementType.BYTES_REF, checks::strings, StatusChecks::textFromStored));
+        r.add(
+            new FieldCase(
+                storedTextField("mv_stored_text"),
+                ElementType.BYTES_REF,
+                checks::mvStringsUnordered,
+                StatusChecks::mvTextFromStored
+            )
+        );
         r.add(
             new FieldCase(
                 textFieldWithDelegate("text_with_delegate", new KeywordFieldMapper.KeywordFieldType("kwd")),
+                ElementType.BYTES_REF,
                 checks::tags,
                 StatusChecks::textWithDelegate
             )
@@ -654,6 +781,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 textFieldWithDelegate("mv_text_with_delegate", new KeywordFieldMapper.KeywordFieldType("mv_kwd")),
+                ElementType.BYTES_REF,
                 checks::mvStringsFromDocValues,
                 StatusChecks::mvTextWithDelegate
             )
@@ -661,22 +789,27 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         r.add(
             new FieldCase(
                 textFieldWithDelegate("missing_text_with_delegate", new KeywordFieldMapper.KeywordFieldType("missing_kwd")),
+                ElementType.BYTES_REF,
                 checks::constantNulls,
                 StatusChecks::constantNullTextWithDelegate
             )
         );
-        r.add(new FieldCase(new ProvidedIdFieldMapper(() -> false).fieldType(), checks::ids, StatusChecks::id));
-        r.add(new FieldCase(TsidExtractingIdFieldMapper.INSTANCE.fieldType(), checks::ids, StatusChecks::id));
+        r.add(new FieldCase(new ProvidedIdFieldMapper(() -> false).fieldType(), ElementType.BYTES_REF, checks::ids, StatusChecks::id));
+        r.add(new FieldCase(TsidExtractingIdFieldMapper.INSTANCE.fieldType(), ElementType.BYTES_REF, checks::ids, StatusChecks::id));
         r.add(
             new FieldCase(
-                new ValuesSourceReaderOperator.FieldInfo("constant_bytes", List.of(BlockLoader.constantBytes(new BytesRef("foo")))),
+                new ValuesSourceReaderOperator.FieldInfo(
+                    "constant_bytes",
+                    ElementType.BYTES_REF,
+                    shardIdx -> BlockLoader.constantBytes(new BytesRef("foo"))
+                ),
                 checks::constantBytes,
                 StatusChecks::constantBytes
             )
         );
         r.add(
             new FieldCase(
-                new ValuesSourceReaderOperator.FieldInfo("null", List.of(BlockLoader.CONSTANT_NULLS)),
+                new ValuesSourceReaderOperator.FieldInfo("null", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS),
                 checks::constantNulls,
                 StatusChecks::constantNulls
             )
@@ -1149,7 +1282,7 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
 
         DriverContext driverContext = driverContext();
         var luceneFactory = new LuceneSourceOperator.Factory(
-            List.of(mockSearchContext(reader)),
+            List.of(mockSearchContext(reader, 0)),
             ctx -> new MatchAllDocsQuery(),
             randomFrom(DataPartitioning.values()),
             randomIntBetween(1, 10),
@@ -1161,10 +1294,10 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
                 driverContext,
                 luceneFactory.get(driverContext),
                 List.of(
-                    factory(reader, intFt).get(driverContext),
-                    factory(reader, longFt).get(driverContext),
-                    factory(reader, doubleFt).get(driverContext),
-                    factory(reader, kwFt).get(driverContext)
+                    factory(reader, intFt, ElementType.INT).get(driverContext),
+                    factory(reader, longFt, ElementType.LONG).get(driverContext),
+                    factory(reader, doubleFt, ElementType.DOUBLE).get(driverContext),
+                    factory(reader, kwFt, ElementType.BYTES_REF).get(driverContext)
                 ),
                 new PageConsumerOperator(page -> {
                     try {
@@ -1286,8 +1419,8 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
                 List.of(
                     new ValuesSourceReaderOperator.Factory(
                         List.of(
-                            new ValuesSourceReaderOperator.FieldInfo("null1", List.of(BlockLoader.CONSTANT_NULLS)),
-                            new ValuesSourceReaderOperator.FieldInfo("null2", List.of(BlockLoader.CONSTANT_NULLS))
+                            new ValuesSourceReaderOperator.FieldInfo("null1", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS),
+                            new ValuesSourceReaderOperator.FieldInfo("null2", ElementType.NULL, shardIdx -> BlockLoader.CONSTANT_NULLS)
                         ),
                         List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
                         0
@@ -1331,8 +1464,8 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
         assertTrue(source.get(0).<DocBlock>getBlock(0).asVector().singleSegmentNonDecreasing());
         Operator op = new ValuesSourceReaderOperator.Factory(
             List.of(
-                fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER)),
-                fieldInfo(storedTextField("stored_text"))
+                fieldInfo(docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER), ElementType.INT),
+                fieldInfo(storedTextField("stored_text"), ElementType.BYTES_REF)
             ),
             List.of(new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE)),
             0
@@ -1368,4 +1501,63 @@ public class ValuesSourceReaderOperatorTests extends OperatorTestCase {
             assertThat(op.toString(), equalTo("ValuesSourceReaderOperator[fields = [" + cases.size() + " fields]]"));
         }
     }
+
+    public void testManyShards() throws IOException {
+        int shardCount = between(2, 10);
+        int size = between(100, 1000);
+        Directory[] dirs = new Directory[shardCount];
+        IndexReader[] readers = new IndexReader[shardCount];
+        Closeable[] closeMe = new Closeable[shardCount * 2];
+        Set<Integer> seenShards = new TreeSet<>();
+        Map<Integer, Integer> keyCounts = new TreeMap<>();
+        try {
+            for (int d = 0; d < dirs.length; d++) {
+                closeMe[d * 2 + 1] = dirs[d] = newDirectory();
+                closeMe[d * 2] = readers[d] = initIndex(dirs[d], size, between(10, size * 2));
+            }
+            List<SearchContext> contexts = new ArrayList<>();
+            List<ValuesSourceReaderOperator.ShardContext> readerShardContexts = new ArrayList<>();
+            for (int s = 0; s < shardCount; s++) {
+                contexts.add(mockSearchContext(readers[s], s));
+                readerShardContexts.add(new ValuesSourceReaderOperator.ShardContext(readers[s], () -> SourceLoader.FROM_STORED_SOURCE));
+            }
+            var luceneFactory = new LuceneSourceOperator.Factory(
+                contexts,
+                ctx -> new MatchAllDocsQuery(),
+                DataPartitioning.SHARD,
+                randomIntBetween(1, 10),
+                1000,
+                LuceneOperator.NO_LIMIT
+            );
+            MappedFieldType ft = docValuesNumberField("key", NumberFieldMapper.NumberType.INTEGER);
+            var readerFactory = new ValuesSourceReaderOperator.Factory(
+                List.of(new ValuesSourceReaderOperator.FieldInfo("key", ElementType.INT, shardIdx -> {
+                    seenShards.add(shardIdx);
+                    return ft.blockLoader(blContext());
+                })),
+                readerShardContexts,
+                0
+            );
+            DriverContext driverContext = driverContext();
+            List<Page> results = drive(
+                readerFactory.get(driverContext),
+                CannedSourceOperator.collectPages(luceneFactory.get(driverContext)).iterator(),
+                driverContext
+            );
+            assertThat(seenShards, equalTo(IntStream.range(0, shardCount).boxed().collect(Collectors.toCollection(TreeSet::new))));
+            for (Page p : results) {
+                IntBlock keyBlock = p.getBlock(1);
+                IntVector keys = keyBlock.asVector();
+                for (int i = 0; i < keys.getPositionCount(); i++) {
+                    keyCounts.merge(keys.getInt(i), 1, (prev, one) -> prev + one);
+                }
+            }
+            assertThat(keyCounts.keySet(), hasSize(size));
+            for (int k = 0; k < size; k++) {
+                assertThat(keyCounts.get(k), equalTo(shardCount));
+            }
+        } finally {
+            IOUtils.close(closeMe);
+        }
+    }
 }

+ 15 - 3
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java

@@ -42,6 +42,7 @@ import org.elasticsearch.compute.operator.SourceOperator;
 import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.core.Releasables;
+import org.elasticsearch.index.mapper.BlockLoader;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.index.shard.ShardId;
@@ -273,12 +274,23 @@ public class EnrichLookupService {
                 NamedExpression extractField = extractFields.get(i);
                 final ElementType elementType = PlannerUtils.toElementType(extractField.dataType());
                 mergingTypes[i] = elementType;
-                var loaders = BlockReaderFactories.loaders(
-                    List.of(searchContext),
+                BlockLoader loader = BlockReaderFactories.loader(
+                    searchContext.getSearchExecutionContext(),
                     extractField instanceof Alias a ? ((NamedExpression) a.child()).name() : extractField.name(),
                     EsqlDataTypes.isUnsupported(extractField.dataType())
                 );
-                fields.add(new ValuesSourceReaderOperator.FieldInfo(extractField.name(), loaders));
+                fields.add(
+                    new ValuesSourceReaderOperator.FieldInfo(
+                        extractField.name(),
+                        PlannerUtils.toElementType(extractField.dataType()),
+                        shardIdx -> {
+                            if (shardIdx != 0) {
+                                throw new IllegalStateException("only one shard");
+                            }
+                            return loader;
+                        }
+                    )
+                );
             }
             intermediateOperators.add(
                 new ValuesSourceReaderOperator(

+ 15 - 3
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java

@@ -43,6 +43,7 @@ import org.elasticsearch.xpack.ql.type.DataType;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.function.Function;
+import java.util.function.IntFunction;
 
 import static org.elasticsearch.common.lucene.search.Queries.newNonNestedFilter;
 import static org.elasticsearch.compute.lucene.LuceneSourceOperator.NO_LIMIT;
@@ -74,9 +75,15 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
             }
             layout.append(attr);
             DataType dataType = attr.dataType();
+            ElementType elementType = PlannerUtils.toElementType(dataType);
             String fieldName = attr.name();
-            List<BlockLoader> loaders = BlockReaderFactories.loaders(searchContexts, fieldName, EsqlDataTypes.isUnsupported(dataType));
-            fields.add(new ValuesSourceReaderOperator.FieldInfo(fieldName, loaders));
+            boolean isSupported = EsqlDataTypes.isUnsupported(dataType);
+            IntFunction<BlockLoader> loader = s -> BlockReaderFactories.loader(
+                searchContexts.get(s).getSearchExecutionContext(),
+                fieldName,
+                isSupported
+            );
+            fields.add(new ValuesSourceReaderOperator.FieldInfo(fieldName, elementType, loader));
         }
         return source.with(new ValuesSourceReaderOperator.Factory(fields, readers, docChannel), layout.build());
     }
@@ -165,8 +172,13 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
             .toList();
         // The grouping-by values are ready, let's group on them directly.
         // Costin: why are they ready and not already exposed in the layout?
+        boolean isUnsupported = EsqlDataTypes.isUnsupported(attrSource.dataType());
         return new OrdinalsGroupingOperator.OrdinalsGroupingOperatorFactory(
-            BlockReaderFactories.loaders(searchContexts, attrSource.name(), EsqlDataTypes.isUnsupported(attrSource.dataType())),
+            shardIdx -> BlockReaderFactories.loader(
+                searchContexts.get(shardIdx).getSearchExecutionContext(),
+                attrSource.name(),
+                isUnsupported
+            ),
             shardContexts,
             groupElementType,
             docChannel,