1
0
Iván Cea Fontenla 10 сар өмнө
parent
commit
d90b4c7a9a
35 өөрчлөгдсөн 1660 нэмэгдсэн , 322 устгасан
  1. 5 0
      docs/changelog/114317.yaml
  2. 2 2
      docs/reference/esql/functions/kibana/definition/categorize.json
  3. 2 2
      docs/reference/esql/functions/types/categorize.asciidoc
  4. 0 15
      muted-tests.yml
  5. 105 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java
  6. 24 4
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java
  7. 137 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java
  8. 77 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java
  9. 9 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java
  10. 1 0
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java
  11. 34 0
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTestCase.java
  12. 1 21
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java
  13. 406 0
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java
  14. 1 0
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java
  15. 2 0
      x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java
  16. 518 8
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec
  17. 16 0
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-mv_sample_data.json
  18. 8 0
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_sample_data.csv
  19. 0 145
      x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeEvaluator.java
  20. 4 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
  21. 16 60
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java
  22. 26 12
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineProjections.java
  23. 2 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java
  24. 23 8
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java
  25. 12 5
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/InsertFieldExtraction.java
  26. 31 11
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java
  27. 3 3
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
  28. 2 1
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java
  29. 5 14
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java
  30. 1 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java
  31. 78 5
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java
  32. 11 5
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java
  33. 61 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java
  34. 13 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java
  35. 24 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java

+ 5 - 0
docs/changelog/114317.yaml

@@ -0,0 +1,5 @@
+pr: 114317
+summary: "ESQL: CATEGORIZE as a `BlockHash`"
+area: ES|QL
+type: enhancement
+issues: []

+ 2 - 2
docs/reference/esql/functions/kibana/definition/categorize.json

@@ -14,7 +14,7 @@
         }
       ],
       "variadic" : false,
-      "returnType" : "integer"
+      "returnType" : "keyword"
     },
     {
       "params" : [
@@ -26,7 +26,7 @@
         }
       ],
       "variadic" : false,
-      "returnType" : "integer"
+      "returnType" : "keyword"
     }
   ],
   "preview" : false,

+ 2 - 2
docs/reference/esql/functions/types/categorize.asciidoc

@@ -5,6 +5,6 @@
 [%header.monospaced.styled,format=dsv,separator=|]
 |===
 field | result
-keyword | integer
-text | integer
+keyword | keyword
+text | keyword
 |===

+ 0 - 15
muted-tests.yml

@@ -193,12 +193,6 @@ tests:
 - class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT
   method: test {p0=indices.split/40_routing_partition_size/more than 1}
   issue: https://github.com/elastic/elasticsearch/issues/113841
-- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
-  method: test {categorize.Categorize SYNC}
-  issue: https://github.com/elastic/elasticsearch/issues/113722
-- class: org.elasticsearch.xpack.esql.qa.mixed.MixedClusterEsqlSpecIT
-  method: test {categorize.Categorize ASYNC}
-  issue: https://github.com/elastic/elasticsearch/issues/116373
 - class: org.elasticsearch.kibana.KibanaThreadPoolIT
   method: testBlockedThreadPoolsRejectUserRequests
   issue: https://github.com/elastic/elasticsearch/issues/113939
@@ -254,12 +248,6 @@ tests:
 - class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT
   method: test {p0=search/380_sort_segments_on_timestamp/Test that index segments are NOT sorted on timestamp field when @timestamp field is dynamically added}
   issue: https://github.com/elastic/elasticsearch/issues/116221
-- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
-  method: test {categorize.Categorize SYNC}
-  issue: https://github.com/elastic/elasticsearch/issues/113054
-- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
-  method: test {categorize.Categorize ASYNC}
-  issue: https://github.com/elastic/elasticsearch/issues/113054
 - class: org.elasticsearch.ingest.common.IngestCommonClientYamlTestSuiteIT
   method: test {yaml=ingest/310_reroute_processor/Test remove then add reroute processor with and without lazy rollover}
   issue: https://github.com/elastic/elasticsearch/issues/116158
@@ -272,9 +260,6 @@ tests:
 - class: org.elasticsearch.xpack.deprecation.DeprecationHttpIT
   method: testDeprecatedSettingsReturnWarnings
   issue: https://github.com/elastic/elasticsearch/issues/108628
-- class: org.elasticsearch.xpack.esql.ccq.MultiClusterSpecIT
-  method: test {categorize.Categorize}
-  issue: https://github.com/elastic/elasticsearch/issues/116434
 - class: org.elasticsearch.xpack.apmdata.APMYamlTestSuiteIT
   method: test {yaml=/10_apm/Test template reinstallation}
   issue: https://github.com/elastic/elasticsearch/issues/116445

+ 105 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java

@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.util.BytesRefBuilder;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.common.util.BytesRefHash;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.ReleasableIterator;
+import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
+import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
+import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
+import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
+
+import java.io.IOException;
+
+/**
+ * Base BlockHash implementation for {@code Categorize} grouping function.
+ */
+public abstract class AbstractCategorizeBlockHash extends BlockHash {
+    // TODO: this should probably also take an emitBatchSize
+    private final int channel;
+    private final boolean outputPartial;
+    protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
+
+    AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) {
+        super(blockFactory);
+        this.channel = channel;
+        this.outputPartial = outputPartial;
+        this.categorizer = new TokenListCategorizer.CloseableTokenListCategorizer(
+            new CategorizationBytesRefHash(new BytesRefHash(2048, blockFactory.bigArrays())),
+            CategorizationPartOfSpeechDictionary.getInstance(),
+            0.70f
+        );
+    }
+
+    protected int channel() {
+        return channel;
+    }
+
+    @Override
+    public Block[] getKeys() {
+        return new Block[] { outputPartial ? buildIntermediateBlock() : buildFinalBlock() };
+    }
+
+    @Override
+    public IntVector nonEmpty() {
+        return IntVector.range(0, categorizer.getCategoryCount(), blockFactory);
+    }
+
+    @Override
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
+        throw new UnsupportedOperationException();
+    }
+
+    /**
+     * Serializes the intermediate state into a single BytesRef block, or an empty Null block if there are no categories.
+     */
+    private Block buildIntermediateBlock() {
+        if (categorizer.getCategoryCount() == 0) {
+            return blockFactory.newConstantNullBlock(0);
+        }
+        try (BytesStreamOutput out = new BytesStreamOutput()) {
+            // TODO be more careful here.
+            out.writeVInt(categorizer.getCategoryCount());
+            for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
+                category.writeTo(out);
+            }
+            // We're returning a block with N positions just because the Page must have all blocks with the same position count!
+            return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), categorizer.getCategoryCount());
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    private Block buildFinalBlock() {
+        try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
+            BytesRefBuilder scratch = new BytesRefBuilder();
+            for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
+                scratch.copyChars(category.getRegex());
+                result.appendBytesRef(scratch.get());
+                scratch.clear();
+            }
+            return result.build().asBlock();
+        }
+    }
+}

+ 24 - 4
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.util.BytesRefHash;
 import org.elasticsearch.common.util.Int3Hash;
 import org.elasticsearch.common.util.LongHash;
 import org.elasticsearch.common.util.LongLongHash;
+import org.elasticsearch.compute.aggregation.AggregatorMode;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
 import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.data.Block;
@@ -58,9 +59,7 @@ import java.util.List;
  *     leave a big gap, even if we never see {@code null}.
  * </p>
  */
-public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
-    permits BooleanBlockHash, BytesRefBlockHash, DoubleBlockHash, IntBlockHash, LongBlockHash, BytesRef2BlockHash, BytesRef3BlockHash, //
-    NullBlockHash, PackedValuesBlockHash, BytesRefLongBlockHash, LongLongBlockHash, TimeSeriesBlockHash {
+public abstract class BlockHash implements Releasable, SeenGroupIds {
 
     protected final BlockFactory blockFactory;
 
@@ -107,7 +106,15 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
     @Override
     public abstract BitArray seenGroupIds(BigArrays bigArrays);
 
-    public record GroupSpec(int channel, ElementType elementType) {}
+    /**
+     * @param isCategorize Whether this group is a CATEGORIZE() or not.
+     *                     May be changed in the future when more stateful grouping functions are added.
+     */
+    public record GroupSpec(int channel, ElementType elementType, boolean isCategorize) {
+        public GroupSpec(int channel, ElementType elementType) {
+            this(channel, elementType, false);
+        }
+    }
 
     /**
      * Creates a specialized hash table that maps one or more {@link Block}s to ids.
@@ -159,6 +166,19 @@ public abstract sealed class BlockHash implements Releasable, SeenGroupIds //
         return new PackedValuesBlockHash(groups, blockFactory, emitBatchSize);
     }
 
+    /**
+     * Builds a BlockHash for the Categorize grouping function.
+     */
+    public static BlockHash buildCategorizeBlockHash(List<GroupSpec> groups, AggregatorMode aggregatorMode, BlockFactory blockFactory) {
+        if (groups.size() != 1) {
+            throw new IllegalArgumentException("only a single CATEGORIZE group can used");
+        }
+
+        return aggregatorMode.isInputPartial()
+            ? new CategorizedIntermediateBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial())
+            : new CategorizeRawBlockHash(groups.get(0).channel, blockFactory, aggregatorMode.isOutputPartial());
+    }
+
     /**
      * Creates a specialized hash table that maps a {@link Block} of the given input element type to ids.
      */

+ 137 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java

@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.analysis.core.WhitespaceTokenizer;
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.index.analysis.CharFilterFactory;
+import org.elasticsearch.index.analysis.CustomAnalyzer;
+import org.elasticsearch.index.analysis.TokenFilterFactory;
+import org.elasticsearch.index.analysis.TokenizerFactory;
+import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
+import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
+
+/**
+ * BlockHash implementation for {@code Categorize} grouping function.
+ * <p>
+ *     This implementation expects rows, and can't deserialize intermediate states coming from other nodes.
+ * </p>
+ */
+public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
+    private final CategorizeEvaluator evaluator;
+
+    CategorizeRawBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
+        super(blockFactory, channel, outputPartial);
+        CategorizationAnalyzer analyzer = new CategorizationAnalyzer(
+            // TODO: should be the same analyzer as used in Production
+            new CustomAnalyzer(
+                TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
+                new CharFilterFactory[0],
+                new TokenFilterFactory[0]
+            ),
+            true
+        );
+        this.evaluator = new CategorizeEvaluator(analyzer, categorizer, blockFactory);
+    }
+
+    @Override
+    public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+        try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel()))) {
+            addInput.add(0, result);
+        }
+    }
+
+    @Override
+    public void close() {
+        evaluator.close();
+    }
+
+    /**
+     * Similar implementation to an Evaluator.
+     */
+    public static final class CategorizeEvaluator implements Releasable {
+        private final CategorizationAnalyzer analyzer;
+
+        private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
+
+        private final BlockFactory blockFactory;
+
+        public CategorizeEvaluator(
+            CategorizationAnalyzer analyzer,
+            TokenListCategorizer.CloseableTokenListCategorizer categorizer,
+            BlockFactory blockFactory
+        ) {
+            this.analyzer = analyzer;
+            this.categorizer = categorizer;
+            this.blockFactory = blockFactory;
+        }
+
+        public Block eval(BytesRefBlock vBlock) {
+            BytesRefVector vVector = vBlock.asVector();
+            if (vVector == null) {
+                return eval(vBlock.getPositionCount(), vBlock);
+            }
+            IntVector vector = eval(vBlock.getPositionCount(), vVector);
+            return vector.asBlock();
+        }
+
+        public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
+            try (IntBlock.Builder result = blockFactory.newIntBlockBuilder(positionCount)) {
+                BytesRef vScratch = new BytesRef();
+                for (int p = 0; p < positionCount; p++) {
+                    if (vBlock.isNull(p)) {
+                        result.appendNull();
+                        continue;
+                    }
+                    int first = vBlock.getFirstValueIndex(p);
+                    int count = vBlock.getValueCount(p);
+                    if (count == 1) {
+                        result.appendInt(process(vBlock.getBytesRef(first, vScratch)));
+                        continue;
+                    }
+                    int end = first + count;
+                    result.beginPositionEntry();
+                    for (int i = first; i < end; i++) {
+                        result.appendInt(process(vBlock.getBytesRef(i, vScratch)));
+                    }
+                    result.endPositionEntry();
+                }
+                return result.build();
+            }
+        }
+
+        public IntVector eval(int positionCount, BytesRefVector vVector) {
+            try (IntVector.FixedBuilder result = blockFactory.newIntVectorFixedBuilder(positionCount)) {
+                BytesRef vScratch = new BytesRef();
+                for (int p = 0; p < positionCount; p++) {
+                    result.appendInt(p, process(vVector.getBytesRef(p, vScratch)));
+                }
+                return result.build();
+            }
+        }
+
+        private int process(BytesRef v) {
+            return categorizer.computeCategory(v.utf8ToString(), analyzer).getId();
+        }
+
+        @Override
+        public void close() {
+            Releasables.closeExpectNoException(analyzer, categorizer);
+        }
+    }
+}

