Jelajahi Sumber

[ES|QL] Infer the score mode to use from the Lucene collector (#125930)

This change uses the Lucene collector to infer which score mode to use
when the topN collector is used.
Jim Ferenczi 6 bulan lalu
induk
melakukan
42b7b78a31

+ 5 - 0
docs/changelog/125930.yaml

@@ -0,0 +1,5 @@
+pr: 125930
+summary: Infer the score mode to use from the Lucene collector
+area: "ES|QL"
+type: enhancement
+issues: []

+ 1 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java

@@ -49,7 +49,7 @@ public class LuceneCountOperator extends LuceneOperator {
             int taskConcurrency,
             int limit
         ) {
-            super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
+            super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
         }
 
         @Override

+ 3 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java

@@ -23,6 +23,8 @@ import java.io.IOException;
 import java.util.List;
 import java.util.function.Function;
 
+import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
+
 /**
  * Factory that generates an operator that finds the max value of a field using the {@link LuceneMinMaxOperator}.
  */
@@ -121,7 +123,7 @@ public final class LuceneMaxFactory extends LuceneOperator.Factory {
         NumberType numberType,
         int limit
     ) {
-        super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
+        super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
         this.fieldName = fieldName;
         this.numberType = numberType;
     }

+ 3 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java

@@ -23,6 +23,8 @@ import java.io.IOException;
 import java.util.List;
 import java.util.function.Function;
 
+import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
+
 /**
  * Factory that generates an operator that finds the min value of a field using the {@link LuceneMinMaxOperator}.
  */
@@ -121,7 +123,7 @@ public final class LuceneMinFactory extends LuceneOperator.Factory {
         NumberType numberType,
         int limit
     ) {
-        super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
+        super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
         this.fieldName = fieldName;
         this.numberType = numberType;
     }

+ 5 - 7
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java

@@ -11,7 +11,6 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.BulkScorer;
 import org.apache.lucene.search.ConstantScoreQuery;
 import org.apache.lucene.search.DocIdSetIterator;
-import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.LeafCollector;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreMode;
@@ -84,28 +83,27 @@ public abstract class LuceneOperator extends SourceOperator {
         protected final DataPartitioning dataPartitioning;
         protected final int taskConcurrency;
         protected final int limit;
-        protected final ScoreMode scoreMode;
+        protected final boolean needsScore;
         protected final LuceneSliceQueue sliceQueue;
 
         /**
          * Build the factory.
          *
-         * @param scoreMode the {@link ScoreMode} passed to {@link IndexSearcher#createWeight}
+         * @param needsScore Whether the score is needed.
          */
         protected Factory(
             List<? extends ShardContext> contexts,
-            Function<ShardContext, Query> queryFunction,
+            Function<ShardContext, Weight> weightFunction,
             DataPartitioning dataPartitioning,
             int taskConcurrency,
             int limit,
-            ScoreMode scoreMode
+            boolean needsScore
         ) {
             this.limit = limit;
-            this.scoreMode = scoreMode;
             this.dataPartitioning = dataPartitioning;
-            var weightFunction = weightFunction(queryFunction, scoreMode);
             this.sliceQueue = LuceneSliceQueue.create(contexts, weightFunction, dataPartitioning, taskConcurrency);
             this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency);
+            this.needsScore = needsScore;
         }
 
         public final int taskConcurrency() {

+ 14 - 8
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java

@@ -11,7 +11,6 @@ import org.apache.lucene.search.CollectionTerminatedException;
 import org.apache.lucene.search.LeafCollector;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.Scorable;
-import org.apache.lucene.search.ScoreMode;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.DocBlock;
 import org.elasticsearch.compute.data.DocVector;
@@ -56,9 +55,16 @@ public class LuceneSourceOperator extends LuceneOperator {
             int taskConcurrency,
             int maxPageSize,
             int limit,
-            boolean scoring
+            boolean needsScore
         ) {
-            super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : COMPLETE_NO_SCORES);
+            super(
+                contexts,
+                weightFunction(queryFunction, needsScore ? COMPLETE : COMPLETE_NO_SCORES),
+                dataPartitioning,
+                taskConcurrency,
+                limit,
+                needsScore
+            );
             this.maxPageSize = maxPageSize;
             // TODO: use a single limiter for multiple stage execution
             this.limiter = limit == NO_LIMIT ? Limiter.NO_LIMIT : new Limiter(limit);
@@ -66,7 +72,7 @@ public class LuceneSourceOperator extends LuceneOperator {
 
         @Override
         public SourceOperator get(DriverContext driverContext) {
-            return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, scoreMode);
+            return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, needsScore);
         }
 
         public int maxPageSize() {
@@ -81,8 +87,8 @@ public class LuceneSourceOperator extends LuceneOperator {
                 + maxPageSize
                 + ", limit = "
                 + limit
-                + ", scoreMode = "
-                + scoreMode
+                + ", needsScore = "
+                + needsScore
                 + "]";
         }
     }
@@ -94,7 +100,7 @@ public class LuceneSourceOperator extends LuceneOperator {
         LuceneSliceQueue sliceQueue,
         int limit,
         Limiter limiter,
-        ScoreMode scoreMode
+        boolean needsScore
     ) {
         super(blockFactory, maxPageSize, sliceQueue);
         this.minPageSize = Math.max(1, maxPageSize / 2);
@@ -104,7 +110,7 @@ public class LuceneSourceOperator extends LuceneOperator {
         boolean success = false;
         try {
             this.docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
-            if (scoreMode.needsScores()) {
+            if (needsScore) {
                 scoreBuilder = blockFactory.newDoubleVectorBuilder(estimatedSize);
                 this.leafCollector = new ScoringCollector();
             } else {

+ 54 - 41
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java

@@ -14,12 +14,12 @@ import org.apache.lucene.search.FieldDoc;
 import org.apache.lucene.search.LeafCollector;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreDoc;
-import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.Sort;
 import org.apache.lucene.search.SortField;
 import org.apache.lucene.search.TopDocsCollector;
 import org.apache.lucene.search.TopFieldCollectorManager;
 import org.apache.lucene.search.TopScoreDocCollectorManager;
+import org.apache.lucene.search.Weight;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.DocBlock;
@@ -36,6 +36,7 @@ import org.elasticsearch.search.sort.SortAndFormats;
 import org.elasticsearch.search.sort.SortBuilder;
 
 import java.io.IOException;
+import java.io.UncheckedIOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
@@ -43,9 +44,6 @@ import java.util.Optional;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
-import static org.apache.lucene.search.ScoreMode.TOP_DOCS;
-import static org.apache.lucene.search.ScoreMode.TOP_DOCS_WITH_SCORES;
-
 /**
  * Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
  */
@@ -63,16 +61,16 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
             int maxPageSize,
             int limit,
             List<SortBuilder<?>> sorts,
-            boolean scoring
+            boolean needsScore
         ) {
-            super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? TOP_DOCS_WITH_SCORES : TOP_DOCS);
+            super(contexts, weightFunction(queryFunction, sorts, needsScore), dataPartitioning, taskConcurrency, limit, needsScore);
             this.maxPageSize = maxPageSize;
             this.sorts = sorts;
         }
 
         @Override
         public SourceOperator get(DriverContext driverContext) {
-            return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, scoreMode);
+            return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, needsScore);
         }
 
         public int maxPageSize() {
@@ -88,8 +86,8 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
                 + maxPageSize
                 + ", limit = "
                 + limit
-                + ", scoreMode = "
-                + scoreMode
+                + ", needsScore = "
+                + needsScore
                 + ", sorts = ["
                 + notPrettySorts
                 + "]]";
@@ -108,7 +106,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
     private PerShardCollector perShardCollector;
     private final List<SortBuilder<?>> sorts;
     private final int limit;
-    private final ScoreMode scoreMode;
+    private final boolean needsScore;
 
     public LuceneTopNSourceOperator(
         BlockFactory blockFactory,
@@ -116,12 +114,12 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
         List<SortBuilder<?>> sorts,
         int limit,
         LuceneSliceQueue sliceQueue,
-        ScoreMode scoreMode
+        boolean needsScore
     ) {
         super(blockFactory, maxPageSize, sliceQueue);
         this.sorts = sorts;
         this.limit = limit;
-        this.scoreMode = scoreMode;
+        this.needsScore = needsScore;
     }
 
     @Override
@@ -163,7 +161,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
         try {
             if (perShardCollector == null || perShardCollector.shardContext.index() != scorer.shardContext().index()) {
                 // TODO: share the bottom between shardCollectors
-                perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, limit);
+                perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, needsScore, limit);
             }
             var leafCollector = perShardCollector.getLeafCollector(scorer.leafReaderContext());
             scorer.scoreNextRange(leafCollector, scorer.leafReaderContext().reader().getLiveDocs(), maxPageSize);
@@ -261,7 +259,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
     }
 
     private DoubleVector.Builder scoreVectorOrNull(int size) {
-        if (scoreMode.needsScores()) {
+        if (needsScore) {
             return blockFactory.newDoubleVectorFixedBuilder(size);
         } else {
             return null;
@@ -271,37 +269,11 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
     @Override
     protected void describe(StringBuilder sb) {
         sb.append(", limit = ").append(limit);
-        sb.append(", scoreMode = ").append(scoreMode);
+        sb.append(", needsScore = ").append(needsScore);
         String notPrettySorts = sorts.stream().map(Strings::toString).collect(Collectors.joining(","));
         sb.append(", sorts = [").append(notPrettySorts).append("]");
     }
 
-    PerShardCollector newPerShardCollector(ShardContext shardContext, List<SortBuilder<?>> sorts, int limit) throws IOException {
-        Optional<SortAndFormats> sortAndFormats = shardContext.buildSort(sorts);
-        if (sortAndFormats.isEmpty()) {
-            throw new IllegalStateException("sorts must not be disabled in TopN");
-        }
-        if (scoreMode.needsScores() == false) {
-            return new NonScoringPerShardCollector(shardContext, sortAndFormats.get().sort, limit);
-        } else {
-            SortField[] sortFields = sortAndFormats.get().sort.getSort();
-            if (sortFields != null && sortFields.length == 1 && sortFields[0].needsScores() && sortFields[0].getReverse() == false) {
-                // SORT _score DESC
-                return new ScoringPerShardCollector(shardContext, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
-            } else {
-                // SORT ..., _score, ...
-                var sort = new Sort();
-                if (sortFields != null) {
-                    var l = new ArrayList<>(Arrays.asList(sortFields));
-                    l.add(SortField.FIELD_DOC);
-                    l.add(SortField.FIELD_SCORE);
-                    sort = new Sort(l.toArray(SortField[]::new));
-                }
-                return new ScoringPerShardCollector(shardContext, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
-            }
-        }
-    }
-
     abstract static class PerShardCollector {
         private final ShardContext shardContext;
         private final TopDocsCollector<?> collector;
@@ -336,4 +308,45 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
             super(shardContext, topDocsCollector);
         }
     }
+
+    private static Function<ShardContext, Weight> weightFunction(
+        Function<ShardContext, Query> queryFunction,
+        List<SortBuilder<?>> sorts,
+        boolean needsScore
+    ) {
+        return ctx -> {
+            final var query = queryFunction.apply(ctx);
+            final var searcher = ctx.searcher();
+            try {
+                // we create a collector with a limit of 1 to determine the appropriate score mode to use.
+                var scoreMode = newPerShardCollector(ctx, sorts, needsScore, 1).collector.scoreMode();
+                return searcher.createWeight(searcher.rewrite(query), scoreMode, 1);
+            } catch (IOException e) {
+                throw new UncheckedIOException(e);
+            }
+        };
+    }
+
+    private static PerShardCollector newPerShardCollector(ShardContext context, List<SortBuilder<?>> sorts, boolean needsScore, int limit)
+        throws IOException {
+        Optional<SortAndFormats> sortAndFormats = context.buildSort(sorts);
+        if (sortAndFormats.isEmpty()) {
+            throw new IllegalStateException("sorts must not be disabled in TopN");
+        }
+        if (needsScore == false) {
+            return new NonScoringPerShardCollector(context, sortAndFormats.get().sort, limit);
+        }
+        Sort sort = sortAndFormats.get().sort;
+        if (Sort.RELEVANCE.equals(sort)) {
+            // SORT _score DESC
+            return new ScoringPerShardCollector(context, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
+        }
+
+        // SORT ..., _score, ...
+        var l = new ArrayList<>(Arrays.asList(sort.getSort()));
+        l.add(SortField.FIELD_DOC);
+        l.add(SortField.FIELD_SCORE);
+        sort = new Sort(l.toArray(SortField[]::new));
+        return new ScoringPerShardCollector(context, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
+    }
 }

+ 3 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSortedSourceOperatorFactory.java

@@ -33,6 +33,8 @@ import java.io.UncheckedIOException;
 import java.util.List;
 import java.util.function.Function;
 
+import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
+
 /**
  * Creates a source operator that takes advantage of the natural sorting of segments in a tsdb index.
  * <p>
@@ -56,7 +58,7 @@ public class TimeSeriesSortedSourceOperatorFactory extends LuceneOperator.Factor
         int maxPageSize,
         int limit
     ) {
-        super(contexts, queryFunction, DataPartitioning.SHARD, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
+        super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), DataPartitioning.SHARD, taskConcurrency, limit, false);
         this.maxPageSize = maxPageSize;
     }
 

+ 1 - 1
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java

@@ -120,7 +120,7 @@ public class LuceneSourceOperatorTests extends AnyOperatorTestCase {
     protected Matcher<String> expectedDescriptionOfSimple() {
         return matchesRegex(
             "LuceneSourceOperator"
-                + "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, scoreMode = (COMPLETE|COMPLETE_NO_SCORES)]"
+                + "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, needsScore = (true|false)]"
         );
     }
 

+ 2 - 3
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperatorScoringTests.java

@@ -110,8 +110,7 @@ public class LuceneTopNSourceOperatorScoringTests extends LuceneTopNSourceOperat
     @Override
     protected Matcher<String> expectedToStringOfSimple() {
         return matchesRegex(
-            "LuceneTopNSourceOperator\\[shards = \\[test], "
-                + "maxPageSize = \\d+, limit = 100, scoreMode = TOP_DOCS_WITH_SCORES, sorts = \\[\\{.+}]]"
+            "LuceneTopNSourceOperator\\[shards = \\[test], " + "maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]"
         );
     }
 
@@ -119,7 +118,7 @@ public class LuceneTopNSourceOperatorScoringTests extends LuceneTopNSourceOperat
     protected Matcher<String> expectedDescriptionOfSimple() {
         return matchesRegex(
             "LuceneTopNSourceOperator\\[dataPartitioning = (DOC|SHARD|SEGMENT), "
-                + "maxPageSize = \\d+, limit = 100, scoreMode = TOP_DOCS_WITH_SCORES, sorts = \\[\\{.+}]]"
+                + "maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]"
         );
     }
 

+ 5 - 5
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperatorTests.java

@@ -114,19 +114,19 @@ public class LuceneTopNSourceOperatorTests extends AnyOperatorTestCase {
 
     @Override
     protected Matcher<String> expectedToStringOfSimple() {
-        var s = scoring ? "TOP_DOCS_WITH_SCORES" : "TOP_DOCS";
         return matchesRegex(
-            "LuceneTopNSourceOperator\\[shards = \\[test], maxPageSize = \\d+, limit = 100, scoreMode = " + s + ", sorts = \\[\\{.+}]]"
+            "LuceneTopNSourceOperator\\[shards = \\[test], maxPageSize = \\d+, limit = 100, needsScore = "
+                + scoring
+                + ", sorts = \\[\\{.+}]]"
         );
     }
 
     @Override
     protected Matcher<String> expectedDescriptionOfSimple() {
-        var s = scoring ? "TOP_DOCS_WITH_SCORES" : "TOP_DOCS";
         return matchesRegex(
             "LuceneTopNSourceOperator"
-                + "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, scoreMode = "
-                + s
+                + "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, needsScore = "
+                + scoring
                 + ", sorts = \\[\\{.+}]]"
         );
     }

+ 7 - 5
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java

@@ -167,7 +167,7 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
                         \\_AggregationOperator[mode = INITIAL, aggs = sum of longs]
                         \\_ExchangeSinkOperator""".replace(
                         "sourceStatus",
-                        "dataPartitioning = SHARD, maxPageSize = " + pageSize() + ", limit = 2147483647, scoreMode = COMPLETE_NO_SCORES"
+                        "dataPartitioning = SHARD, maxPageSize = " + pageSize() + ", limit = 2147483647, needsScore = false"
                     )
                 )
             );
@@ -502,7 +502,7 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
                 [{"pause_me":{"order":"asc","missing":"_last","unmapped_type":"long"}}]""";
             String sourceStatus = "dataPartitioning = SHARD, maxPageSize = "
                 + pageSize()
-                + ", limit = 1000, scoreMode = TOP_DOCS, sorts = "
+                + ", limit = 1000, needsScore = false, sorts = "
                 + sortStatus;
             assertThat(dataTasks(tasks).get(0).description(), equalTo("""
                 \\_LuceneTopNSourceOperator[sourceStatus]
@@ -545,7 +545,7 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
             scriptPermits.release(pageSize() - prereleasedDocs);
             List<TaskInfo> tasks = getTasksRunning();
             assertThat(dataTasks(tasks).get(0).description(), equalTo("""
-                \\_LuceneSourceOperator[dataPartitioning = SHARD, maxPageSize = pageSize(), limit = limit(), scoreMode = COMPLETE_NO_SCORES]
+                \\_LuceneSourceOperator[dataPartitioning = SHARD, maxPageSize = pageSize(), limit = limit(), needsScore = false]
                 \\_ValuesSourceReaderOperator[fields = [pause_me]]
                 \\_ProjectOperator[projection = [1]]
                 \\_ExchangeSinkOperator""".replace("pageSize()", Integer.toString(pageSize())).replace("limit()", limit)));
@@ -573,8 +573,10 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
             logger.info("unblocking script");
             scriptPermits.release(pageSize());
             List<TaskInfo> tasks = getTasksRunning();
-            String sourceStatus = "dataPartitioning = SHARD, maxPageSize = pageSize(), limit = 2147483647, scoreMode = COMPLETE_NO_SCORES"
-                .replace("pageSize()", Integer.toString(pageSize()));
+            String sourceStatus = "dataPartitioning = SHARD, maxPageSize = pageSize(), limit = 2147483647, needsScore = false".replace(
+                "pageSize()",
+                Integer.toString(pageSize())
+            );
             assertThat(
                 dataTasks(tasks).get(0).description(),
                 equalTo(

+ 22 - 5
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java

@@ -31,6 +31,8 @@ import org.elasticsearch.index.IndexMode;
 import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy;
 import org.elasticsearch.index.mapper.MapperServiceTestCase;
 import org.elasticsearch.node.Node;
+import org.elasticsearch.plugins.ExtensiblePlugin;
+import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.search.internal.AliasFilter;
 import org.elasticsearch.search.internal.ContextIndexSearcher;
 import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
@@ -46,11 +48,14 @@ import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
 import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
 import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
 import org.elasticsearch.xpack.esql.session.Configuration;
+import org.elasticsearch.xpack.spatial.SpatialPlugin;
 import org.hamcrest.Matcher;
 import org.junit.After;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
@@ -58,6 +63,7 @@ import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
 
 public class LocalExecutionPlannerTests extends MapperServiceTestCase {
+
     @ParametersFactory
     public static Iterable<Object[]> parameters() throws Exception {
         List<Object[]> params = new ArrayList<>();
@@ -78,6 +84,19 @@ public class LocalExecutionPlannerTests extends MapperServiceTestCase {
         this.estimatedRowSizeIsHuge = estimatedRowSizeIsHuge;
     }
 
+    @Override
+    protected Collection<Plugin> getPlugins() {
+        var plugin = new SpatialPlugin();
+        plugin.loadExtensions(new ExtensiblePlugin.ExtensionLoader() {
+            @Override
+            public <T> List<T> loadExtensions(Class<T> extensionPointType) {
+                return List.of();
+            }
+        });
+
+        return Collections.singletonList(plugin);
+    }
+
     @After
     public void closeIndex() throws IOException {
         IOUtils.close(reader, directory, () -> Releasables.close(releasables), releasables::clear);
@@ -251,11 +270,9 @@ public class LocalExecutionPlannerTests extends MapperServiceTestCase {
         );
         for (int i = 0; i < numShards; i++) {
             shardContexts.add(
-                new EsPhysicalOperationProviders.DefaultShardContext(
-                    i,
-                    createSearchExecutionContext(createMapperService(mapping(b -> {})), searcher),
-                    AliasFilter.EMPTY
-                )
+                new EsPhysicalOperationProviders.DefaultShardContext(i, createSearchExecutionContext(createMapperService(mapping(b -> {
+                    b.startObject("point").field("type", "geo_point").endObject();
+                })), searcher), AliasFilter.EMPTY)
             );
         }
         releasables.add(searcher);