Browse Source

ESQL: Syntax support and operator for count all (#99602)

Introduce physical plan for representing query stats
Use internal aggs when pushing down count
Add support for count all outside Lucene
Costin Leau 2 years ago
parent
commit
f883dd9856
26 changed files with 1180 additions and 263 deletions
  1. 12 3
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java
  2. 39 25
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java
  3. 163 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java
  4. 9 3
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java
  5. 155 0
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java
  6. 25 12
      x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java
  7. 1 1
      x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java
  8. 36 0
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec
  9. 42 1
      x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java
  10. 5 1
      x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
  11. 4 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
  12. 75 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java
  13. 1 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseLexer.java
  14. 1 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.interp
  15. 249 177
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.java
  16. 12 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseListener.java
  17. 7 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseVisitor.java
  18. 14 4
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java
  19. 8 2
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java
  20. 11 6
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java
  21. 128 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsStatsQueryExec.java
  22. 31 3
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
  23. 17 5
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java
  24. 38 17
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java
  25. 91 2
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java
  26. 6 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java

+ 12 - 3
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountAggregatorFunction.java

@@ -49,6 +49,7 @@ public class CountAggregatorFunction implements AggregatorFunction {
 
     private final LongState state;
     private final List<Integer> channels;
+    private final boolean countAll;
 
     public static CountAggregatorFunction create(List<Integer> inputChannels) {
         return new CountAggregatorFunction(inputChannels, new LongState());
@@ -57,6 +58,8 @@ public class CountAggregatorFunction implements AggregatorFunction {
     private CountAggregatorFunction(List<Integer> channels, LongState state) {
         this.channels = channels;
         this.state = state;
+        // no channels specified means count-all/count(*)
+        this.countAll = channels.isEmpty();
     }
 
     @Override
@@ -64,17 +67,23 @@ public class CountAggregatorFunction implements AggregatorFunction {
         return intermediateStateDesc().size();
     }
 
+    private int blockIndex() {
+        return countAll ? 0 : channels.get(0);
+    }
+
     @Override
     public void addRawInput(Page page) {
-        Block block = page.getBlock(channels.get(0));
+        Block block = page.getBlock(blockIndex());
         LongState state = this.state;
-        state.longValue(state.longValue() + block.getTotalValueCount());
+        int count = countAll ? block.getPositionCount() : block.getTotalValueCount();
+        state.longValue(state.longValue() + count);
     }
 
     @Override
     public void addIntermediateInput(Page page) {
         assert channels.size() == intermediateBlockCount();
-        assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
+        var blockIndex = blockIndex();
+        assert page.getBlockCount() >= blockIndex + intermediateStateDesc().size();
         LongVector count = page.<LongBlock>getBlock(channels.get(0)).asVector();
         BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();
         assert count.getPositionCount() == 1;

+ 39 - 25
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/CountGroupingAggregatorFunction.java

@@ -30,6 +30,7 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
 
     private final LongArrayState state;
     private final List<Integer> channels;
+    private final boolean countAll;
 
     public static CountGroupingAggregatorFunction create(BigArrays bigArrays, List<Integer> inputChannels) {
         return new CountGroupingAggregatorFunction(inputChannels, new LongArrayState(bigArrays, 0));
@@ -42,6 +43,11 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
     private CountGroupingAggregatorFunction(List<Integer> channels, LongArrayState state) {
         this.channels = channels;
         this.state = state;
+        this.countAll = channels.isEmpty();
+    }
+
+    private int blockIndex() {
+        return countAll ? 0 : channels.get(0);
     }
 
     @Override
@@ -51,33 +57,35 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
 
     @Override
     public AddInput prepareProcessPage(SeenGroupIds seenGroupIds, Page page) {
-        Block valuesBlock = page.getBlock(channels.get(0));
-        if (valuesBlock.areAllValuesNull()) {
-            state.enableGroupIdTracking(seenGroupIds);
-            return new AddInput() { // TODO return null meaning "don't collect me" and skip those
-                @Override
-                public void add(int positionOffset, IntBlock groupIds) {}
-
-                @Override
-                public void add(int positionOffset, IntVector groupIds) {}
-            };
-        }
-        Vector valuesVector = valuesBlock.asVector();
-        if (valuesVector == null) {
-            if (valuesBlock.mayHaveNulls()) {
+        Block valuesBlock = page.getBlock(blockIndex());
+        if (countAll == false) {
+            if (valuesBlock.areAllValuesNull()) {
                 state.enableGroupIdTracking(seenGroupIds);
-            }
-            return new AddInput() {
-                @Override
-                public void add(int positionOffset, IntBlock groupIds) {
-                    addRawInput(positionOffset, groupIds, valuesBlock);
-                }
+                return new AddInput() { // TODO return null meaning "don't collect me" and skip those
+                    @Override
+                    public void add(int positionOffset, IntBlock groupIds) {}
 
-                @Override
-                public void add(int positionOffset, IntVector groupIds) {
-                    addRawInput(positionOffset, groupIds, valuesBlock);
+                    @Override
+                    public void add(int positionOffset, IntVector groupIds) {}
+                };
+            }
+            Vector valuesVector = valuesBlock.asVector();
+            if (valuesVector == null) {
+                if (valuesBlock.mayHaveNulls()) {
+                    state.enableGroupIdTracking(seenGroupIds);
                 }
-            };
+                return new AddInput() {
+                    @Override
+                    public void add(int positionOffset, IntBlock groupIds) {
+                        addRawInput(positionOffset, groupIds, valuesBlock);
+                    }
+
+                    @Override
+                    public void add(int positionOffset, IntVector groupIds) {
+                        addRawInput(positionOffset, groupIds, valuesBlock);
+                    }
+                };
+            }
         }
         return new AddInput() {
             @Override
@@ -121,6 +129,9 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
         }
     }
 
+    /**
+     * This method is called for count all.
+     */
     private void addRawInput(IntVector groups) {
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             int groupId = Math.toIntExact(groups.getInt(groupPosition));
@@ -128,6 +139,9 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
         }
     }
 
+    /**
+     * This method is called for count all.
+     */
     private void addRawInput(IntBlock groups) {
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             // TODO remove the check one we don't emit null anymore
@@ -146,7 +160,7 @@ public class CountGroupingAggregatorFunction implements GroupingAggregatorFuncti
     @Override
     public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
         assert channels.size() == intermediateBlockCount();
-        assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
+        assert page.getBlockCount() >= blockIndex() + intermediateStateDesc().size();
         state.enableGroupIdTracking(new SeenGroupIds.Empty());
         LongVector count = page.<LongBlock>getBlock(channels.get(0)).asVector();
         BooleanVector seen = page.<BooleanBlock>getBlock(channels.get(1)).asVector();

+ 163 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java

@@ -0,0 +1,163 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.lucene;
+
+import org.apache.lucene.search.LeafCollector;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.Scorable;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Weight;
+import org.elasticsearch.compute.data.BooleanBlock;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.SourceOperator;
+import org.elasticsearch.search.internal.SearchContext;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.List;
+import java.util.function.Function;
+
+/**
+ * Source operator that incrementally counts the results in Lucene searches
+ * Returns always one entry that mimics the Count aggregation internal state:
+ * 1. the count as a long (0 if no doc is seen)
+ * 2. a bool flag (seen) that's always true meaning that the group (all items) always exists
+ */
+public class LuceneCountOperator extends LuceneOperator {
+
+    private static final int PAGE_SIZE = 1;
+
+    private int totalHits = 0;
+    private int remainingDocs;
+
+    private final LeafCollector leafCollector;
+
+    public static class Factory implements LuceneOperator.Factory {
+        private final DataPartitioning dataPartitioning;
+        private final int taskConcurrency;
+        private final int limit;
+        private final LuceneSliceQueue sliceQueue;
+
+        public Factory(
+            List<SearchContext> searchContexts,
+            Function<SearchContext, Query> queryFunction,
+            DataPartitioning dataPartitioning,
+            int taskConcurrency,
+            int limit
+        ) {
+            this.limit = limit;
+            this.dataPartitioning = dataPartitioning;
+            var weightFunction = weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES);
+            this.sliceQueue = LuceneSliceQueue.create(searchContexts, weightFunction, dataPartitioning, taskConcurrency);
+            this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency);
+        }
+
+        @Override
+        public SourceOperator get(DriverContext driverContext) {
+            return new LuceneCountOperator(sliceQueue, limit);
+        }
+
+        @Override
+        public int taskConcurrency() {
+            return taskConcurrency;
+        }
+
+        public int limit() {
+            return limit;
+        }
+
+        @Override
+        public String describe() {
+            return "LuceneCountOperator[dataPartitioning = " + dataPartitioning + ", limit = " + limit + "]";
+        }
+    }
+
+    public LuceneCountOperator(LuceneSliceQueue sliceQueue, int limit) {
+        super(PAGE_SIZE, sliceQueue);
+        this.remainingDocs = limit;
+        this.leafCollector = new LeafCollector() {
+            @Override
+            public void setScorer(Scorable scorer) {}
+
+            @Override
+            public void collect(int doc) {
+                if (remainingDocs > 0) {
+                    remainingDocs--;
+                    totalHits++;
+                }
+            }
+        };
+    }
+
+    @Override
+    public boolean isFinished() {
+        return doneCollecting || remainingDocs == 0;
+    }
+
+    @Override
+    public void finish() {
+        doneCollecting = true;
+    }
+
+    @Override
+    public Page getOutput() {
+        if (isFinished()) {
+            assert remainingDocs <= 0 : remainingDocs;
+            return null;
+        }
+        try {
+            final LuceneScorer scorer = getCurrentOrLoadNextScorer();
+            // no scorer means no more docs
+            if (scorer == null) {
+                remainingDocs = 0;
+            } else {
+                Weight weight = scorer.weight();
+                var leafReaderContext = scorer.leafReaderContext();
+                // see org.apache.lucene.search.TotalHitCountCollector
+                int leafCount = weight == null ? -1 : weight.count(leafReaderContext);
+                if (leafCount != -1) {
+                    // make sure to NOT multi count as the count _shortcut_ (which is segment wide)
+                    // handle doc partitioning where the same leaf can be seen multiple times
+                    // since the count is global, consider it only for the first partition and skip the rest
+                    // SHARD, SEGMENT and the first DOC_ reader in data partitioning contain the first doc (position 0)
+                    if (scorer.position() == 0) {
+                        // check to not count over the desired number of docs/limit
+                        var count = Math.min(leafCount, remainingDocs);
+                        totalHits += count;
+                        remainingDocs -= count;
+                        scorer.markAsDone();
+                    }
+                } else {
+                    // could not apply shortcut, trigger the search
+                    scorer.scoreNextRange(leafCollector, leafReaderContext.reader().getLiveDocs(), remainingDocs);
+                }
+            }
+
+            Page page = null;
+            // emit only one page
+            if (remainingDocs <= 0 && pagesEmitted == 0) {
+                pagesEmitted++;
+                page = new Page(
+                    PAGE_SIZE,
+                    LongBlock.newConstantBlockWith(totalHits, PAGE_SIZE),
+                    BooleanBlock.newConstantBlockWith(true, PAGE_SIZE)
+                );
+            }
+            return page;
+        } catch (IOException e) {
+            throw new UncheckedIOException(e);
+        }
+    }
+
+    @Override
+    protected void describe(StringBuilder sb) {
+        sb.append(", remainingDocs=").append(remainingDocs);
+    }
+}

+ 9 - 3
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java

@@ -59,9 +59,7 @@ public abstract class LuceneOperator extends SourceOperator {
     }
 
     @Override
-    public void close() {
-
-    }
+    public void close() {}
 
     LuceneScorer getCurrentOrLoadNextScorer() {
         while (currentScorer == null || currentScorer.isDone()) {
@@ -150,6 +148,14 @@ public abstract class LuceneOperator extends SourceOperator {
         SearchContext searchContext() {
             return searchContext;
         }
+
+        Weight weight() {
+            return weight;
+        }
+
+        int position() {
+            return position;
+        }
     }
 
     @Override

+ 155 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneCountOperatorTests.java

@@ -0,0 +1,155 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.lucene;
+
+import org.apache.lucene.document.SortedNumericDocValuesField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexableField;
+import org.apache.lucene.index.NoMergePolicy;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.data.BooleanBlock;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.AnyOperatorTestCase;
+import org.elasticsearch.compute.operator.Driver;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.OperatorTestCase;
+import org.elasticsearch.compute.operator.PageConsumerOperator;
+import org.elasticsearch.core.IOUtils;
+import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy;
+import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.search.internal.ContextIndexSearcher;
+import org.elasticsearch.search.internal.SearchContext;
+import org.junit.After;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Function;
+
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class LuceneCountOperatorTests extends AnyOperatorTestCase {
+    private Directory directory = newDirectory();
+    private IndexReader reader;
+
+    @After
+    public void closeIndex() throws IOException {
+        IOUtils.close(reader, directory);
+    }
+
+    @Override
+    protected LuceneCountOperator.Factory simple(BigArrays bigArrays) {
+        return simple(bigArrays, randomFrom(DataPartitioning.values()), between(1, 10_000), 100);
+    }
+
+    private LuceneCountOperator.Factory simple(BigArrays bigArrays, DataPartitioning dataPartitioning, int numDocs, int limit) {
+        int commitEvery = Math.max(1, numDocs / 10);
+        try (
+            RandomIndexWriter writer = new RandomIndexWriter(
+                random(),
+                directory,
+                newIndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE)
+            )
+        ) {
+            for (int d = 0; d < numDocs; d++) {
+                List<IndexableField> doc = new ArrayList<>();
+                doc.add(new SortedNumericDocValuesField("s", d));
+                writer.addDocument(doc);
+                if (d % commitEvery == 0) {
+                    writer.commit();
+                }
+            }
+            reader = writer.getReader();
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+
+        SearchContext ctx = mockSearchContext(reader);
+        SearchExecutionContext ectx = mock(SearchExecutionContext.class);
+        when(ctx.getSearchExecutionContext()).thenReturn(ectx);
+        when(ectx.getIndexReader()).thenReturn(reader);
+        Function<SearchContext, Query> queryFunction = c -> new MatchAllDocsQuery();
+        return new LuceneCountOperator.Factory(List.of(ctx), queryFunction, dataPartitioning, 1, limit);
+    }
+
+    @Override
+    protected String expectedToStringOfSimple() {
+        assumeFalse("can't support variable maxPageSize", true); // TODO allow testing this
+        return "LuceneCountOperator[shardId=0, maxPageSize=**random**]";
+    }
+
+    @Override
+    protected String expectedDescriptionOfSimple() {
+        assumeFalse("can't support variable maxPageSize", true); // TODO allow testing this
+        return """
+            LuceneCountOperator[dataPartitioning = SHARD, maxPageSize = **random**, limit = 100, sorts = [{"s":{"order":"asc"}}]]""";
+    }
+
+    // TODO tests for the other data partitioning configurations
+
+    public void testShardDataPartitioning() {
+        int size = between(1_000, 20_000);
+        int limit = between(10, size);
+        testCount(size, limit);
+    }
+
+    public void testEmpty() {
+        testCount(0, between(10, 10_000));
+    }
+
+    private void testCount(int size, int limit) {
+        DriverContext ctx = driverContext();
+        LuceneCountOperator.Factory factory = simple(nonBreakingBigArrays(), DataPartitioning.SHARD, size, limit);
+
+        List<Page> results = new ArrayList<>();
+        OperatorTestCase.runDriver(new Driver(ctx, factory.get(ctx), List.of(), new PageConsumerOperator(results::add), () -> {}));
+        OperatorTestCase.assertDriverContext(ctx);
+
+        assertThat(results, hasSize(1));
+        Page page = results.get(0);
+
+        assertThat(page.getPositionCount(), is(1));
+        assertThat(page.getBlockCount(), is(2));
+        LongBlock lb = page.getBlock(0);
+        assertThat(lb.getPositionCount(), is(1));
+        assertThat(lb.getLong(0), is((long) Math.min(size, limit)));
+        BooleanBlock bb = page.getBlock(1);
+        assertThat(bb.getBoolean(1), is(true));
+    }
+
+    /**
+     * Creates a mock search context with the given index reader.
+     * The returned mock search context can be used to test with {@link LuceneOperator}.
+     */
+    public static SearchContext mockSearchContext(IndexReader reader) {
+        try {
+            ContextIndexSearcher searcher = new ContextIndexSearcher(
+                reader,
+                IndexSearcher.getDefaultSimilarity(),
+                IndexSearcher.getDefaultQueryCache(),
+                TrivialQueryCachingPolicy.NEVER,
+                true
+            );
+            SearchContext searchContext = mock(SearchContext.class);
+            when(searchContext.searcher()).thenReturn(searcher);
+            return searchContext;
+        } catch (IOException e) {
+            throw new UncheckedIOException(e);
+        }
+    }
+}

+ 25 - 12
x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java

@@ -32,6 +32,7 @@ import static org.elasticsearch.test.MapMatcher.assertMap;
 import static org.elasticsearch.test.MapMatcher.matchesMap;
 import static org.elasticsearch.xpack.esql.CsvAssert.assertData;
 import static org.elasticsearch.xpack.esql.CsvAssert.assertMetadata;
+import static org.elasticsearch.xpack.esql.CsvTestUtils.ExpectedResults;
 import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled;
 import static org.elasticsearch.xpack.esql.CsvTestUtils.loadCsvSpecValues;
 import static org.elasticsearch.xpack.esql.CsvTestsDataLoader.CSV_DATASET_MAP;
@@ -83,6 +84,10 @@ public abstract class EsqlSpecTestCase extends ESRestTestCase {
         }
     }
 
+    public boolean logResults() {
+        return false;
+    }
+
     public final void test() throws Throwable {
         try {
             assumeTrue("Test " + testName + " is not enabled", isEnabled(testName));
@@ -97,21 +102,29 @@ public abstract class EsqlSpecTestCase extends ESRestTestCase {
         Map<String, Object> answer = runEsql(builder.query(testCase.query).build(), testCase.expectedWarnings);
         var expectedColumnsWithValues = loadCsvSpecValues(testCase.expectedResults);
 
-        assertNotNull(answer.get("columns"));
+        var metadata = answer.get("columns");
+        assertNotNull(metadata);
         @SuppressWarnings("unchecked")
-        var actualColumns = (List<Map<String, String>>) answer.get("columns");
-        assertMetadata(expectedColumnsWithValues, actualColumns, LOGGER);
+        var actualColumns = (List<Map<String, String>>) metadata;
 
-        assertNotNull(answer.get("values"));
+        Logger logger = logResults() ? LOGGER : null;
+        var values = answer.get("values");
+        assertNotNull(values);
         @SuppressWarnings("unchecked")
-        List<List<Object>> actualValues = (List<List<Object>>) answer.get("values");
-        assertData(
-            expectedColumnsWithValues,
-            actualValues,
-            testCase.ignoreOrder,
-            LOGGER,
-            value -> value == null ? "null" : value.toString()
-        );
+        List<List<Object>> actualValues = (List<List<Object>>) values;
+
+        assertResults(expectedColumnsWithValues, actualColumns, actualValues, testCase.ignoreOrder, logger);
+    }
+
+    protected void assertResults(
+        ExpectedResults expected,
+        List<Map<String, String>> actualColumns,
+        List<List<Object>> actualValues,
+        boolean ignoreOrder,
+        Logger logger
+    ) {
+        assertMetadata(expected, actualColumns, logger);
+        assertData(expected, actualValues, testCase.ignoreOrder, logger, value -> value == null ? "null" : value.toString());
     }
 
     private Throwable reworkException(Throwable th) {

+ 1 - 1
x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java

@@ -415,7 +415,7 @@ public final class CsvTestUtils {
         }
     }
 
-    static void logMetaData(List<String> actualColumnNames, List<Type> actualColumnTypes, Logger logger) {
+    public static void logMetaData(List<String> actualColumnNames, List<Type> actualColumnTypes, Logger logger) {
         // header
         StringBuilder sb = new StringBuilder();
         StringBuilder column = new StringBuilder();

+ 36 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

@@ -533,3 +533,39 @@ c:l
 ;
 
 
+
+countAllGrouped
+from employees | stats c = count(*) by languages | rename languages as l | sort l DESC ; 
+
+c:l | l:i
+10  |null
+21  |5 
+18  |4
+17  |3
+19  |2
+15  |1
+;
+
+countAllAndOtherStatGrouped
+from employees | stats c = count(*), min = min(emp_no) by languages | sort languages;
+
+c:l | min:i    | languages:i
+15  | 10005    | 1 
+19  | 10001    | 2
+17  | 10006    | 3
+18  | 10003    | 4
+21  | 10002    | 5
+10  | 10020    | null
+;
+
+countAllWithEval
+from employees | rename languages as l | stats min = min(salary) by l | eval x = min + 1 | stats ca = count(*), cx = count(x) by l | sort l; 
+
+ca:l | cx:l | l:i
+1    | 1    | 1 
+1    | 1    | 2
+1    | 1    | 3
+1    | 1    | 4
+1    | 1    | 5
+1    | 1    | null
+;

+ 42 - 1
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java

@@ -448,7 +448,48 @@ public class EsqlActionIT extends AbstractEsqlIntegTestCase {
         assertEquals(0.034d, (double) getValuesList(results).get(0).get(0), 0.001d);
     }
 
-    public void testFromStatsThenEval() {
+    public void testUngroupedCountAll() {
+        EsqlQueryResponse results = run("from test | stats count(*)");
+        logger.info(results);
+        Assert.assertEquals(1, results.columns().size());
+        Assert.assertEquals(1, getValuesList(results).size());
+        assertEquals("count(*)", results.columns().get(0).name());
+        assertEquals("long", results.columns().get(0).type());
+        var values = getValuesList(results).get(0);
+        assertEquals(1, values.size());
+        assertEquals(40, (long) values.get(0));
+    }
+
+    public void testUngroupedCountAllWithFilter() {
+        EsqlQueryResponse results = run("from test | where data > 1 | stats count(*)");
+        logger.info(results);
+        Assert.assertEquals(1, results.columns().size());
+        Assert.assertEquals(1, getValuesList(results).size());
+        assertEquals("count(*)", results.columns().get(0).name());
+        assertEquals("long", results.columns().get(0).type());
+        var values = getValuesList(results).get(0);
+        assertEquals(1, values.size());
+        assertEquals(20, (long) values.get(0));
+    }
+
+    @AwaitsFix(bugUrl = "tracking down a 64b(long) memory leak")
+    public void testGroupedCountAllWithFilter() {
+        EsqlQueryResponse results = run("from test | where data > 1 | stats count(*) by data | sort data");
+        logger.info(results);
+        Assert.assertEquals(2, results.columns().size());
+        Assert.assertEquals(1, getValuesList(results).size());
+        assertEquals("count(*)", results.columns().get(0).name());
+        assertEquals("long", results.columns().get(0).type());
+        assertEquals("data", results.columns().get(1).name());
+        assertEquals("long", results.columns().get(1).type());
+        var values = getValuesList(results).get(0);
+        assertEquals(2, values.size());
+        assertEquals(20, (long) values.get(0));
+        assertEquals(2L, (long) values.get(1));
+    }
+
+    public void testFromStatsEvalWithPragma() {
+        assumeTrue("pragmas only enabled on snapshot builds", Build.current().isSnapshot());
         EsqlQueryResponse results = run("from test | stats avg_count = avg(count) | eval x = avg_count + 7");
         logger.info(results);
         Assert.assertEquals(1, getValuesList(results).size());

+ 5 - 1
x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4

@@ -76,8 +76,12 @@ operatorExpression
 primaryExpression
     : constant                                                                          #constantDefault
     | qualifiedName                                                                     #dereference
+    | functionExpression                                                                #function
     | LP booleanExpression RP                                                           #parenthesizedExpression
-    | identifier LP (booleanExpression (COMMA booleanExpression)*)? RP                  #functionExpression
+    ;
+
+functionExpression
+    : identifier LP (ASTERISK | (booleanExpression (COMMA booleanExpression)*))? RP
     ;
 
 rowCommand

+ 4 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

@@ -192,6 +192,10 @@ public class EsqlFunctionRegistry extends FunctionRegistry {
 
     @Override
     protected String normalize(String name) {
+        return normalizeName(name);
+    }
+
+    public static String normalizeName(String name) {
         return name.toLowerCase(Locale.ROOT);
     }
 }

+ 75 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java

@@ -7,15 +7,19 @@
 
 package org.elasticsearch.xpack.esql.optimizer;
 
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
 import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.Equals;
 import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.NotEquals;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
 import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerRules.OptimizerRule;
 import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
 import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
 import org.elasticsearch.xpack.esql.plan.physical.EsSourceExec;
+import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec;
+import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.Stat;
 import org.elasticsearch.xpack.esql.plan.physical.ExchangeExec;
 import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
 import org.elasticsearch.xpack.esql.plan.physical.FilterExec;
@@ -23,15 +27,19 @@ import org.elasticsearch.xpack.esql.plan.physical.LimitExec;
 import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
 import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
 import org.elasticsearch.xpack.esql.plan.physical.UnaryExec;
+import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders;
 import org.elasticsearch.xpack.esql.planner.PhysicalVerificationException;
 import org.elasticsearch.xpack.esql.planner.PhysicalVerifier;
 import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery;
 import org.elasticsearch.xpack.ql.common.Failure;
+import org.elasticsearch.xpack.ql.expression.Alias;
 import org.elasticsearch.xpack.ql.expression.Attribute;
+import org.elasticsearch.xpack.ql.expression.AttributeMap;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.Expressions;
 import org.elasticsearch.xpack.ql.expression.FieldAttribute;
 import org.elasticsearch.xpack.ql.expression.MetadataAttribute;
+import org.elasticsearch.xpack.ql.expression.NamedExpression;
 import org.elasticsearch.xpack.ql.expression.Order;
 import org.elasticsearch.xpack.ql.expression.TypedAttribute;
 import org.elasticsearch.xpack.ql.expression.function.scalar.ScalarFunction;
@@ -48,6 +56,7 @@ import org.elasticsearch.xpack.ql.rule.ParameterizedRuleExecutor;
 import org.elasticsearch.xpack.ql.rule.Rule;
 import org.elasticsearch.xpack.ql.util.Queries;
 import org.elasticsearch.xpack.ql.util.Queries.Clause;
+import org.elasticsearch.xpack.ql.util.StringUtils;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -58,6 +67,9 @@ import java.util.Set;
 import java.util.function.Supplier;
 
 import static java.util.Arrays.asList;
+import static java.util.Collections.emptyList;
+import static java.util.Collections.singletonList;
+import static org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.StatsType.COUNT;
 import static org.elasticsearch.xpack.ql.expression.predicate.Predicates.splitAnd;
 import static org.elasticsearch.xpack.ql.optimizer.OptimizerRules.TransformDirection.UP;
 
@@ -90,6 +102,7 @@ public class LocalPhysicalPlanOptimizer extends ParameterizedRuleExecutor<Physic
             esSourceRules.add(new PushTopNToSource());
             esSourceRules.add(new PushLimitToSource());
             esSourceRules.add(new PushFiltersToSource());
+            esSourceRules.add(new PushStatsToSource());
         }
 
         // execute the rules multiple times to improve the chances of things being pushed down
@@ -304,6 +317,68 @@ public class LocalPhysicalPlanOptimizer extends ParameterizedRuleExecutor<Physic
         }
     }
 
+    /**
+     * Looks for the case where certain stats exist right before the query and thus can be pushed down.
+     */
+    private static class PushStatsToSource extends OptimizerRule<AggregateExec> {
+
+        @Override
+        protected PhysicalPlan rule(AggregateExec aggregateExec) {
+            PhysicalPlan plan = aggregateExec;
+            if (aggregateExec.child() instanceof EsQueryExec queryExec) {
+                var tuple = pushableStats(aggregateExec);
+
+                // TODO: handle case where some aggs cannot be pushed down by breaking the aggs into two sources (regular + stats) + union
+                // use the stats since the attributes are larger in size (due to seen)
+                if (tuple.v2().size() == aggregateExec.aggregates().size()) {
+                    plan = new EsStatsQueryExec(
+                        aggregateExec.source(),
+                        queryExec.index(),
+                        queryExec.query(),
+                        queryExec.limit(),
+                        tuple.v1(),
+                        tuple.v2()
+                    );
+                }
+            }
+            return plan;
+        }
+
+        private Tuple<List<Attribute>, List<Stat>> pushableStats(AggregateExec aggregate) {
+            AttributeMap<Stat> stats = new AttributeMap<>();
+            Tuple<List<Attribute>, List<Stat>> tuple = new Tuple<>(new ArrayList<Attribute>(), new ArrayList<Stat>());
+
+            if (aggregate.groupings().isEmpty()) {
+                for (NamedExpression agg : aggregate.aggregates()) {
+                    var attribute = agg.toAttribute();
+                    Stat stat = stats.computeIfAbsent(attribute, a -> {
+                        if (agg instanceof Alias as) {
+                            Expression child = as.child();
+                            if (child instanceof Count count) {
+                                var target = count.field();
+                                // TODO: add count over field (has to be field attribute)
+                                if (target.foldable()) {
+                                    return new Stat(StringUtils.WILDCARD, COUNT);
+                                }
+                            }
+                        }
+                        return null;
+                    });
+                    if (stat != null) {
+                        List<Attribute> intermediateAttributes = AbstractPhysicalOperationProviders.intermediateAttributes(
+                            singletonList(agg),
+                            emptyList()
+                        );
+                        tuple.v1().addAll(intermediateAttributes);
+                        tuple.v2().add(stat);
+                    }
+                }
+            }
+
+            return tuple;
+        }
+    }
+
     private static final class EsqlTranslatorHandler extends QlTranslatorHandler {
         @Override
         public Query wrapFunctionQuery(ScalarFunction sf, Expression field, Supplier<Query> querySupplier) {

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseLexer.java

@@ -127,7 +127,7 @@ public class EsqlBaseLexer extends Lexer {
   }
 
 
-  @SuppressWarnings("this-escape") public EsqlBaseLexer(CharStream input) {
+  public EsqlBaseLexer(CharStream input) {
     super(input);
     _interp = new LexerATNSimulator(this,_ATN,_decisionToDFA,_sharedContextCache);
   }

File diff suppressed because it is too large
+ 1 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.interp


File diff suppressed because it is too large
+ 249 - 177
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParser.java


+ 12 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseListener.java

@@ -252,6 +252,18 @@ public class EsqlBaseParserBaseListener implements EsqlBaseParserListener {
    * <p>The default implementation does nothing.</p>
    */
   @Override public void exitDereference(EsqlBaseParser.DereferenceContext ctx) { }
+  /**
+   * {@inheritDoc}
+   *
+   * <p>The default implementation does nothing.</p>
+   */
+  @Override public void enterFunction(EsqlBaseParser.FunctionContext ctx) { }
+  /**
+   * {@inheritDoc}
+   *
+   * <p>The default implementation does nothing.</p>
+   */
+  @Override public void exitFunction(EsqlBaseParser.FunctionContext ctx) { }
   /**
    * {@inheritDoc}
    *

+ 7 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserBaseVisitor.java

@@ -152,6 +152,13 @@ public class EsqlBaseParserBaseVisitor<T> extends AbstractParseTreeVisitor<T> im
    * {@link #visitChildren} on {@code ctx}.</p>
    */
   @Override public T visitDereference(EsqlBaseParser.DereferenceContext ctx) { return visitChildren(ctx); }
+  /**
+   * {@inheritDoc}
+   *
+   * <p>The default implementation returns the result of calling
+   * {@link #visitChildren} on {@code ctx}.</p>
+   */
+  @Override public T visitFunction(EsqlBaseParser.FunctionContext ctx) { return visitChildren(ctx); }
   /**
    * {@inheritDoc}
    *

+ 14 - 4
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserListener.java

@@ -237,6 +237,18 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
    * @param ctx the parse tree
    */
   void exitDereference(EsqlBaseParser.DereferenceContext ctx);
+  /**
+   * Enter a parse tree produced by the {@code function}
+   * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+   * @param ctx the parse tree
+   */
+  void enterFunction(EsqlBaseParser.FunctionContext ctx);
+  /**
+   * Exit a parse tree produced by the {@code function}
+   * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+   * @param ctx the parse tree
+   */
+  void exitFunction(EsqlBaseParser.FunctionContext ctx);
   /**
    * Enter a parse tree produced by the {@code parenthesizedExpression}
    * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
@@ -250,14 +262,12 @@ public interface EsqlBaseParserListener extends ParseTreeListener {
    */
   void exitParenthesizedExpression(EsqlBaseParser.ParenthesizedExpressionContext ctx);
   /**
-   * Enter a parse tree produced by the {@code functionExpression}
-   * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+   * Enter a parse tree produced by {@link EsqlBaseParser#functionExpression}.
    * @param ctx the parse tree
    */
   void enterFunctionExpression(EsqlBaseParser.FunctionExpressionContext ctx);
   /**
-   * Exit a parse tree produced by the {@code functionExpression}
-   * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+   * Exit a parse tree produced by {@link EsqlBaseParser#functionExpression}.
    * @param ctx the parse tree
    */
   void exitFunctionExpression(EsqlBaseParser.FunctionExpressionContext ctx);

+ 8 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlBaseParserVisitor.java

@@ -145,6 +145,13 @@ public interface EsqlBaseParserVisitor<T> extends ParseTreeVisitor<T> {
    * @return the visitor result
    */
   T visitDereference(EsqlBaseParser.DereferenceContext ctx);
+  /**
+   * Visit a parse tree produced by the {@code function}
+   * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+   * @param ctx the parse tree
+   * @return the visitor result
+   */
+  T visitFunction(EsqlBaseParser.FunctionContext ctx);
   /**
    * Visit a parse tree produced by the {@code parenthesizedExpression}
    * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
@@ -153,8 +160,7 @@ public interface EsqlBaseParserVisitor<T> extends ParseTreeVisitor<T> {
    */
   T visitParenthesizedExpression(EsqlBaseParser.ParenthesizedExpressionContext ctx);
   /**
-   * Visit a parse tree produced by the {@code functionExpression}
-   * labeled alternative in {@link EsqlBaseParser#primaryExpression}.
+   * Visit a parse tree produced by {@link EsqlBaseParser#functionExpression}.
    * @param ctx the parse tree
    * @return the visitor result
    */

+ 11 - 6
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java

@@ -20,6 +20,7 @@ import org.elasticsearch.xpack.esql.evaluator.predicate.operator.comparison.Less
 import org.elasticsearch.xpack.esql.evaluator.predicate.operator.regex.RLike;
 import org.elasticsearch.xpack.esql.evaluator.predicate.operator.regex.WildcardLike;
 import org.elasticsearch.xpack.esql.expression.Order;
+import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod;
@@ -62,6 +63,7 @@ import java.util.Map;
 import java.util.function.BiFunction;
 
 import static java.util.Collections.emptyList;
+import static java.util.Collections.singletonList;
 import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.DATE_PERIOD;
 import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.TIME_DURATION;
 import static org.elasticsearch.xpack.ql.parser.ParserUtils.source;
@@ -312,12 +314,15 @@ abstract class ExpressionBuilder extends IdentifierBuilder {
 
     @Override
     public Expression visitFunctionExpression(EsqlBaseParser.FunctionExpressionContext ctx) {
-        return new UnresolvedFunction(
-            source(ctx),
-            visitIdentifier(ctx.identifier()),
-            FunctionResolutionStrategy.DEFAULT,
-            ctx.booleanExpression().stream().map(this::expression).toList()
-        );
+        String name = visitIdentifier(ctx.identifier());
+        List<Expression> args = expressions(ctx.booleanExpression());
+        if ("count".equals(EsqlFunctionRegistry.normalizeName(name))) {
+            // to simplify the registration, handle in the parser the special count cases
+            if (args.isEmpty() || ctx.ASTERISK() != null) {
+                args = singletonList(new Literal(source(ctx), "*", DataTypes.KEYWORD));
+            }
+        }
+        return new UnresolvedFunction(source(ctx), name, FunctionResolutionStrategy.DEFAULT, args);
     }
 
     @Override

+ 128 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsStatsQueryExec.java

@@ -0,0 +1,128 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.plan.physical;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.xpack.ql.expression.Attribute;
+import org.elasticsearch.xpack.ql.expression.Expression;
+import org.elasticsearch.xpack.ql.index.EsIndex;
+import org.elasticsearch.xpack.ql.tree.NodeInfo;
+import org.elasticsearch.xpack.ql.tree.NodeUtils;
+import org.elasticsearch.xpack.ql.tree.Source;
+
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * Specialized query class for retrieving statistics about the underlying data and not the actual documents.
+ * For that see {@link EsQueryExec}
+ */
+public class EsStatsQueryExec extends LeafExec implements EstimatesRowSize {
+
+    public enum StatsType {
+        COUNT,
+        MIN,
+        MAX,
+        EXISTS;
+    }
+
+    public record Stat(String name, StatsType type) {};
+
+    private final EsIndex index;
+    private final QueryBuilder query;
+    private final Expression limit;
+    private final List<Attribute> attrs;
+    private final List<Stat> stats;
+
+    public EsStatsQueryExec(
+        Source source,
+        EsIndex index,
+        QueryBuilder query,
+        Expression limit,
+        List<Attribute> attributes,
+        List<Stat> stats
+    ) {
+        super(source);
+        this.index = index;
+        this.query = query;
+        this.limit = limit;
+        this.attrs = attributes;
+        this.stats = stats;
+    }
+
+    @Override
+    protected NodeInfo<EsStatsQueryExec> info() {
+        return NodeInfo.create(this, EsStatsQueryExec::new, index, query, limit, attrs, stats);
+    }
+
+    public EsIndex index() {
+        return index;
+    }
+
+    public QueryBuilder query() {
+        return query;
+    }
+
+    @Override
+    public List<Attribute> output() {
+        return attrs;
+    }
+
+    public Expression limit() {
+        return limit;
+    }
+
+    @Override
+    // TODO - get the estimation outside the plan so it doesn't touch the plan
+    public PhysicalPlan estimateRowSize(State state) {
+        int size;
+        state.add(false, attrs);
+        size = state.consumeAllFields(false);
+        return this;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(index, query, limit, attrs, stats);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) {
+            return true;
+        }
+
+        if (obj == null || getClass() != obj.getClass()) {
+            return false;
+        }
+
+        EsStatsQueryExec other = (EsStatsQueryExec) obj;
+        return Objects.equals(index, other.index)
+            && Objects.equals(attrs, other.attrs)
+            && Objects.equals(query, other.query)
+            && Objects.equals(limit, other.limit)
+            && Objects.equals(stats, other.stats);
+    }
+
+    @Override
+    public String nodeString() {
+        return nodeName()
+            + "["
+            + index
+            + "], stats"
+            + stats
+            + "], query["
+            + (query != null ? Strings.toString(query, false, true) : "")
+            + "]"
+            + NodeUtils.limitedToString(attrs)
+            + ", limit["
+            + (limit != null ? limit.toString() : "")
+            + "], ";
+    }
+}

+ 31 - 3
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java

@@ -18,6 +18,7 @@ import org.elasticsearch.compute.operator.HashAggregationOperator;
 import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory;
 import org.elasticsearch.compute.operator.Operator;
 import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
+import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
 import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
 import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
 import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation;
@@ -35,7 +36,9 @@ import java.util.List;
 import java.util.Set;
 import java.util.function.Consumer;
 
-abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders {
+import static java.util.Collections.emptyList;
+
+public abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders {
 
     private final AggregateMapper aggregateMapper = new AggregateMapper();
 
@@ -235,7 +238,30 @@ abstract class AbstractPhysicalOperationProviders implements PhysicalOperationPr
                     if (mode == AggregateExec.Mode.PARTIAL) {
                         aggMode = AggregatorMode.INITIAL;
                         // TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1)
-                        sourceAttr = List.of(Expressions.attribute(aggregateFunction.field()));
+                        Expression field = aggregateFunction.field();
+                        // Only count can now support literals - all the other aggs should be optimized away
+                        if (field.foldable()) {
+                            if (aggregateFunction instanceof Count count) {
+                                sourceAttr = emptyList();
+                            } else {
+                                throw new EsqlIllegalArgumentException(
+                                    "Does not support yet aggregations over constants - [{}]",
+                                    aggregateFunction.sourceText()
+                                );
+                            }
+                        } else {
+                            Attribute attr = Expressions.attribute(field);
+                            // cannot determine attribute
+                            if (attr == null) {
+                                throw new EsqlIllegalArgumentException(
+                                    "Cannot work with target field [{}] for agg [{}]",
+                                    field.sourceText(),
+                                    aggregateFunction.sourceText()
+                                );
+                            }
+                            sourceAttr = List.of(attr);
+                        }
+
                     } else if (mode == AggregateExec.Mode.FINAL) {
                         aggMode = AggregatorMode.FINAL;
                         if (grouping) {
@@ -253,7 +279,9 @@ abstract class AbstractPhysicalOperationProviders implements PhysicalOperationPr
                     }
 
                     List<Integer> inputChannels = sourceAttr.stream().map(attr -> layout.get(attr.id()).channel()).toList();
-                    assert inputChannels != null && inputChannels.size() > 0 && inputChannels.stream().allMatch(i -> i >= 0);
+                    if (inputChannels.size() > 0) {
+                        assert inputChannels.size() > 0 && inputChannels.stream().allMatch(i -> i >= 0);
+                    }
                     if (aggregateFunction instanceof ToAggregator agg) {
                         consumer.accept(new AggFunctionSupplierContext(agg.supplier(bigArrays, inputChannels), aggMode));
                     } else {

+ 17 - 5
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java

@@ -20,6 +20,8 @@ import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
 import org.elasticsearch.compute.operator.Operator;
 import org.elasticsearch.compute.operator.OrdinalsGroupingOperator;
 import org.elasticsearch.index.mapper.NestedLookup;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.index.search.NestedHelper;
 import org.elasticsearch.logging.LogManager;
@@ -54,6 +56,10 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
         this.searchContexts = searchContexts;
     }
 
+    public List<SearchContext> searchContexts() {
+        return searchContexts;
+    }
+
     @Override
     public final PhysicalOperation fieldExtractPhysicalOperation(FieldExtractExec fieldExtractExec, PhysicalOperation source) {
         Layout.Builder layout = source.layout.builder();
@@ -85,12 +91,12 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
         return op;
     }
 
-    @Override
-    public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, LocalExecutionPlannerContext context) {
-        final LuceneOperator.Factory luceneFactory;
-        Function<SearchContext, Query> querySupplier = searchContext -> {
+    public static Function<SearchContext, Query> querySupplier(QueryBuilder queryBuilder) {
+        final QueryBuilder qb = queryBuilder == null ? QueryBuilders.matchAllQuery() : queryBuilder;
+
+        return searchContext -> {
             SearchExecutionContext ctx = searchContext.getSearchExecutionContext();
-            Query query = ctx.toQuery(esQueryExec.query()).query();
+            Query query = ctx.toQuery(qb).query();
             NestedLookup nestedLookup = ctx.nestedLookup();
             if (nestedLookup != NestedLookup.EMPTY) {
                 NestedHelper nestedHelper = new NestedHelper(nestedLookup, ctx::isFieldMapped);
@@ -110,6 +116,12 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
             }
             return query;
         };
+    }
+
+    @Override
+    public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, LocalExecutionPlannerContext context) {
+        Function<SearchContext, Query> querySupplier = querySupplier(esQueryExec.query());
+        final LuceneOperator.Factory luceneFactory;
 
         List<FieldSort> sorts = esQueryExec.sorts();
         List<SortBuilder<?>> fieldSorts = null;

+ 38 - 17
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.esql.planner;
 
+import org.apache.lucene.search.Query;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.iterable.Iterables;
 import org.elasticsearch.compute.Describable;
@@ -15,6 +16,8 @@ import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.lucene.DataPartitioning;
+import org.elasticsearch.compute.lucene.LuceneCountOperator;
+import org.elasticsearch.compute.lucene.LuceneOperator;
 import org.elasticsearch.compute.operator.ColumnExtractOperator;
 import org.elasticsearch.compute.operator.Driver;
 import org.elasticsearch.compute.operator.DriverContext;
@@ -43,7 +46,7 @@ import org.elasticsearch.compute.operator.topn.TopNOperator;
 import org.elasticsearch.compute.operator.topn.TopNOperator.TopNOperatorFactory;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.index.query.MatchAllQueryBuilder;
+import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
 import org.elasticsearch.xpack.esql.enrich.EnrichLookupOperator;
@@ -54,6 +57,7 @@ import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
 import org.elasticsearch.xpack.esql.plan.physical.DissectExec;
 import org.elasticsearch.xpack.esql.plan.physical.EnrichExec;
 import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
+import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec;
 import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
 import org.elasticsearch.xpack.esql.plan.physical.ExchangeExec;
 import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec;
@@ -96,6 +100,7 @@ import java.util.function.Function;
 import java.util.stream.Stream;
 
 import static java.util.stream.Collectors.joining;
+import static org.elasticsearch.compute.lucene.LuceneOperator.NO_LIMIT;
 import static org.elasticsearch.compute.operator.LimitOperator.Factory;
 import static org.elasticsearch.compute.operator.ProjectOperator.ProjectOperatorFactory;
 
@@ -196,6 +201,8 @@ public class LocalExecutionPlanner {
         // source nodes
         else if (node instanceof EsQueryExec esQuery) {
             return planEsQueryNode(esQuery, context);
+        } else if (node instanceof EsStatsQueryExec statsQuery) {
+            return planEsStats(statsQuery, context);
         } else if (node instanceof RowExec row) {
             return planRow(row, context);
         } else if (node instanceof LocalSourceExec localSource) {
@@ -224,19 +231,33 @@ public class LocalExecutionPlanner {
         return physicalOperationProviders.groupingPhysicalOperation(aggregate, source, context);
     }
 
-    private PhysicalOperation planEsQueryNode(EsQueryExec esQuery, LocalExecutionPlannerContext context) {
-        if (esQuery.query() == null) {
-            esQuery = new EsQueryExec(
-                esQuery.source(),
-                esQuery.index(),
-                esQuery.output(),
-                new MatchAllQueryBuilder(),
-                esQuery.limit(),
-                esQuery.sorts(),
-                esQuery.estimatedRowSize()
-            );
+    private PhysicalOperation planEsQueryNode(EsQueryExec esQueryExec, LocalExecutionPlannerContext context) {
+        return physicalOperationProviders.sourcePhysicalOperation(esQueryExec, context);
+    }
+
+    private PhysicalOperation planEsStats(EsStatsQueryExec statsQuery, LocalExecutionPlannerContext context) {
+        if (physicalOperationProviders instanceof EsPhysicalOperationProviders == false) {
+            throw new EsqlIllegalArgumentException("EsStatsQuery should only occur against a Lucene backend");
         }
-        return physicalOperationProviders.sourcePhysicalOperation(esQuery, context);
+        EsPhysicalOperationProviders esProvider = (EsPhysicalOperationProviders) physicalOperationProviders;
+
+        Function<SearchContext, Query> querySupplier = EsPhysicalOperationProviders.querySupplier(statsQuery.query());
+
+        Expression limitExp = statsQuery.limit();
+        int limit = limitExp != null ? (Integer) limitExp.fold() : NO_LIMIT;
+        final LuceneOperator.Factory luceneFactory = new LuceneCountOperator.Factory(
+            esProvider.searchContexts(),
+            querySupplier,
+            context.dataPartitioning(),
+            context.taskConcurrency(),
+            limit
+        );
+
+        Layout.Builder layout = new Layout.Builder();
+        layout.append(statsQuery.outputSet());
+        int instanceCount = Math.max(1, luceneFactory.taskConcurrency());
+        context.driverParallelism(new DriverParallelism(DriverParallelism.Type.DATA_PARALLELISM, instanceCount));
+        return PhysicalOperation.fromSource(luceneFactory, layout.build());
     }
 
     private PhysicalOperation planFieldExtractNode(LocalExecutionPlannerContext context, FieldExtractExec fieldExtractExec) {
@@ -318,11 +339,11 @@ public class LocalExecutionPlanner {
 
     private PhysicalOperation planExchangeSink(ExchangeSinkExec exchangeSink, LocalExecutionPlannerContext context) {
         Objects.requireNonNull(exchangeSinkHandler, "ExchangeSinkHandler wasn't provided");
-        PhysicalOperation source = plan(exchangeSink.child(), context);
+        var child = exchangeSink.child();
+        PhysicalOperation source = plan(child, context);
 
-        Function<Page, Page> transformer = exchangeSink.child() instanceof AggregateExec
-            ? Function.identity()
-            : alignPageToAttributes(exchangeSink.output(), source.layout);
+        boolean isAgg = child instanceof AggregateExec || child instanceof EsStatsQueryExec;
+        Function<Page, Page> transformer = isAgg ? Function.identity() : alignPageToAttributes(exchangeSink.output(), source.layout);
 
         return source.withSink(new ExchangeSinkOperatorFactory(exchangeSinkHandler::createExchangeSink, transformer), source.layout);
     }

+ 91 - 2
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java

@@ -14,12 +14,14 @@ import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.index.query.RangeQueryBuilder;
 import org.elasticsearch.index.query.RegexpQueryBuilder;
 import org.elasticsearch.index.query.TermQueryBuilder;
 import org.elasticsearch.index.query.TermsQueryBuilder;
 import org.elasticsearch.index.query.WildcardQueryBuilder;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.junit.annotations.TestLogging;
 import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
 import org.elasticsearch.xpack.esql.EsqlTestUtils;
 import org.elasticsearch.xpack.esql.analysis.Analyzer;
@@ -42,6 +44,7 @@ import org.elasticsearch.xpack.esql.plan.physical.EnrichExec;
 import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
 import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec.FieldSort;
 import org.elasticsearch.xpack.esql.plan.physical.EsSourceExec;
+import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec;
 import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize;
 import org.elasticsearch.xpack.esql.plan.physical.EvalExec;
 import org.elasticsearch.xpack.esql.plan.physical.ExchangeExec;
@@ -54,6 +57,7 @@ import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec;
 import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
 import org.elasticsearch.xpack.esql.plan.physical.ProjectExec;
 import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
+import org.elasticsearch.xpack.esql.planner.FilterTests;
 import org.elasticsearch.xpack.esql.planner.Mapper;
 import org.elasticsearch.xpack.esql.planner.PhysicalVerificationException;
 import org.elasticsearch.xpack.esql.planner.PlannerUtils;
@@ -91,6 +95,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping;
 import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForMissingField;
 import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization;
+import static org.elasticsearch.xpack.esql.plan.physical.AggregateExec.Mode.FINAL;
 import static org.elasticsearch.xpack.ql.expression.Expressions.name;
 import static org.elasticsearch.xpack.ql.expression.Expressions.names;
 import static org.elasticsearch.xpack.ql.expression.Order.OrderDirection.ASC;
@@ -103,7 +108,7 @@ import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.nullValue;
 
-//@TestLogging(value = "org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer:TRACE", reason = "debug")
+@TestLogging(value = "org.elasticsearch.xpack.esql.optimizer.LocalLogicalPlanOptimizer:TRACE", reason = "debug")
 public class PhysicalPlanOptimizerTests extends ESTestCase {
 
     private static final String PARAM_FORMATTING = "%1$s";
@@ -1844,7 +1849,7 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
         assertThat(limit.limit(), instanceOf(Literal.class));
         assertThat(limit.limit().fold(), equalTo(10000));
         var aggFinal = as(limit.child(), AggregateExec.class);
-        assertThat(aggFinal.getMode(), equalTo(AggregateExec.Mode.FINAL));
+        assertThat(aggFinal.getMode(), equalTo(FINAL));
         var aggPartial = as(aggFinal.child(), AggregateExec.class);
         assertThat(aggPartial.getMode(), equalTo(AggregateExec.Mode.PARTIAL));
         limit = as(aggPartial.child(), LimitExec.class);
@@ -1861,6 +1866,86 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
         assertThat(source.limit().fold(), equalTo(10));
     }
 
+    // optimized doesn't know yet how to push down count over field
+    public void testCountOneFieldWithFilter() {
+        var plan = optimizedPlan(physicalPlan("""
+            from test
+            | where salary > 1000
+            | stats c = count(salary)
+            """));
+        assertThat(plan.anyMatch(EsQueryExec.class::isInstance), is(true));
+    }
+
+    // optimized doesn't know yet how to push down count over field
+    public void testCountOneFieldWithFilterAndLimit() {
+        var plan = optimizedPlan(physicalPlan("""
+            from test
+            | where salary > 1000
+            | limit 10
+            | stats c = count(salary)
+            """));
+        assertThat(plan.anyMatch(EsQueryExec.class::isInstance), is(true));
+    }
+
+    // optimized doesn't know yet how to break down different multi count
+    public void testCountMultipleFieldsWithFilter() {
+        var plan = optimizedPlan(physicalPlan("""
+            from test
+            | where salary > 1000 and emp_no > 10010
+            | stats cs = count(salary), ce = count(emp_no)
+            """));
+        assertThat(plan.anyMatch(EsQueryExec.class::isInstance), is(true));
+    }
+
+    public void testCountAllWithFilter() {
+        var plan = optimizedPlan(physicalPlan("""
+            from test
+            | where emp_no > 10010
+            | stats c = count()
+            """));
+
+        var limit = as(plan, LimitExec.class);
+        var agg = as(limit.child(), AggregateExec.class);
+        assertThat(agg.getMode(), is(FINAL));
+        assertThat(Expressions.names(agg.aggregates()), contains("c"));
+        var exchange = as(agg.child(), ExchangeExec.class);
+        var esStatsQuery = as(exchange.child(), EsStatsQueryExec.class);
+        assertThat(esStatsQuery.limit(), is(nullValue()));
+        assertThat(Expressions.names(esStatsQuery.output()), contains("count", "seen"));
+        var expected = wrapWithSingleQuery(QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no");
+        assertThat(expected.toString(), is(esStatsQuery.query().toString()));
+    }
+
+    @AwaitsFix(bugUrl = "intermediateAgg does proper reduction but the agg itself does not - the optimizer needs to improve")
+    public void testMultiCountAllWithFilter() {
+        var plan = optimizedPlan(physicalPlan("""
+            from test
+            | where emp_no > 10010
+            | stats c = count(), call = count(*), c_literal = count(1)
+            """));
+
+        var limit = as(plan, LimitExec.class);
+        var agg = as(limit.child(), AggregateExec.class);
+        assertThat(agg.getMode(), is(FINAL));
+        assertThat(Expressions.names(agg.aggregates()), contains("c", "call", "c_literal"));
+        var exchange = as(agg.child(), ExchangeExec.class);
+        var esStatsQuery = as(exchange.child(), EsStatsQueryExec.class);
+        assertThat(esStatsQuery.limit(), is(nullValue()));
+        assertThat(Expressions.names(esStatsQuery.output()), contains("count", "seen"));
+        var expected = wrapWithSingleQuery(QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no");
+        assertThat(expected.toString(), is(esStatsQuery.query().toString()));
+    }
+
+    // optimized doesn't know yet how to break down different multi count
+    public void testCountFieldsAndAllWithFilter() {
+        var plan = optimizedPlan(physicalPlan("""
+            from test
+            | where emp_no > 10010
+            | stats c = count(), cs = count(salary), ce = count(emp_no)
+            """));
+        assertThat(plan.anyMatch(EsQueryExec.class::isInstance), is(true));
+    }
+
     private static EsQueryExec source(PhysicalPlan plan) {
         if (plan instanceof ExchangeExec exchange) {
             plan = exchange.child();
@@ -1915,4 +2000,8 @@ public class PhysicalPlanOptimizerTests extends ESTestCase {
         assertThat(sv.field(), equalTo(fieldName));
         return sv.next();
     }
+
+    private QueryBuilder wrapWithSingleQuery(QueryBuilder inner, String fieldName) {
+        return FilterTests.singleValueQuery(inner, fieldName);
+    }
 }

+ 6 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java

@@ -19,6 +19,8 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
 import org.elasticsearch.xpack.esql.plan.logical.Dissect;
 import org.elasticsearch.xpack.esql.plan.logical.Grok;
 import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
+import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.Stat;
+import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.StatsType;
 import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
 import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
 import org.elasticsearch.xpack.esql.type.EsqlDataTypes;
@@ -97,6 +99,10 @@ public class EsqlNodeSubclassTests<T extends B, B extends Node<B>> extends NodeS
                 ),
                 IndexResolution.invalid(randomAlphaOfLength(5))
             );
+
+        } else if (argClass == Stat.class) {
+            // record field
+            return new Stat(randomRealisticUnicodeOfLength(10), randomFrom(StatsType.values()));
         } else if (argClass == Integer.class) {
             return randomInt();
         }

Some files were not shown because too many files changed in this diff