+ 77 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * BlockHash implementation for {@code Categorize} grouping function.
+ * <p>
+ *     This implementation expects a single intermediate state in a block, as generated by {@link AbstractCategorizeBlockHash}.
+ * </p>
+ */
+public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHash {
+
+    CategorizedIntermediateBlockHash(int channel, BlockFactory blockFactory, boolean outputPartial) {
+        super(blockFactory, channel, outputPartial);
+    }
+
+    @Override
+    public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+        if (page.getPositionCount() == 0) {
+            // No categories
+            return;
+        }
+        BytesRefBlock categorizerState = page.getBlock(channel());
+        Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
+        try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
+            for (int i = 0; i < idMap.size(); i++) {
+                newIdsBuilder.appendInt(idMap.get(i));
+            }
+            try (IntBlock newIds = newIdsBuilder.build()) {
+                addInput.add(0, newIds);
+            }
+        }
+    }
+
+    /**
+     * Read intermediate state from a block.
+     *
+     * @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
+     */
+    private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
+        Map<Integer, Integer> idMap = new HashMap<>();
+        try (StreamInput in = new BytesArray(bytes).streamInput()) {
+            int count = in.readVInt();
+            for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
+                int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
+                idMap.put(oldCategoryId, newCategoryId);
+            }
+            return idMap;
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    @Override
+    public void close() {
+        categorizer.close();
+    }
+}

+ 9 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.compute.Describable;
+import org.elasticsearch.compute.aggregation.AggregatorMode;
 import org.elasticsearch.compute.aggregation.GroupingAggregator;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
 import org.elasticsearch.compute.aggregation.blockhash.BlockHash;
@@ -39,11 +40,19 @@ public class HashAggregationOperator implements Operator {
 
     public record HashAggregationOperatorFactory(
         List<BlockHash.GroupSpec> groups,
+        AggregatorMode aggregatorMode,
         List<GroupingAggregator.Factory> aggregators,
         int maxPageSize
     ) implements OperatorFactory {
         @Override
         public Operator get(DriverContext driverContext) {
+            if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
+                return new HashAggregationOperator(
+                    aggregators,
+                    () -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory()),
+                    driverContext
+                );
+            }
             return new HashAggregationOperator(
                 aggregators,
                 () -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),

+ 1 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java

@@ -105,6 +105,7 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
         }
         return new HashAggregationOperator.HashAggregationOperatorFactory(
             List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
+            mode,
             List.of(supplier.groupingAggregatorFactory(mode)),
             randomPageSize()
         );

+ 34 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTestCase.java

@@ -0,0 +1,34 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.elasticsearch.common.breaker.CircuitBreaker;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.MockBigArrays;
+import org.elasticsearch.common.util.PageCacheRecycler;
+import org.elasticsearch.compute.data.MockBlockFactory;
+import org.elasticsearch.indices.breaker.CircuitBreakerService;
+import org.elasticsearch.test.ESTestCase;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public abstract class BlockHashTestCase extends ESTestCase {
+
+    final CircuitBreaker breaker = newLimitedBreaker(ByteSizeValue.ofGb(1));
+    final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
+    final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
+
+    // A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
+    private static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
+        CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
+        when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
+        return breakerService;
+    }
+}

+ 1 - 21
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java

@@ -11,11 +11,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name;
 import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
 
 import org.apache.lucene.util.BytesRef;
-import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.unit.ByteSizeValue;
-import org.elasticsearch.common.util.BigArrays;
-import org.elasticsearch.common.util.MockBigArrays;
-import org.elasticsearch.common.util.PageCacheRecycler;
 import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BooleanBlock;
@@ -26,7 +22,6 @@ import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LongBlock;
-import org.elasticsearch.compute.data.MockBlockFactory;
 import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
 import org.elasticsearch.compute.data.OrdinalBytesRefVector;
 import org.elasticsearch.compute.data.Page;
@@ -34,8 +29,6 @@ import org.elasticsearch.compute.data.TestBlockFactory;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.ReleasableIterator;
 import org.elasticsearch.core.Releasables;
-import org.elasticsearch.indices.breaker.CircuitBreakerService;
-import org.elasticsearch.test.ESTestCase;
 import org.junit.After;
 
 import java.util.ArrayList;
@@ -54,14 +47,8 @@ import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.startsWith;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
 
-public class BlockHashTests extends ESTestCase {
-
-    final CircuitBreaker breaker = new MockBigArrays.LimitedBreaker("esql-test-breaker", ByteSizeValue.ofGb(1));
-    final BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, mockBreakerService(breaker));
-    final MockBlockFactory blockFactory = new MockBlockFactory(breaker, bigArrays);
+public class BlockHashTests extends BlockHashTestCase {
 
     @ParametersFactory
     public static List<Object[]> params() {
@@ -1534,13 +1521,6 @@ public class BlockHashTests extends ESTestCase {
         }
     }
 
-    // A breaker service that always returns the given breaker for getBreaker(CircuitBreaker.REQUEST)
-    static CircuitBreakerService mockBreakerService(CircuitBreaker breaker) {
-        CircuitBreakerService breakerService = mock(CircuitBreakerService.class);
-        when(breakerService.getBreaker(CircuitBreaker.REQUEST)).thenReturn(breaker);
-        return breakerService;
-    }
-
     IntVector intRange(int startInclusive, int endExclusive) {
         return IntVector.range(startInclusive, endExclusive, TestBlockFactory.getNonBreakingInstance());
     }

+ 406 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

@@ -0,0 +1,406 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.breaker.CircuitBreaker;
+import org.elasticsearch.common.collect.Iterators;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.MockBigArrays;
+import org.elasticsearch.common.util.PageCacheRecycler;
+import org.elasticsearch.compute.aggregation.AggregatorMode;
+import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier;
+import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.BytesRefVector;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.LongVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.CannedSourceOperator;
+import org.elasticsearch.compute.operator.Driver;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.HashAggregationOperator;
+import org.elasticsearch.compute.operator.LocalSourceOperator;
+import org.elasticsearch.compute.operator.PageConsumerOperator;
+import org.elasticsearch.core.Releasables;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+
+public class CategorizeBlockHashTests extends BlockHashTestCase {
+
+    public void testCategorizeRaw() {
+        final Page page;
+        final int positions = 7;
+        try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
+            builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
+            builder.appendBytesRef(new BytesRef("Connection error"));
+            builder.appendBytesRef(new BytesRef("Connection error"));
+            builder.appendBytesRef(new BytesRef("Connection error"));
+            builder.appendBytesRef(new BytesRef("Disconnected"));
+            builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
+            builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
+            page = new Page(builder.build());
+        }
+
+        try (BlockHash hash = new CategorizeRawBlockHash(0, blockFactory, true)) {
+            hash.add(page, new GroupingAggregatorFunction.AddInput() {
+                @Override
+                public void add(int positionOffset, IntBlock groupIds) {
+                    assertEquals(groupIds.getPositionCount(), positions);
+
+                    assertEquals(0, groupIds.getInt(0));
+                    assertEquals(1, groupIds.getInt(1));
+                    assertEquals(1, groupIds.getInt(2));
+                    assertEquals(1, groupIds.getInt(3));
+                    assertEquals(2, groupIds.getInt(4));
+                    assertEquals(0, groupIds.getInt(5));
+                    assertEquals(0, groupIds.getInt(6));
+                }
+
+                @Override
+                public void add(int positionOffset, IntVector groupIds) {
+                    add(positionOffset, groupIds.asBlock());
+                }
+
+                @Override
+                public void close() {
+                    fail("hashes should not close AddInput");
+                }
+            });
+        } finally {
+            page.releaseBlocks();
+        }
+
+        // TODO: randomize and try multiple pages.
+        // TODO: assert the state of the BlockHash after adding pages. Including the categorizer state.
+        // TODO: also test the lookup method and other stuff.
+    }
+
+    public void testCategorizeIntermediate() {
+        Page page1;
+        int positions1 = 7;
+        try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions1)) {
+            builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
+            builder.appendBytesRef(new BytesRef("Connection error"));
+            builder.appendBytesRef(new BytesRef("Connection error"));
+            builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
+            builder.appendBytesRef(new BytesRef("Connection error"));
+            builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
+            builder.appendBytesRef(new BytesRef("Connected to 10.1.0.4"));
+            page1 = new Page(builder.build());
+        }
+        Page page2;
+        int positions2 = 5;
+        try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions2)) {
+            builder.appendBytesRef(new BytesRef("Disconnected"));
+            builder.appendBytesRef(new BytesRef("Connected to 10.2.0.1"));
+            builder.appendBytesRef(new BytesRef("Disconnected"));
+            builder.appendBytesRef(new BytesRef("Connected to 10.3.0.2"));
+            builder.appendBytesRef(new BytesRef("System shutdown"));
+            page2 = new Page(builder.build());
+        }
+
+        Page intermediatePage1, intermediatePage2;
+
+        // Fill intermediatePages with the intermediate state from the raw hashes
+        try (
+            BlockHash rawHash1 = new CategorizeRawBlockHash(0, blockFactory, true);
+            BlockHash rawHash2 = new CategorizeRawBlockHash(0, blockFactory, true)
+        ) {
+            rawHash1.add(page1, new GroupingAggregatorFunction.AddInput() {
+                @Override
+                public void add(int positionOffset, IntBlock groupIds) {
+                    assertEquals(groupIds.getPositionCount(), positions1);
+                    assertEquals(0, groupIds.getInt(0));
+                    assertEquals(1, groupIds.getInt(1));
+                    assertEquals(1, groupIds.getInt(2));
+                    assertEquals(0, groupIds.getInt(3));
+                    assertEquals(1, groupIds.getInt(4));
+                    assertEquals(0, groupIds.getInt(5));
+                    assertEquals(0, groupIds.getInt(6));
+                }
+
+                @Override
+                public void add(int positionOffset, IntVector groupIds) {
+                    add(positionOffset, groupIds.asBlock());
+                }
+
+                @Override
+                public void close() {
+                    fail("hashes should not close AddInput");
+                }
+            });
+            intermediatePage1 = new Page(rawHash1.getKeys()[0]);
+
+            rawHash2.add(page2, new GroupingAggregatorFunction.AddInput() {
+                @Override
+                public void add(int positionOffset, IntBlock groupIds) {
+                    assertEquals(groupIds.getPositionCount(), positions2);
+                    assertEquals(0, groupIds.getInt(0));
+                    assertEquals(1, groupIds.getInt(1));
+                    assertEquals(0, groupIds.getInt(2));
+                    assertEquals(1, groupIds.getInt(3));
+                    assertEquals(2, groupIds.getInt(4));
+                }
+
+                @Override
+                public void add(int positionOffset, IntVector groupIds) {
+                    add(positionOffset, groupIds.asBlock());
+                }
+
+                @Override
+                public void close() {
+                    fail("hashes should not close AddInput");
+                }
+            });
+            intermediatePage2 = new Page(rawHash2.getKeys()[0]);
+        } finally {
+            page1.releaseBlocks();
+            page2.releaseBlocks();
+        }
+
+        try (BlockHash intermediateHash = new CategorizedIntermediateBlockHash(0, blockFactory, true)) {
+            intermediateHash.add(intermediatePage1, new GroupingAggregatorFunction.AddInput() {
+                @Override
+                public void add(int positionOffset, IntBlock groupIds) {
+                    Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
+                        .map(groupIds::getInt)
+                        .boxed()
+                        .collect(Collectors.toSet());
+                    assertEquals(values, Set.of(0, 1));
+                }
+
+                @Override
+                public void add(int positionOffset, IntVector groupIds) {
+                    add(positionOffset, groupIds.asBlock());
+                }
+
+                @Override
+                public void close() {
+                    fail("hashes should not close AddInput");
+                }
+            });
+
+            intermediateHash.add(intermediatePage2, new GroupingAggregatorFunction.AddInput() {
+                @Override
+                public void add(int positionOffset, IntBlock groupIds) {
+                    Set<Integer> values = IntStream.range(0, groupIds.getPositionCount())
+                        .map(groupIds::getInt)
+                        .boxed()
+                        .collect(Collectors.toSet());
+                    // The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
+                    // 0 matches an existing category (Connected to ...), and the others are new.
+                    assertEquals(values, Set.of(0, 2, 3));
+                }
+
+                @Override
+                public void add(int positionOffset, IntVector groupIds) {
+                    add(positionOffset, groupIds.asBlock());
+                }
+
+                @Override
+                public void close() {
+                    fail("hashes should not close AddInput");
+                }
+            });
+        } finally {
+            intermediatePage1.releaseBlocks();
+            intermediatePage2.releaseBlocks();
+        }
+    }
+
+    public void testCategorize_withDriver() {
+        BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
+        CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
+        DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
+
+        LocalSourceOperator.BlockSupplier input1 = () -> {
+            try (
+                BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
+                LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
+            ) {
+                textsBuilder.appendBytesRef(new BytesRef("a"));
+                textsBuilder.appendBytesRef(new BytesRef("b"));
+                textsBuilder.appendBytesRef(new BytesRef("words words words goodbye jan"));
+                textsBuilder.appendBytesRef(new BytesRef("words words words goodbye nik"));
+                textsBuilder.appendBytesRef(new BytesRef("words words words goodbye tom"));
+                textsBuilder.appendBytesRef(new BytesRef("words words words hello jan"));
+                textsBuilder.appendBytesRef(new BytesRef("c"));
+                textsBuilder.appendBytesRef(new BytesRef("d"));
+                countsBuilder.appendLong(1);
+                countsBuilder.appendLong(2);
+                countsBuilder.appendLong(800);
+                countsBuilder.appendLong(80);
+                countsBuilder.appendLong(8000);
+                countsBuilder.appendLong(900);
+                countsBuilder.appendLong(30);
+                countsBuilder.appendLong(4);
+                return new Block[] { textsBuilder.build().asBlock(), countsBuilder.build().asBlock() };
+            }
+        };
+        LocalSourceOperator.BlockSupplier input2 = () -> {
+            try (
+                BytesRefVector.Builder textsBuilder = driverContext.blockFactory().newBytesRefVectorBuilder(10);
+                LongVector.Builder countsBuilder = driverContext.blockFactory().newLongVectorBuilder(10)
+            ) {
+                textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
+                textsBuilder.appendBytesRef(new BytesRef("words words words hello nik"));
+                textsBuilder.appendBytesRef(new BytesRef("c"));
+                textsBuilder.appendBytesRef(new BytesRef("words words words goodbye chris"));
+                textsBuilder.appendBytesRef(new BytesRef("d"));
+                textsBuilder.appendBytesRef(new BytesRef("e"));
+                countsBuilder.appendLong(9);
+                countsBuilder.appendLong(90);
+                countsBuilder.appendLong(3);
+                countsBuilder.appendLong(8);
+                countsBuilder.appendLong(40);
+                countsBuilder.appendLong(5);
+                return new Block[] { textsBuilder.build().asBlock(), countsBuilder.build().asBlock() };
+            }
+        };
+
+        List<Page> intermediateOutput = new ArrayList<>();
+
+        Driver driver = new Driver(
+            driverContext,
+            new LocalSourceOperator(input1),
+            List.of(
+                new HashAggregationOperator.HashAggregationOperatorFactory(
+                    List.of(makeGroupSpec()),
+                    AggregatorMode.INITIAL,
+                    List.of(
+                        new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
+                        new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
+                    ),
+                    16 * 1024
+                ).get(driverContext)
+            ),
+            new PageConsumerOperator(intermediateOutput::add),
+            () -> {}
+        );
+        runDriver(driver);
+
+        driver = new Driver(
+            driverContext,
+            new LocalSourceOperator(input2),
+            List.of(
+                new HashAggregationOperator.HashAggregationOperatorFactory(
+                    List.of(makeGroupSpec()),
+                    AggregatorMode.INITIAL,
+                    List.of(
+                        new SumLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL),
+                        new MaxLongAggregatorFunctionSupplier(List.of(1)).groupingAggregatorFactory(AggregatorMode.INITIAL)
+                    ),
+                    16 * 1024
+                ).get(driverContext)
+            ),
+            new PageConsumerOperator(intermediateOutput::add),
+            () -> {}
+        );
+        runDriver(driver);
+
+        List<Page> finalOutput = new ArrayList<>();
+
+        driver = new Driver(
+            driverContext,
+            new CannedSourceOperator(intermediateOutput.iterator()),
+            List.of(
+                new HashAggregationOperator.HashAggregationOperatorFactory(
+                    List.of(makeGroupSpec()),
+                    AggregatorMode.FINAL,
+                    List.of(
+                        new SumLongAggregatorFunctionSupplier(List.of(1, 2)).groupingAggregatorFactory(AggregatorMode.FINAL),
+                        new MaxLongAggregatorFunctionSupplier(List.of(3, 4)).groupingAggregatorFactory(AggregatorMode.FINAL)
+                    ),
+                    16 * 1024
+                ).get(driverContext)
+            ),
+            new PageConsumerOperator(finalOutput::add),
+            () -> {}
+        );
+        runDriver(driver);
+
+        assertThat(finalOutput, hasSize(1));
+        assertThat(finalOutput.get(0).getBlockCount(), equalTo(3));
+        BytesRefBlock outputTexts = finalOutput.get(0).getBlock(0);
+        LongBlock outputSums = finalOutput.get(0).getBlock(1);
+        LongBlock outputMaxs = finalOutput.get(0).getBlock(2);
+        assertThat(outputSums.getPositionCount(), equalTo(outputTexts.getPositionCount()));
+        assertThat(outputMaxs.getPositionCount(), equalTo(outputTexts.getPositionCount()));
+        Map<String, Long> sums = new HashMap<>();
+        Map<String, Long> maxs = new HashMap<>();
+        for (int i = 0; i < outputTexts.getPositionCount(); i++) {
+            sums.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputSums.getLong(i));
+            maxs.put(outputTexts.getBytesRef(i, new BytesRef()).utf8ToString(), outputMaxs.getLong(i));
+        }
+        assertThat(
+            sums,
+            equalTo(
+                Map.of(
+                    ".*?a.*?",
+                    1L,
+                    ".*?b.*?",
+                    2L,
+                    ".*?c.*?",
+                    33L,
+                    ".*?d.*?",
+                    44L,
+                    ".*?e.*?",
+                    5L,
+                    ".*?words.+?words.+?words.+?goodbye.*?",
+                    8888L,
+                    ".*?words.+?words.+?words.+?hello.*?",
+                    999L
+                )
+            )
+        );
+        assertThat(
+            maxs,
+            equalTo(
+                Map.of(
+                    ".*?a.*?",
+                    1L,
+                    ".*?b.*?",
+                    2L,
+                    ".*?c.*?",
+                    30L,
+                    ".*?d.*?",
+                    40L,
+                    ".*?e.*?",
+                    5L,
+                    ".*?words.+?words.+?words.+?goodbye.*?",
+                    8000L,
+                    ".*?words.+?words.+?words.+?hello.*?",
+                    900L
+                )
+            )
+        );
+        Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
+    }
+
+    private BlockHash.GroupSpec makeGroupSpec() {
+        return new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true);
+    }
+}

+ 1 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java

@@ -54,6 +54,7 @@ public class HashAggregationOperatorTests extends ForkingOperatorTestCase {
 
         return new HashAggregationOperator.HashAggregationOperatorFactory(
             List.of(new BlockHash.GroupSpec(0, ElementType.LONG)),
+            mode,
             List.of(
                 new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode),
                 new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode)

+ 2 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java

@@ -59,6 +59,7 @@ public class CsvTestsDataLoader {
     private static final TestsDataset ALERTS = new TestsDataset("alerts");
     private static final TestsDataset UL_LOGS = new TestsDataset("ul_logs");
     private static final TestsDataset SAMPLE_DATA = new TestsDataset("sample_data");
+    private static final TestsDataset MV_SAMPLE_DATA = new TestsDataset("mv_sample_data");
     private static final TestsDataset SAMPLE_DATA_STR = SAMPLE_DATA.withIndex("sample_data_str")
         .withTypeMapping(Map.of("client_ip", "keyword"));
     private static final TestsDataset SAMPLE_DATA_TS_LONG = SAMPLE_DATA.withIndex("sample_data_ts_long")
@@ -103,6 +104,7 @@ public class CsvTestsDataLoader {
         Map.entry(LANGUAGES.indexName, LANGUAGES),
         Map.entry(UL_LOGS.indexName, UL_LOGS),
         Map.entry(SAMPLE_DATA.indexName, SAMPLE_DATA),
+        Map.entry(MV_SAMPLE_DATA.indexName, MV_SAMPLE_DATA),
         Map.entry(ALERTS.indexName, ALERTS),
         Map.entry(SAMPLE_DATA_STR.indexName, SAMPLE_DATA_STR),
         Map.entry(SAMPLE_DATA_TS_LONG.indexName, SAMPLE_DATA_TS_LONG),

+ 518 - 8
x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec

@@ -1,14 +1,524 @@
-categorize
-required_capability: categorize
+standard aggs
+required_capability: categorize_v2
 
 FROM sample_data
-  | SORT message ASC
-  | STATS count=COUNT(), values=MV_SORT(VALUES(message)) BY category=CATEGORIZE(message)
+  | STATS count=COUNT(),
+          sum=SUM(event_duration),
+          avg=AVG(event_duration),
+          count_distinct=COUNT_DISTINCT(event_duration)
+       BY category=CATEGORIZE(message)
+  | SORT count DESC, category
+;
+
+count:long | sum:long |     avg:double     | count_distinct:long | category:keyword
+         3 |  7971589 | 2657196.3333333335 |                   3 | .*?Connected.+?to.*?
+         3 | 14027356 | 4675785.333333333  |                   3 | .*?Connection.+?error.*?
+         1 |  1232382 | 1232382.0          |                   1 | .*?Disconnected.*?
+;
+
+values aggs
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS values=MV_SORT(VALUES(message)),
+          top=TOP(event_duration, 2, "DESC")
+       BY category=CATEGORIZE(message)
+  | SORT category
+;
+
+values:keyword                                                        |      top:long      | category:keyword
+[Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | [3450233, 2764889] | .*?Connected.+?to.*?
+[Connection error]                                                    | [8268153, 5033755] | .*?Connection.+?error.*?
+[Disconnected]                                                        |           1232382  | .*?Disconnected.*?
+;
+
+mv
+required_capability: categorize_v2
+
+FROM mv_sample_data
+  | STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(message)
+  | SORT category
+;
+
+COUNT():long | SUM(event_duration):long | category:keyword
+           7 |                 23231327 | .*?Banana.*?
+           3 |                  7971589 | .*?Connected.+?to.*?
+           3 |                 14027356 | .*?Connection.+?error.*?
+           1 |                  1232382 | .*?Disconnected.*?
+;
+
+row mv
+required_capability: categorize_v2
+
+ROW message = ["connected to a", "connected to b", "disconnected"], str = ["a", "b", "c"]
+  | STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message)
+  | SORT category
+;
+
+COUNT():long | VALUES(str):keyword | category:keyword
+           2 | [a, b, c]           | .*?connected.+?to.*?
+           1 | [a, b, c]           | .*?disconnected.*?
+;
+
+with multiple indices
+required_capability: categorize_v2
+required_capability: union_types
+
+FROM sample_data*
+  | STATS COUNT() BY category=CATEGORIZE(message)
+  | SORT category
+;
+
+COUNT():long | category:keyword
+          12 | .*?Connected.+?to.*?
+          12 | .*?Connection.+?error.*?
+           4 | .*?Disconnected.*?
+;
+
+mv with many values
+required_capability: categorize_v2
+
+FROM employees
+  | STATS COUNT() BY category=CATEGORIZE(job_positions)
+  | SORT category
+  | LIMIT 5
+;
+
+COUNT():long | category:keyword
+           18 | .*?Accountant.*?
+           13 | .*?Architect.*?
+           11 | .*?Business.+?Analyst.*?
+           13 | .*?Data.+?Scientist.*?
+           10 | .*?Head.+?Human.+?Resources.*?
+;
+
+# Throws when calling AbstractCategorizeBlockHash.seenGroupIds() - Requires nulls support?
+mv with many values-Ignore
+required_capability: categorize_v2
+
+FROM employees
+  | STATS SUM(languages) BY category=CATEGORIZE(job_positions)
+  | SORT category DESC
+  | LIMIT 3
+;
+
+SUM(languages):integer | category:keyword
+                    43 | .*?Accountant.*?
+                    46 | .*?Architect.*?
+                    35 | .*?Business.+?Analyst.*?
+;
+
+mv via eval
+required_capability: categorize_v2
+
+FROM sample_data
+  | EVAL message = MV_APPEND(message, "Banana")
+  | STATS COUNT() BY category=CATEGORIZE(message)
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           7 | .*?Banana.*?
+           3 | .*?Connected.+?to.*?
+           3 | .*?Connection.+?error.*?
+           1 | .*?Disconnected.*?
+;
+
+mv via eval const
+required_capability: categorize_v2
+
+FROM sample_data
+  | EVAL message = ["Banana", "Bread"]
+  | STATS COUNT() BY category=CATEGORIZE(message)
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           7 | .*?Banana.*?
+           7 | .*?Bread.*?
+;
+
+mv via eval const without aliases
+required_capability: categorize_v2
+
+FROM sample_data
+  | EVAL message = ["Banana", "Bread"]
+  | STATS COUNT() BY CATEGORIZE(message)
+  | SORT `CATEGORIZE(message)`
+;
+
+COUNT():long | CATEGORIZE(message):keyword
+           7 | .*?Banana.*?
+           7 | .*?Bread.*?
+;
+
+mv const in parameter
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
+  | SORT c
+;
+
+COUNT():long | c:keyword
+           7 | .*?Banana.*?
+           7 | .*?Bread.*?
+;
+
+agg alias shadowing
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS c = COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
+  | SORT c
+;
+
+warning:Line 2:9: Field 'c' shadowed by field at line 2:24
+
+c:keyword
+.*?Banana.*?
+.*?Bread.*?
+;
+
+chained aggregations using categorize
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS COUNT() BY category=CATEGORIZE(message)
+  | STATS COUNT() BY category=CATEGORIZE(category)
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           1 | .*?\.\*\?Connected\.\+\?to\.\*\?.*?
+           1 | .*?\.\*\?Connection\.\+\?error\.\*\?.*?
+           1 | .*?\.\*\?Disconnected\.\*\?.*?
+;
+
+stats without aggs
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS BY category=CATEGORIZE(message)
+  | SORT category
+;
+
+category:keyword
+.*?Connected.+?to.*?
+.*?Connection.+?error.*?
+.*?Disconnected.*?
+;
+
+text field
+required_capability: categorize_v2
+
+FROM hosts
+  | STATS COUNT() BY category=CATEGORIZE(host_group)
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           2 | .*?DB.+?servers.*?
+           2 | .*?Gateway.+?instances.*?
+           5 | .*?Kubernetes.+?cluster.*?
+;
+
+on TO_UPPER
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS COUNT() BY category=CATEGORIZE(TO_UPPER(message))
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           3 | .*?CONNECTED.+?TO.*?
+           3 | .*?CONNECTION.+?ERROR.*?
+           1 | .*?DISCONNECTED.*?
+;
+
+on CONCAT
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " banana"))
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           3 | .*?Connected.+?to.+?banana.*?
+           3 | .*?Connection.+?error.+?banana.*?
+           1 | .*?Disconnected.+?banana.*?
+;
+
+on CONCAT with unicode
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " 👍🏽😊"))
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           3 | .*?Connected.+?to.+?👍🏽😊.*?
+           3 | .*?Connection.+?error.+?👍🏽😊.*?
+           1 | .*?Disconnected.+?👍🏽😊.*?
+;
+
+on REVERSE(CONCAT())
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS COUNT() BY category=CATEGORIZE(REVERSE(CONCAT(message, " 👍🏽😊")))
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           1 | .*?😊👍🏽.+?detcennocsiD.*?
+           3 | .*?😊👍🏽.+?ot.+?detcennoC.*?
+           3 | .*?😊👍🏽.+?rorre.+?noitcennoC.*?
+;
+
+and then TO_LOWER
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS COUNT() BY category=CATEGORIZE(message)
+  | EVAL category=TO_LOWER(category)
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           3 | .*?connected.+?to.*?
+           3 | .*?connection.+?error.*?
+           1 | .*?disconnected.*?
+;
+
+# Throws NPE - Requires nulls support
+on const empty string-Ignore
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS COUNT() BY category=CATEGORIZE("")
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           7 | .*?.*?
+;
+
+# Throws NPE - Requires nulls support
+on const empty string from eval-Ignore
+required_capability: categorize_v2
+
+FROM sample_data
+  | EVAL x = ""
+  | STATS COUNT() BY category=CATEGORIZE(x)
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           7 | .*?.*?
+;
+
+# Doesn't give the correct results - Requires nulls support
+on null-Ignore
+required_capability: categorize_v2
+
+FROM sample_data
+  | EVAL x = null
+  | STATS COUNT() BY category=CATEGORIZE(x)
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           7 | null
+;
+
+# Doesn't give the correct results - Requires nulls support
+on null string-Ignore
+required_capability: categorize_v2
+
+FROM sample_data
+  | EVAL x = null::string
+  | STATS COUNT() BY category=CATEGORIZE(x)
+  | SORT category
+;
+
+COUNT():long | category:keyword
+           7 | null
+;
+
+filtering out all data
+required_capability: categorize_v2
+
+FROM sample_data
+  | WHERE @timestamp < "2023-10-23T00:00:00Z"
+  | STATS COUNT() BY category=CATEGORIZE(message)
+  | SORT category
+;
+
+COUNT():long | category:keyword
+;
+
+filtering out all data with constant
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS COUNT() BY category=CATEGORIZE(message)
+  | WHERE false
+;
+
+COUNT():long | category:keyword
+;
+
+drop output columns
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS count=COUNT() BY category=CATEGORIZE(message)
+  | EVAL x=1
+  | DROP count, category
+;
+
+x:integer
+1
+1
+1
+;
+
+category value processing
+required_capability: categorize_v2
+
+ROW message = ["connected to a", "connected to b", "disconnected"]
+  | STATS COUNT() BY category=CATEGORIZE(message)
+  | EVAL category = TO_UPPER(category)
   | SORT category
 ;
 
-count:long | values:keyword                                                        | category:integer
-3          | [Connected to 10.1.0.1, Connected to 10.1.0.2, Connected to 10.1.0.3] | 0
-3          | [Connection error]                                                    | 1
-1          | [Disconnected]                                                        | 2
+COUNT():long | category:keyword
+           2 | .*?CONNECTED.+?TO.*?
+           1 | .*?DISCONNECTED.*?
+;
+
+row aliases
+required_capability: categorize_v2
+
+ROW message = "connected to a"
+  | EVAL x = message
+  | STATS COUNT() BY category=CATEGORIZE(x)
+  | EVAL y = category
+  | SORT y
+;
+
+COUNT():long | category:keyword         | y:keyword
+           1 | .*?connected.+?to.+?a.*? | .*?connected.+?to.+?a.*?
+;
+
+from aliases
+required_capability: categorize_v2
+
+FROM sample_data
+  | EVAL x = message
+  | STATS COUNT() BY category=CATEGORIZE(x)
+  | EVAL y = category
+  | SORT y
+;
+
+COUNT():long | category:keyword         | y:keyword
+           3 | .*?Connected.+?to.*?     | .*?Connected.+?to.*?
+           3 | .*?Connection.+?error.*? | .*?Connection.+?error.*?
+           1 | .*?Disconnected.*?       | .*?Disconnected.*?
+;
+
+row aliases with keep
+required_capability: categorize_v2
+
+ROW message = "connected to a"
+  | EVAL x = message
+  | KEEP x
+  | STATS COUNT() BY category=CATEGORIZE(x)
+  | EVAL y = category
+  | KEEP `COUNT()`, y
+  | SORT y
+;
+
+COUNT():long | y:keyword
+           1 | .*?connected.+?to.+?a.*?
+;
+
+from aliases with keep
+required_capability: categorize_v2
+
+FROM sample_data
+  | EVAL x = message
+  | KEEP x
+  | STATS COUNT() BY category=CATEGORIZE(x)
+  | EVAL y = category
+  | KEEP `COUNT()`, y
+  | SORT y
+;
+
+COUNT():long | y:keyword
+           3 | .*?Connected.+?to.*?
+           3 | .*?Connection.+?error.*?
+           1 | .*?Disconnected.*?
+;
+
+row rename
+required_capability: categorize_v2
+
+ROW message = "connected to a"
+  | RENAME message as x
+  | STATS COUNT() BY category=CATEGORIZE(x)
+  | RENAME category as y
+  | SORT y
+;
+
+COUNT():long | y:keyword
+           1 | .*?connected.+?to.+?a.*?
+;
+
+from rename
+required_capability: categorize_v2
+
+FROM sample_data
+  | RENAME message as x
+  | STATS COUNT() BY category=CATEGORIZE(x)
+  | RENAME category as y
+  | SORT y
+;
+
+COUNT():long | y:keyword
+           3 | .*?Connected.+?to.*?
+           3 | .*?Connection.+?error.*?
+           1 | .*?Disconnected.*?
+;
+
+row drop
+required_capability: categorize_v2
+
+ROW message = "connected to a"
+  | STATS c = COUNT() BY category=CATEGORIZE(message)
+  | DROP category
+  | SORT c
+;
+
+c:long
+1
+;
+
+from drop
+required_capability: categorize_v2
+
+FROM sample_data
+  | STATS c = COUNT() BY category=CATEGORIZE(message)
+  | DROP category
+  | SORT c
+;
+
+c:long
+1
+3
+3
 ;

+ 16 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-mv_sample_data.json

@@ -0,0 +1,16 @@
+{
+    "properties": {
+        "@timestamp": {
+            "type": "date"
+        },
+        "client_ip": {
+            "type": "ip"
+        },
+        "event_duration": {
+            "type": "long"
+        },
+        "message": {
+            "type": "keyword"
+        }
+    }
+}

+ 8 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_sample_data.csv

@@ -0,0 +1,8 @@
+@timestamp:date         ,client_ip:ip,event_duration:long,message:keyword
+2023-10-23T13:55:01.543Z,172.21.3.15 ,1756467,[Connected to 10.1.0.1, Banana]
+2023-10-23T13:53:55.832Z,172.21.3.15 ,5033755,[Connection error, Banana]
+2023-10-23T13:52:55.015Z,172.21.3.15 ,8268153,[Connection error, Banana]
+2023-10-23T13:51:54.732Z,172.21.3.15 , 725448,[Connection error, Banana]
+2023-10-23T13:33:34.937Z,172.21.0.5  ,1232382,[Disconnected, Banana]
+2023-10-23T12:27:28.948Z,172.21.2.113,2764889,[Connected to 10.1.0.2, Banana]
+2023-10-23T12:15:03.360Z,172.21.2.162,3450233,[Connected to 10.1.0.3, Banana]

+ 0 - 145
x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeEvaluator.java

@@ -1,145 +0,0 @@
-// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
-// or more contributor license agreements. Licensed under the Elastic License
-// 2.0; you may not use this file except in compliance with the Elastic License
-// 2.0.
-package org.elasticsearch.xpack.esql.expression.function.grouping;
-
-import java.lang.IllegalArgumentException;
-import java.lang.Override;
-import java.lang.String;
-import java.util.function.Function;
-import org.apache.lucene.util.BytesRef;
-import org.elasticsearch.compute.data.Block;
-import org.elasticsearch.compute.data.BytesRefBlock;
-import org.elasticsearch.compute.data.BytesRefVector;
-import org.elasticsearch.compute.data.IntBlock;
-import org.elasticsearch.compute.data.IntVector;
-import org.elasticsearch.compute.data.Page;
-import org.elasticsearch.compute.operator.DriverContext;
-import org.elasticsearch.compute.operator.EvalOperator;
-import org.elasticsearch.compute.operator.Warnings;
-import org.elasticsearch.core.Releasables;
-import org.elasticsearch.xpack.esql.core.tree.Source;
-import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
-import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
-
-/**
- * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Categorize}.
- * This class is generated. Do not edit it.
- */
-public final class CategorizeEvaluator implements EvalOperator.ExpressionEvaluator {
-  private final Source source;
-
-  private final EvalOperator.ExpressionEvaluator v;
-
-  private final CategorizationAnalyzer analyzer;
-
-  private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
-
-  private final DriverContext driverContext;
-
-  private Warnings warnings;
-
-  public CategorizeEvaluator(Source source, EvalOperator.ExpressionEvaluator v,
-      CategorizationAnalyzer analyzer,
-      TokenListCategorizer.CloseableTokenListCategorizer categorizer, DriverContext driverContext) {
-    this.source = source;
-    this.v = v;
-    this.analyzer = analyzer;
-    this.categorizer = categorizer;
-    this.driverContext = driverContext;
-  }
-
-  @Override
-  public Block eval(Page page) {
-    try (BytesRefBlock vBlock = (BytesRefBlock) v.eval(page)) {
-      BytesRefVector vVector = vBlock.asVector();
-      if (vVector == null) {
-        return eval(page.getPositionCount(), vBlock);
-      }
-      return eval(page.getPositionCount(), vVector).asBlock();
-    }
-  }
-
-  public IntBlock eval(int positionCount, BytesRefBlock vBlock) {
-    try(IntBlock.Builder result = driverContext.blockFactory().newIntBlockBuilder(positionCount)) {
-      BytesRef vScratch = new BytesRef();
-      position: for (int p = 0; p < positionCount; p++) {
-        if (vBlock.isNull(p)) {
-          result.appendNull();
-          continue position;
-        }
-        if (vBlock.getValueCount(p) != 1) {
-          if (vBlock.getValueCount(p) > 1) {
-            warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value"));
-          }
-          result.appendNull();
-          continue position;
-        }
-        result.appendInt(Categorize.process(vBlock.getBytesRef(vBlock.getFirstValueIndex(p), vScratch), this.analyzer, this.categorizer));
-      }
-      return result.build();
-    }
-  }
-
-  public IntVector eval(int positionCount, BytesRefVector vVector) {
-    try(IntVector.FixedBuilder result = driverContext.blockFactory().newIntVectorFixedBuilder(positionCount)) {
-      BytesRef vScratch = new BytesRef();
-      position: for (int p = 0; p < positionCount; p++) {
-        result.appendInt(p, Categorize.process(vVector.getBytesRef(p, vScratch), this.analyzer, this.categorizer));
-      }
-      return result.build();
-    }
-  }
-
-  @Override
-  public String toString() {
-    return "CategorizeEvaluator[" + "v=" + v + "]";
-  }
-
-  @Override
-  public void close() {
-    Releasables.closeExpectNoException(v, analyzer, categorizer);
-  }
-
-  private Warnings warnings() {
-    if (warnings == null) {
-      this.warnings = Warnings.createWarnings(
-              driverContext.warningsMode(),
-              source.source().getLineNumber(),
-              source.source().getColumnNumber(),
-              source.text()
-          );
-    }
-    return warnings;
-  }
-
-  static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
-    private final Source source;
-
-    private final EvalOperator.ExpressionEvaluator.Factory v;
-
-    private final Function<DriverContext, CategorizationAnalyzer> analyzer;
-
-    private final Function<DriverContext, TokenListCategorizer.CloseableTokenListCategorizer> categorizer;
-
-    public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory v,
-        Function<DriverContext, CategorizationAnalyzer> analyzer,
-        Function<DriverContext, TokenListCategorizer.CloseableTokenListCategorizer> categorizer) {
-      this.source = source;
-      this.v = v;
-      this.analyzer = analyzer;
-      this.categorizer = categorizer;
-    }
-
-    @Override
-    public CategorizeEvaluator get(DriverContext context) {
-      return new CategorizeEvaluator(source, v.get(context), analyzer.apply(context), categorizer.apply(context), context);
-    }
-
-    @Override
-    public String toString() {
-      return "CategorizeEvaluator[" + "v=" + v + "]";
-    }
-  }
-}

+ 4 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -395,8 +395,11 @@ public class EsqlCapabilities {
 
         /**
          * Supported the text categorization function "CATEGORIZE".
+         * <p>
+         *     This capability was initially named `CATEGORIZE`, and got renamed after the function started correctly returning keywords.
+         * </p>
          */
-        CATEGORIZE(Build.current().isSnapshot()),
+        CATEGORIZE_V2(Build.current().isSnapshot()),
 
         /**
          * QSTR function

+ 16 - 60
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java

@@ -7,20 +7,10 @@
 
 package org.elasticsearch.xpack.esql.expression.function.grouping;
 
-import org.apache.lucene.analysis.TokenStream;
-import org.apache.lucene.analysis.core.WhitespaceTokenizer;
-import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.util.BytesRefHash;
-import org.elasticsearch.compute.ann.Evaluator;
-import org.elasticsearch.compute.ann.Fixed;
 import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
-import org.elasticsearch.index.analysis.CharFilterFactory;
-import org.elasticsearch.index.analysis.CustomAnalyzer;
-import org.elasticsearch.index.analysis.TokenFilterFactory;
-import org.elasticsearch.index.analysis.TokenizerFactory;
 import org.elasticsearch.xpack.esql.capabilities.Validatable;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
@@ -29,10 +19,6 @@ import org.elasticsearch.xpack.esql.core.type.DataType;
 import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
 import org.elasticsearch.xpack.esql.expression.function.Param;
 import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
-import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
-import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
-import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategorizer;
-import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
 
 import java.io.IOException;
 import java.util.List;
@@ -42,16 +28,16 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isStr
 
 /**
  * Categorizes text messages.
- *
- * This implementation is incomplete and comes with the following caveats:
- * - it only works correctly on a single node.
- * - when running on multiple nodes, category IDs of the different nodes are
- *   aggregated, even though the same ID can correspond to a totally different
- *   category
- * - the output consists of category IDs, which should be replaced by category
- *   regexes or keys
- *
- * TODO(jan, nik): fix this
+ * <p>
+ *     This function has no evaluators, as it works like an aggregation (Accumulates values, stores intermediate states, etc).
+ * </p>
+ * <p>
+ *     For the implementation, see:
+ * </p>
+ * <ul>
+ *     <li>{@link org.elasticsearch.compute.aggregation.blockhash.CategorizedIntermediateBlockHash}</li>
+ *     <li>{@link org.elasticsearch.compute.aggregation.blockhash.CategorizeRawBlockHash}</li>
+ * </ul>
  */
 public class Categorize extends GroupingFunction implements Validatable {
     public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
@@ -62,7 +48,7 @@ public class Categorize extends GroupingFunction implements Validatable {
 
     private final Expression field;
 
-    @FunctionInfo(returnType = { "integer" }, description = "Categorizes text messages.")
+    @FunctionInfo(returnType = "keyword", description = "Categorizes text messages.")
     public Categorize(
         Source source,
         @Param(name = "field", type = { "text", "keyword" }, description = "Expression to categorize") Expression field
@@ -88,43 +74,13 @@ public class Categorize extends GroupingFunction implements Validatable {
 
     @Override
     public boolean foldable() {
-        return field.foldable();
-    }
-
-    @Evaluator
-    static int process(
-        BytesRef v,
-        @Fixed(includeInToString = false, build = true) CategorizationAnalyzer analyzer,
-        @Fixed(includeInToString = false, build = true) TokenListCategorizer.CloseableTokenListCategorizer categorizer
-    ) {
-        String s = v.utf8ToString();
-        try (TokenStream ts = analyzer.tokenStream("text", s)) {
-            return categorizer.computeCategory(ts, s.length(), 1).getId();
-        } catch (IOException e) {
-            throw new RuntimeException(e);
-        }
+        // Categorize cannot be currently folded
+        return false;
     }
 
     @Override
     public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
-        return new CategorizeEvaluator.Factory(
-            source(),
-            toEvaluator.apply(field),
-            context -> new CategorizationAnalyzer(
-                // TODO(jan): get the correct analyzer in here, see CategorizationAnalyzerConfig::buildStandardCategorizationAnalyzer
-                new CustomAnalyzer(
-                    TokenizerFactory.newFactory("whitespace", WhitespaceTokenizer::new),
-                    new CharFilterFactory[0],
-                    new TokenFilterFactory[0]
-                ),
-                true
-            ),
-            context -> new TokenListCategorizer.CloseableTokenListCategorizer(
-                new CategorizationBytesRefHash(new BytesRefHash(2048, context.bigArrays())),
-                CategorizationPartOfSpeechDictionary.getInstance(),
-                0.70f
-            )
-        );
+        throw new UnsupportedOperationException("CATEGORIZE is only evaluated during aggregations");
     }
 
     @Override
@@ -134,11 +90,11 @@ public class Categorize extends GroupingFunction implements Validatable {
 
     @Override
     public DataType dataType() {
-        return DataType.INTEGER;
+        return DataType.KEYWORD;
     }
 
     @Override
-    public Expression replaceChildren(List<Expression> newChildren) {
+    public Categorize replaceChildren(List<Expression> newChildren) {
         return new Categorize(source(), newChildren.get(0));
     }
 

+ 26 - 12
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineProjections.java

@@ -15,6 +15,7 @@ import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
 import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.expression.Expressions;
 import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
+import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
 import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
 import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
 import org.elasticsearch.xpack.esql.plan.logical.Project;
@@ -61,12 +62,15 @@ public final class CombineProjections extends OptimizerRules.OptimizerRule<Unary
         if (plan instanceof Aggregate a) {
             if (child instanceof Project p) {
                 var groupings = a.groupings();
-                List<Attribute> groupingAttrs = new ArrayList<>(a.groupings().size());
+                List<NamedExpression> groupingAttrs = new ArrayList<>(a.groupings().size());
                 for (Expression grouping : groupings) {
                     if (grouping instanceof Attribute attribute) {
                         groupingAttrs.add(attribute);
+                    } else if (grouping instanceof Alias as && as.child() instanceof Categorize) {
+                        groupingAttrs.add(as);
                     } else {
-                        // After applying ReplaceAggregateNestedExpressionWithEval, groupings can only contain attributes.
+                        // After applying ReplaceAggregateNestedExpressionWithEval,
+                        // groupings (except Categorize) can only contain attributes.
                         throw new EsqlIllegalArgumentException("Expected an Attribute, got {}", grouping);
                     }
                 }
@@ -137,23 +141,33 @@ public final class CombineProjections extends OptimizerRules.OptimizerRule<Unary
     }
 
     private static List<Expression> combineUpperGroupingsAndLowerProjections(
-        List<? extends Attribute> upperGroupings,
+        List<? extends NamedExpression> upperGroupings,
         List<? extends NamedExpression> lowerProjections
     ) {
         // Collect the alias map for resolving the source (f1 = 1, f2 = f1, etc..)
-        AttributeMap<Attribute> aliases = new AttributeMap<>();
+        AttributeMap<Expression> aliases = new AttributeMap<>();
         for (NamedExpression ne : lowerProjections) {
-            // Projections are just aliases for attributes, so casting is safe.
-            aliases.put(ne.toAttribute(), (Attribute) Alias.unwrap(ne));
+            // record the alias
+            aliases.put(ne.toAttribute(), Alias.unwrap(ne));
         }
-
         // Replace any matching attribute directly with the aliased attribute from the projection.
-        AttributeSet replaced = new AttributeSet();
-        for (Attribute attr : upperGroupings) {
-            // All substitutions happen before; groupings must be attributes at this point.
-            replaced.add(aliases.resolve(attr, attr));
+        AttributeSet seen = new AttributeSet();
+        List<Expression> replaced = new ArrayList<>();
+        for (NamedExpression ne : upperGroupings) {
+            // Duplicated attributes are ignored.
+            if (ne instanceof Attribute attribute) {
+                var newExpression = aliases.resolve(attribute, attribute);
+                if (newExpression instanceof Attribute newAttribute && seen.add(newAttribute) == false) {
+                    // Already seen, skip
+                    continue;
+                }
+                replaced.add(newExpression);
+            } else {
+                // For grouping functions, this will replace nested properties too
+                replaced.add(ne.transformUp(Attribute.class, a -> aliases.resolve(a, a)));
+            }
         }
-        return new ArrayList<>(replaced);
+        return replaced;
     }
 
     /**

+ 2 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java

@@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions;
 import org.elasticsearch.xpack.esql.core.expression.Literal;
 import org.elasticsearch.xpack.esql.core.expression.Nullability;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
+import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In;
 
 public class FoldNull extends OptimizerRules.OptimizerExpressionRule<Expression> {
@@ -42,6 +43,7 @@ public class FoldNull extends OptimizerRules.OptimizerExpressionRule<Expression>
             }
         } else if (e instanceof Alias == false
             && e.nullable() == Nullability.TRUE
+            && e instanceof Categorize == false
             && Expressions.anyMatch(e.children(), Expressions::isNull)) {
                 return Literal.of(e, null);
             }

+ 23 - 8
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceAggregateNestedExpressionWithEval.java

@@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression;
 import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
 import org.elasticsearch.xpack.esql.core.util.Holder;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
+import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
 import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction;
 import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
 import org.elasticsearch.xpack.esql.plan.logical.Eval;
@@ -46,15 +47,29 @@ public final class ReplaceAggregateNestedExpressionWithEval extends OptimizerRul
         // start with the groupings since the aggs might duplicate it
         for (int i = 0, s = newGroupings.size(); i < s; i++) {
             Expression g = newGroupings.get(i);
-            // move the alias into an eval and replace it with its attribute
+            // Move the alias into an eval and replace it with its attribute.
+            // Exception: Categorize is internal to the aggregation and remains in the groupings. We move its child expression into an eval.
             if (g instanceof Alias as) {
-                groupingChanged = true;
-                var attr = as.toAttribute();
-                evals.add(as);
-                evalNames.put(as.name(), attr);
-                newGroupings.set(i, attr);
-                if (as.child() instanceof GroupingFunction gf) {
-                    groupingAttributes.put(gf, attr);
+                if (as.child() instanceof Categorize cat) {
+                    if (cat.field() instanceof Attribute == false) {
+                        groupingChanged = true;
+                        var fieldAs = new Alias(as.source(), as.name(), cat.field(), null, true);
+                        var fieldAttr = fieldAs.toAttribute();
+                        evals.add(fieldAs);
+                        evalNames.put(fieldAs.name(), fieldAttr);
+                        Categorize replacement = cat.replaceChildren(List.of(fieldAttr));
+                        newGroupings.set(i, as.replaceChild(replacement));
+                        groupingAttributes.put(cat, fieldAttr);
+                    }
+                } else {
+                    groupingChanged = true;
+                    var attr = as.toAttribute();
+                    evals.add(as);
+                    evalNames.put(as.name(), attr);
+                    newGroupings.set(i, attr);
+                    if (as.child() instanceof GroupingFunction gf) {
+                        groupingAttributes.put(gf, attr);
+                    }
                 }
             }
         }

+ 12 - 5
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/InsertFieldExtraction.java

@@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions;
 import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
 import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute;
 import org.elasticsearch.xpack.esql.core.expression.TypedAttribute;
+import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
 import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns;
 import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
 import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
@@ -58,11 +59,17 @@ public class InsertFieldExtraction extends Rule<PhysicalPlan, PhysicalPlan> {
              * make sure the fields are loaded for the standard hash aggregator.
              */
             if (p instanceof AggregateExec agg && agg.groupings().size() == 1) {
-                var leaves = new LinkedList<>();
-                // TODO: this seems out of place
-                agg.aggregates().stream().filter(a -> agg.groupings().contains(a) == false).forEach(a -> leaves.addAll(a.collectLeaves()));
-                var remove = agg.groupings().stream().filter(g -> leaves.contains(g) == false).toList();
-                missing.removeAll(Expressions.references(remove));
+                // CATEGORIZE requires the standard hash aggregator as well.
+                if (agg.groupings().get(0).anyMatch(e -> e instanceof Categorize) == false) {
+                    var leaves = new LinkedList<>();
+                    // TODO: this seems out of place
+                    agg.aggregates()
+                        .stream()
+                        .filter(a -> agg.groupings().contains(a) == false)
+                        .forEach(a -> leaves.addAll(a.collectLeaves()));
+                    var remove = agg.groupings().stream().filter(g -> leaves.contains(g) == false).toList();
+                    missing.removeAll(Expressions.references(remove));
+                }
             }
 
             // add extractor

+ 31 - 11
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java

@@ -29,6 +29,7 @@ import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
 import org.elasticsearch.xpack.esql.evaluator.EvalMapper;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
+import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
 import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
 import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
 import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
@@ -52,6 +53,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
         PhysicalOperation source,
         LocalExecutionPlannerContext context
     ) {
+        // The layout this operation will produce.
         Layout.Builder layout = new Layout.Builder();
         Operator.OperatorFactory operatorFactory = null;
         AggregatorMode aggregatorMode = aggregateExec.getMode();
@@ -95,12 +97,17 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
             List<GroupingAggregator.Factory> aggregatorFactories = new ArrayList<>();
             List<GroupSpec> groupSpecs = new ArrayList<>(aggregateExec.groupings().size());
             for (Expression group : aggregateExec.groupings()) {
-                var groupAttribute = Expressions.attribute(group);
-                if (groupAttribute == null) {
+                Attribute groupAttribute = Expressions.attribute(group);
+                // In case of `... BY groupAttribute = CATEGORIZE(sourceGroupAttribute)` the actual source attribute is different.
+                Attribute sourceGroupAttribute = (aggregatorMode.isInputPartial() == false
+                    && group instanceof Alias as
+                    && as.child() instanceof Categorize categorize) ? Expressions.attribute(categorize.field()) : groupAttribute;
+                if (sourceGroupAttribute == null) {
                     throw new EsqlIllegalArgumentException("Unexpected non-named expression[{}] as grouping in [{}]", group, aggregateExec);
                 }
-                Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<>(), groupAttribute.dataType());
-                groupAttributeLayout.nameIds().add(groupAttribute.id());
+                Layout.ChannelSet groupAttributeLayout = new Layout.ChannelSet(new HashSet<>(), sourceGroupAttribute.dataType());
+                groupAttributeLayout.nameIds()
+                    .add(group instanceof Alias as && as.child() instanceof Categorize ? groupAttribute.id() : sourceGroupAttribute.id());
 
                 /*
                  * Check for aliasing in aggregates which occurs in two cases (due to combining project + stats):
@@ -119,7 +126,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
                             // check if there's any alias used in grouping - no need for the final reduction since the intermediate data
                             // is in the output form
                             // if the group points to an alias declared in the aggregate, use the alias child as source
-                            else if (aggregatorMode == AggregatorMode.INITIAL || aggregatorMode == AggregatorMode.INTERMEDIATE) {
+                            else if (aggregatorMode.isOutputPartial()) {
                                 if (groupAttribute.semanticEquals(a.toAttribute())) {
                                     groupAttribute = attr;
                                     break;
@@ -129,8 +136,8 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
                     }
                 }
                 layout.append(groupAttributeLayout);
-                Layout.ChannelAndType groupInput = source.layout.get(groupAttribute.id());
-                groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), groupAttribute));
+                Layout.ChannelAndType groupInput = source.layout.get(sourceGroupAttribute.id());
+                groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group));
             }
 
             if (aggregatorMode == AggregatorMode.FINAL) {
@@ -164,6 +171,7 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
             } else {
                 operatorFactory = new HashAggregationOperatorFactory(
                     groupSpecs.stream().map(GroupSpec::toHashGroupSpec).toList(),
+                    aggregatorMode,
                     aggregatorFactories,
                     context.pageSize(aggregateExec.estimatedRowSize())
                 );
@@ -178,10 +186,14 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
     /***
      * Creates a standard layout for intermediate aggregations, typically used across exchanges.
      * Puts the group first, followed by each aggregation.
-     *
-     * It's similar to the code above (groupingPhysicalOperation) but ignores the factory creation.
+     * <p>
+     *     It's similar to the code above (groupingPhysicalOperation) but ignores the factory creation.
+     * </p>
      */
     public static List<Attribute> intermediateAttributes(List<? extends NamedExpression> aggregates, List<? extends Expression> groupings) {
+        // TODO: This should take CATEGORIZE into account:
+        // it currently works because the CATEGORIZE intermediate state is just 1 block with the same type as the function return,
+        // so the attribute generated here is the expected one
         var aggregateMapper = new AggregateMapper();
 
         List<Attribute> attrs = new ArrayList<>();
@@ -304,12 +316,20 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOper
         throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
     }
 
-    private record GroupSpec(Integer channel, Attribute attribute) {
+    /**
+     * The input configuration of this group.
+     *
+     * @param channel The source channel of this group
+     * @param attribute The attribute, source of this group
+     * @param expression The expression being used to group
+     */
+    private record GroupSpec(Integer channel, Attribute attribute, Expression expression) {
         BlockHash.GroupSpec toHashGroupSpec() {
             if (channel == null) {
                 throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead");
             }
-            return new BlockHash.GroupSpec(channel, elementType());
+
+            return new BlockHash.GroupSpec(channel, elementType(), Alias.unwrap(expression) instanceof Categorize);
         }
 
         ElementType elementType() {

+ 3 - 3
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -1821,7 +1821,7 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeSingleGrouping() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
 
         query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)");
         query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)");
@@ -1850,7 +1850,7 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeNestedGrouping() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
 
         query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)");
 
@@ -1865,7 +1865,7 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeWithinAggregations() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
 
         query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)");
 

+ 2 - 1
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java

@@ -111,7 +111,8 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa
                         testCase.getExpectedTypeError(),
                         null,
                         null,
-                        null
+                        null,
+                        testCase.canBuildEvaluator()
                     );
                 }));
             }

+ 5 - 14
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java

@@ -229,7 +229,8 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
                         oc.getExpectedTypeError(),
                         null,
                         null,
-                        null
+                        null,
+                        oc.canBuildEvaluator()
                     );
                 }));
 
@@ -260,7 +261,8 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
                                 oc.getExpectedTypeError(),
                                 null,
                                 null,
-                                null
+                                null,
+                                oc.canBuildEvaluator()
                             );
                         }));
                     }
@@ -648,18 +650,7 @@ public abstract class AbstractFunctionTestCase extends ESTestCase {
                 return typedData.withData(tryRandomizeBytesRefOffset(typedData.data()));
             }).toList();
 
-            return new TestCaseSupplier.TestCase(
-                newData,
-                testCase.evaluatorToString(),
-                testCase.expectedType(),
-                testCase.getMatcher(),
-                testCase.getExpectedWarnings(),
-                testCase.getExpectedBuildEvaluatorWarnings(),
-                testCase.getExpectedTypeError(),
-                testCase.foldingExceptionClass(),
-                testCase.foldingExceptionMessage(),
-                testCase.extra()
-            );
+            return testCase.withData(newData);
         })).toList();
     }
 

+ 1 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java

@@ -345,6 +345,7 @@ public abstract class AbstractScalarFunctionTestCase extends AbstractFunctionTes
             return;
         }
         assertFalse("expected resolved", expression.typeResolved().unresolved());
+        assumeTrue("Can't build evaluator", testCase.canBuildEvaluator());
         Expression nullOptimized = new FoldNull().rule(expression);
         assertThat(nullOptimized.dataType(), equalTo(testCase.expectedType()));
         assertTrue(nullOptimized.foldable());

+ 78 - 5
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java

@@ -1431,6 +1431,34 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
             Class<? extends Throwable> foldingExceptionClass,
             String foldingExceptionMessage,
             Object extra
+        ) {
+            this(
+                data,
+                evaluatorToString,
+                expectedType,
+                matcher,
+                expectedWarnings,
+                expectedBuildEvaluatorWarnings,
+                expectedTypeError,
+                foldingExceptionClass,
+                foldingExceptionMessage,
+                extra,
+                data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type))
+            );
+        }
+
+        TestCase(
+            List<TypedData> data,
+            Matcher<String> evaluatorToString,
+            DataType expectedType,
+            Matcher<?> matcher,
+            String[] expectedWarnings,
+            String[] expectedBuildEvaluatorWarnings,
+            String expectedTypeError,
+            Class<? extends Throwable> foldingExceptionClass,
+            String foldingExceptionMessage,
+            Object extra,
+            boolean canBuildEvaluator
         ) {
             this.source = Source.EMPTY;
             this.data = data;
@@ -1442,10 +1470,10 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
             this.expectedWarnings = expectedWarnings;
             this.expectedBuildEvaluatorWarnings = expectedBuildEvaluatorWarnings;
             this.expectedTypeError = expectedTypeError;
-            this.canBuildEvaluator = data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type));
             this.foldingExceptionClass = foldingExceptionClass;
             this.foldingExceptionMessage = foldingExceptionMessage;
             this.extra = extra;
+            this.canBuildEvaluator = canBuildEvaluator;
         }
 
         public Source getSource() {
@@ -1520,6 +1548,25 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
             return extra;
         }
 
+        /**
+         * Build a new {@link TestCase} with new {@link #data}.
+         */
+        public TestCase withData(List<TestCaseSupplier.TypedData> data) {
+            return new TestCase(
+                data,
+                evaluatorToString,
+                expectedType,
+                matcher,
+                expectedWarnings,
+                expectedBuildEvaluatorWarnings,
+                expectedTypeError,
+                foldingExceptionClass,
+                foldingExceptionMessage,
+                extra,
+                canBuildEvaluator
+            );
+        }
+
         /**
          * Build a new {@link TestCase} with new {@link #extra()}.
          */
@@ -1534,7 +1581,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
                 expectedTypeError,
                 foldingExceptionClass,
                 foldingExceptionMessage,
-                extra
+                extra,
+                canBuildEvaluator
             );
         }
 
@@ -1549,7 +1597,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
                 expectedTypeError,
                 foldingExceptionClass,
                 foldingExceptionMessage,
-                extra
+                extra,
+                canBuildEvaluator
             );
         }
 
@@ -1568,7 +1617,8 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
                 expectedTypeError,
                 foldingExceptionClass,
                 foldingExceptionMessage,
-                extra
+                extra,
+                canBuildEvaluator
             );
         }
 
@@ -1592,7 +1642,30 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
                 expectedTypeError,
                 clazz,
                 message,
-                extra
+                extra,
+                canBuildEvaluator
+            );
+        }
+
+        /**
+         * Build a new {@link TestCase} that can't build an evaluator.
+         * <p>
+         *     Useful for special cases that can't be executed, but should still be considered.
+         * </p>
+         */
+        public TestCase withoutEvaluator() {
+            return new TestCase(
+                data,
+                evaluatorToString,
+                expectedType,
+                matcher,
+                expectedWarnings,
+                expectedBuildEvaluatorWarnings,
+                expectedTypeError,
+                foldingExceptionClass,
+                foldingExceptionMessage,
+                extra,
+                false
             );
         }
 

+ 11 - 5
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/grouping/CategorizeTests.java

@@ -23,6 +23,12 @@ import java.util.function.Supplier;
 
 import static org.hamcrest.Matchers.equalTo;
 
+/**
+ * Dummy test implementation for Categorize. Used just to generate documentation.
+ * <p>
+ *     Most test cases are currently skipped as this function can't build an evaluator.
+ * </p>
+ */
 public class CategorizeTests extends AbstractScalarFunctionTestCase {
     public CategorizeTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
         this.testCase = testCaseSupplier.get();
@@ -37,11 +43,11 @@ public class CategorizeTests extends AbstractScalarFunctionTestCase {
                     "text with " + dataType.typeName(),
                     List.of(dataType),
                     () -> new TestCaseSupplier.TestCase(
-                        List.of(new TestCaseSupplier.TypedData(new BytesRef("blah blah blah"), dataType, "f")),
-                        "CategorizeEvaluator[v=Attribute[channel=0]]",
-                        DataType.INTEGER,
-                        equalTo(0)
-                    )
+                        List.of(new TestCaseSupplier.TypedData(new BytesRef(""), dataType, "field")),
+                        "",
+                        DataType.KEYWORD,
+                        equalTo(new BytesRef(""))
+                    ).withoutEvaluator()
                 )
             );
         }

+ 61 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -57,6 +57,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Values;
 import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
+import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
 import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToDouble;
 import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToInteger;
 import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
@@ -1203,6 +1204,33 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
         assertThat(Expressions.names(agg.groupings()), contains("first_name"));
     }
 
+    /**
+     * Expects
+     * Limit[1000[INTEGER]]
+     * \_Aggregate[STANDARD,[CATEGORIZE(first_name{f}#18) AS cat],[SUM(salary{f}#22,true[BOOLEAN]) AS s, cat{r}#10]]
+     *   \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..]
+     */
+    public void testCombineProjectionWithCategorizeGrouping() {
+        var plan = plan("""
+            from test
+            | eval k = first_name, k1 = k
+            | stats s = sum(salary) by cat = CATEGORIZE(k)
+            | keep s, cat
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        assertThat(agg.child(), instanceOf(EsRelation.class));
+
+        assertThat(Expressions.names(agg.aggregates()), contains("s", "cat"));
+        assertThat(Expressions.names(agg.groupings()), contains("cat"));
+
+        var categorizeAlias = as(agg.groupings().get(0), Alias.class);
+        var categorize = as(categorizeAlias.child(), Categorize.class);
+        var categorizeField = as(categorize.field(), FieldAttribute.class);
+        assertThat(categorizeField.name(), is("first_name"));
+    }
+
     /**
      * Expects
      * Limit[1000[INTEGER]]
@@ -3909,6 +3937,39 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
         assertThat(eval.fields().get(0).name(), is("emp_no % 2"));
     }
 
+    /**
+     * Expects
+     * Limit[1000[INTEGER]]
+     * \_Aggregate[STANDARD,[CATEGORIZE(CATEGORIZE(CONCAT(first_name, "abc")){r$}#18) AS CATEGORIZE(CONCAT(first_name, "abc"))],[CO
+     * UNT(salary{f}#13,true[BOOLEAN]) AS c, CATEGORIZE(CONCAT(first_name, "abc")){r}#3]]
+     *   \_Eval[[CONCAT(first_name{f}#9,[61 62 63][KEYWORD]) AS CATEGORIZE(CONCAT(first_name, "abc"))]]
+     *     \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..]
+     */
+    public void testNestedExpressionsInGroupsWithCategorize() {
+        var plan = optimizedPlan("""
+            from test
+            | stats c = count(salary) by CATEGORIZE(CONCAT(first_name, "abc"))
+            """);
+
+        var limit = as(plan, Limit.class);
+        var agg = as(limit.child(), Aggregate.class);
+        var groupings = agg.groupings();
+        var categorizeAlias = as(groupings.get(0), Alias.class);
+        var categorize = as(categorizeAlias.child(), Categorize.class);
+        var aggs = agg.aggregates();
+        assertThat(aggs.get(1), is(categorizeAlias.toAttribute()));
+
+        var eval = as(agg.child(), Eval.class);
+        assertThat(eval.fields(), hasSize(1));
+        var evalFieldAlias = as(eval.fields().get(0), Alias.class);
+        var evalField = as(evalFieldAlias.child(), Concat.class);
+
+        assertThat(evalFieldAlias.name(), is("CATEGORIZE(CONCAT(first_name, \"abc\"))"));
+        assertThat(categorize.field(), is(evalFieldAlias.toAttribute()));
+        assertThat(evalField.source().text(), is("CONCAT(first_name, \"abc\")"));
+        assertThat(categorizeAlias.source(), is(evalFieldAlias.source()));
+    }
+
     /**
      * Expects
      * Limit[1000[INTEGER]]

+ 13 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java

@@ -28,6 +28,8 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.Min;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Percentile;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.SpatialCentroid;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.Sum;
+import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket;
+import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize;
 import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToString;
 import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateExtract;
 import org.elasticsearch.xpack.esql.expression.function.scalar.date.DateFormat;
@@ -267,6 +269,17 @@ public class FoldNullTests extends ESTestCase {
         }
     }
 
+    public void testNullBucketGetsFolded() {
+        FoldNull foldNull = new FoldNull();
+        assertEquals(NULL, foldNull.rule(new Bucket(EMPTY, NULL, NULL, NULL, NULL)));
+    }
+
+    public void testNullCategorizeGroupingNotFolded() {
+        FoldNull foldNull = new FoldNull();
+        Categorize categorize = new Categorize(EMPTY, NULL);
+        assertEquals(categorize, foldNull.rule(categorize));
+    }
+
     private void assertNullLiteral(Expression expression) {
         assertEquals(Literal.class, expression.getClass());
         assertNull(expression.fold());

+ 24 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java

@@ -19,6 +19,7 @@ import org.elasticsearch.core.Releasables;
 import org.elasticsearch.search.aggregations.AggregationReduceContext;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory.TokenAndWeight;
+import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
 
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
@@ -83,6 +84,8 @@ public class TokenListCategorizer implements Accountable {
     @Nullable
     private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary;
 
+    private final List<TokenListCategory> categoriesById;
+
     /**
      * Categories stored in such a way that the most common are accessed first.
      * This is implemented as an {@link ArrayList} with bespoke ordering rather
@@ -108,9 +111,18 @@ public class TokenListCategorizer implements Accountable {
         this.lowerThreshold = threshold;
         this.upperThreshold = (1.0f + threshold) / 2.0f;
         this.categoriesByNumMatches = new ArrayList<>();
+        this.categoriesById = new ArrayList<>();
         cacheRamUsage(0);
     }
 
+    public TokenListCategory computeCategory(String s, CategorizationAnalyzer analyzer) {
+        try (TokenStream ts = analyzer.tokenStream("text", s)) {
+            return computeCategory(ts, s.length(), 1);
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
     public TokenListCategory computeCategory(TokenStream ts, int unfilteredStringLen, long numDocs) throws IOException {
         assert partOfSpeechDictionary != null
             : "This version of computeCategory should only be used when a part-of-speech dictionary is available";
@@ -301,6 +313,7 @@ public class TokenListCategorizer implements Accountable {
             maxUnfilteredStringLen,
             numDocs
         );
+        categoriesById.add(newCategory);
         categoriesByNumMatches.add(newCategory);
         cacheRamUsage(newCategory.ramBytesUsed());
         return repositionCategory(newCategory, newIndex);
@@ -412,6 +425,17 @@ public class TokenListCategorizer implements Accountable {
         }
     }
 
+    public List<SerializableTokenListCategory> toCategories(int size) {
+        return categoriesByNumMatches.stream()
+            .limit(size)
+            .map(category -> new SerializableTokenListCategory(category, bytesRefHash))
+            .toList();
+    }
+
+    public List<SerializableTokenListCategory> toCategoriesById() {
+        return categoriesById.stream().map(category -> new SerializableTokenListCategory(category, bytesRefHash)).toList();
+    }
+
     public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size) {
         return categoriesByNumMatches.stream()
             .limit(size)