Kaynağa Gözat

ES|QL categorize with multiple groupings (#118173) (#118590)

* ES|QL categorize with multiple groupings.

* Fix VerifierTests

* Close stuff when constructing CategorizePackedValuesBlockHash fails

* CategorizePackedValuesBlockHashTests

* Improve categorize javadocs

* Update docs/changelog/118173.yaml

* Create CategorizePackedValuesBlockHash's deletegate page differently

* Double check in BlockHash builder for single categorize

* Reuse blocks array

* More CSV tests

* Remove assumeTrue categorize_v5

* Rename test

* Two more verifier tests

* more CSV tests

* Add JavaDocs/comments

* spotless

* Refactor/unify recategorize

* Better memory accounting

* fix csv test

* randomize CategorizePackedValuesBlockHashTests

* Add TODO

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Jan Kuipers 10 ay önce
ebeveyn
işleme
36d11d3374

+ 5 - 0
docs/changelog/118173.yaml

@@ -0,0 +1,5 @@
+pr: 118173
+summary: ES|QL categorize with multiple groupings
+area: Machine Learning
+type: feature
+issues: []

+ 8 - 5
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java

@@ -180,13 +180,16 @@ public abstract class BlockHash implements Releasable, SeenGroupIds {
         List<GroupSpec> groups,
         AggregatorMode aggregatorMode,
         BlockFactory blockFactory,
-        AnalysisRegistry analysisRegistry
+        AnalysisRegistry analysisRegistry,
+        int emitBatchSize
     ) {
-        if (groups.size() != 1) {
-            throw new IllegalArgumentException("only a single CATEGORIZE group can used");
+        if (groups.size() == 1) {
+            return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
+        } else {
+            assert groups.get(0).isCategorize();
+            assert groups.subList(1, groups.size()).stream().noneMatch(GroupSpec::isCategorize);
+            return new CategorizePackedValuesBlockHash(groups, blockFactory, aggregatorMode, analysisRegistry, emitBatchSize);
         }
-
-        return new CategorizeBlockHash(blockFactory, groups.get(0).channel, aggregatorMode, analysisRegistry);
     }
 
     /**

+ 42 - 37
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHash.java

@@ -44,7 +44,7 @@ import java.util.Map;
 import java.util.Objects;
 
 /**
- * Base BlockHash implementation for {@code Categorize} grouping function.
+ * BlockHash implementation for {@code Categorize} grouping function.
  */
 public class CategorizeBlockHash extends BlockHash {
 
@@ -53,11 +53,9 @@ public class CategorizeBlockHash extends BlockHash {
     );
     private static final int NULL_ORD = 0;
 
-    // TODO: this should probably also take an emitBatchSize
     private final int channel;
     private final AggregatorMode aggregatorMode;
     private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
-
     private final CategorizeEvaluator evaluator;
 
     /**
@@ -95,12 +93,14 @@ public class CategorizeBlockHash extends BlockHash {
         }
     }
 
+    boolean seenNull() {
+        return seenNull;
+    }
+
     @Override
     public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
-        if (aggregatorMode.isInputPartial() == false) {
-            addInitial(page, addInput);
-        } else {
-            addIntermediate(page, addInput);
+        try (IntBlock block = add(page)) {
+            addInput.add(0, block);
         }
     }
 
@@ -129,50 +129,38 @@ public class CategorizeBlockHash extends BlockHash {
         Releasables.close(evaluator, categorizer);
     }
 
+    private IntBlock add(Page page) {
+        return aggregatorMode.isInputPartial() == false ? addInitial(page) : addIntermediate(page);
+    }
+
     /**
      * Adds initial (raw) input to the state.
      */
-    private void addInitial(Page page, GroupingAggregatorFunction.AddInput addInput) {
-        try (IntBlock result = (IntBlock) evaluator.eval(page.getBlock(channel))) {
-            addInput.add(0, result);
-        }
+    IntBlock addInitial(Page page) {
+        return (IntBlock) evaluator.eval(page.getBlock(channel));
     }
 
     /**
      * Adds intermediate state to the state.
      */
-    private void addIntermediate(Page page, GroupingAggregatorFunction.AddInput addInput) {
+    private IntBlock addIntermediate(Page page) {
         if (page.getPositionCount() == 0) {
-            return;
+            return null;
         }
         BytesRefBlock categorizerState = page.getBlock(channel);
         if (categorizerState.areAllValuesNull()) {
             seenNull = true;
-            try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) {
-                addInput.add(0, newIds);
-            }
-            return;
-        }
-
-        Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
-        try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
-            int fromId = idMap.containsKey(0) ? 0 : 1;
-            int toId = fromId + idMap.size();
-            for (int i = fromId; i < toId; i++) {
-                newIdsBuilder.appendInt(idMap.get(i));
-            }
-            try (IntBlock newIds = newIdsBuilder.build()) {
-                addInput.add(0, newIds);
-            }
+            return blockFactory.newConstantIntBlockWith(NULL_ORD, 1);
         }
+        return recategorize(categorizerState.getBytesRef(0, new BytesRef()), null).asBlock();
     }
 
     /**
-     * Read intermediate state from a block.
-     *
-     * @return a map from the old category id to the new one. The old ids go from 0 to {@code size - 1}.
+     * Reads the intermediate state from a block and recategorizes the provided IDs.
+     * If no IDs are provided, the IDs are the IDs in the categorizer's state in order.
+     * (So 0...N-1 or 1...N, depending on whether null is present.)
      */
-    private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
+    IntVector recategorize(BytesRef bytes, IntVector ids) {
         Map<Integer, Integer> idMap = new HashMap<>();
         try (StreamInput in = new BytesArray(bytes).streamInput()) {
             if (in.readBoolean()) {
@@ -185,10 +173,22 @@ public class CategorizeBlockHash extends BlockHash {
                 // +1 because the 0 ordinal is reserved for null
                 idMap.put(oldCategoryId + 1, newCategoryId + 1);
             }
-            return idMap;
         } catch (IOException e) {
             throw new RuntimeException(e);
         }
+        try (IntVector.Builder newIdsBuilder = blockFactory.newIntVectorBuilder(idMap.size())) {
+            if (ids == null) {
+                int idOffset = idMap.containsKey(0) ? 0 : 1;
+                for (int i = 0; i < idMap.size(); i++) {
+                    newIdsBuilder.appendInt(idMap.get(i + idOffset));
+                }
+            } else {
+                for (int i = 0; i < ids.getPositionCount(); i++) {
+                    newIdsBuilder.appendInt(idMap.get(ids.getInt(i)));
+                }
+            }
+            return newIdsBuilder.build();
+        }
     }
 
     /**
@@ -198,15 +198,20 @@ public class CategorizeBlockHash extends BlockHash {
         if (categorizer.getCategoryCount() == 0) {
             return blockFactory.newConstantNullBlock(seenNull ? 1 : 0);
         }
+        int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
+        // We're returning a block with N positions just because the Page must have all blocks with the same position count!
+        return blockFactory.newConstantBytesRefBlockWith(serializeCategorizer(), positionCount);
+    }
+
+    BytesRef serializeCategorizer() {
+        // TODO: This BytesStreamOutput is not accounted for by the circuit breaker. Fix that!
         try (BytesStreamOutput out = new BytesStreamOutput()) {
             out.writeBoolean(seenNull);
             out.writeVInt(categorizer.getCategoryCount());
             for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
                 category.writeTo(out);
             }
-            // We're returning a block with N positions just because the Page must have all blocks with the same position count!
-            int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
-            return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
+            return out.bytes().toBytesRef();
         } catch (IOException e) {
             throw new RuntimeException(e);
         }

+ 170 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHash.java

@@ -0,0 +1,170 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.BitArray;
+import org.elasticsearch.compute.aggregation.AggregatorMode;
+import org.elasticsearch.compute.aggregation.GroupingAggregatorFunction;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.ReleasableIterator;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.index.analysis.AnalysisRegistry;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * BlockHash implementation for {@code Categorize} grouping function as first
+ * grouping expression, followed by one or mode other grouping expressions.
+ * <p>
+ * For the first grouping (the {@code Categorize} grouping function), a
+ * {@code CategorizeBlockHash} is used, which outputs integers (category IDs).
+ * Next, a {@code PackedValuesBlockHash} is used on the category IDs and the
+ * other groupings (which are not {@code Categorize}s).
+ */
+public class CategorizePackedValuesBlockHash extends BlockHash {
+
+    private final List<GroupSpec> specs;
+    private final AggregatorMode aggregatorMode;
+    private final Block[] blocks;
+    private final CategorizeBlockHash categorizeBlockHash;
+    private final PackedValuesBlockHash packedValuesBlockHash;
+
+    CategorizePackedValuesBlockHash(
+        List<GroupSpec> specs,
+        BlockFactory blockFactory,
+        AggregatorMode aggregatorMode,
+        AnalysisRegistry analysisRegistry,
+        int emitBatchSize
+    ) {
+        super(blockFactory);
+        this.specs = specs;
+        this.aggregatorMode = aggregatorMode;
+        blocks = new Block[specs.size()];
+
+        List<GroupSpec> delegateSpecs = new ArrayList<>();
+        delegateSpecs.add(new GroupSpec(0, ElementType.INT));
+        for (int i = 1; i < specs.size(); i++) {
+            delegateSpecs.add(new GroupSpec(i, specs.get(i).elementType()));
+        }
+
+        boolean success = false;
+        try {
+            categorizeBlockHash = new CategorizeBlockHash(blockFactory, specs.get(0).channel(), aggregatorMode, analysisRegistry);
+            packedValuesBlockHash = new PackedValuesBlockHash(delegateSpecs, blockFactory, emitBatchSize);
+            success = true;
+        } finally {
+            if (success == false) {
+                close();
+            }
+        }
+    }
+
+    @Override
+    public void add(Page page, GroupingAggregatorFunction.AddInput addInput) {
+        try (IntBlock categories = getCategories(page)) {
+            blocks[0] = categories;
+            for (int i = 1; i < specs.size(); i++) {
+                blocks[i] = page.getBlock(specs.get(i).channel());
+            }
+            packedValuesBlockHash.add(new Page(blocks), addInput);
+        }
+    }
+
+    private IntBlock getCategories(Page page) {
+        if (aggregatorMode.isInputPartial() == false) {
+            return categorizeBlockHash.addInitial(page);
+        } else {
+            BytesRefBlock stateBlock = page.getBlock(0);
+            BytesRef stateBytes = stateBlock.getBytesRef(0, new BytesRef());
+            try (StreamInput in = new BytesArray(stateBytes).streamInput()) {
+                BytesRef categorizerState = in.readBytesRef();
+                try (IntVector ids = IntVector.readFrom(blockFactory, in)) {
+                    return categorizeBlockHash.recategorize(categorizerState, ids).asBlock();
+                }
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+        }
+    }
+
+    @Override
+    public Block[] getKeys() {
+        Block[] keys = packedValuesBlockHash.getKeys();
+        if (aggregatorMode.isOutputPartial() == false) {
+            // For final output, the keys are the category regexes.
+            try (
+                BytesRefBlock regexes = (BytesRefBlock) categorizeBlockHash.getKeys()[0];
+                BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(keys[0].getPositionCount())
+            ) {
+                IntVector idsVector = (IntVector) keys[0].asVector();
+                int idsOffset = categorizeBlockHash.seenNull() ? 0 : -1;
+                BytesRef scratch = new BytesRef();
+                for (int i = 0; i < idsVector.getPositionCount(); i++) {
+                    int id = idsVector.getInt(i);
+                    if (id == 0) {
+                        builder.appendNull();
+                    } else {
+                        builder.appendBytesRef(regexes.getBytesRef(id + idsOffset, scratch));
+                    }
+                }
+                keys[0].close();
+                keys[0] = builder.build();
+            }
+        } else {
+            // For intermediate output, the keys are the delegate PackedValuesBlockHash's
+            // keys, with the category IDs replaced by the categorizer's internal state
+            // together with the list of category IDs.
+            BytesRef state;
+            // TODO: This BytesStreamOutput is not accounted for by the circuit breaker. Fix that!
+            try (BytesStreamOutput out = new BytesStreamOutput()) {
+                out.writeBytesRef(categorizeBlockHash.serializeCategorizer());
+                ((IntVector) keys[0].asVector()).writeTo(out);
+                state = out.bytes().toBytesRef();
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+            keys[0].close();
+            keys[0] = blockFactory.newConstantBytesRefBlockWith(state, keys[0].getPositionCount());
+        }
+        return keys;
+    }
+
+    @Override
+    public IntVector nonEmpty() {
+        return packedValuesBlockHash.nonEmpty();
+    }
+
+    @Override
+    public BitArray seenGroupIds(BigArrays bigArrays) {
+        return packedValuesBlockHash.seenGroupIds(bigArrays);
+    }
+
+    @Override
+    public final ReleasableIterator<IntBlock> lookup(Page page, ByteSizeValue targetBlockSize) {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public void close() {
+        Releasables.close(categorizeBlockHash, packedValuesBlockHash);
+    }
+}

+ 7 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java

@@ -51,7 +51,13 @@ public class HashAggregationOperator implements Operator {
             if (groups.stream().anyMatch(BlockHash.GroupSpec::isCategorize)) {
                 return new HashAggregationOperator(
                     aggregators,
-                    () -> BlockHash.buildCategorizeBlockHash(groups, aggregatorMode, driverContext.blockFactory(), analysisRegistry),
+                    () -> BlockHash.buildCategorizeBlockHash(
+                        groups,
+                        aggregatorMode,
+                        driverContext.blockFactory(),
+                        analysisRegistry,
+                        maxPageSize
+                    ),
                     driverContext
                 );
             }

+ 0 - 3
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

@@ -130,9 +130,6 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
         } finally {
             page.releaseBlocks();
         }
-
-        // TODO: randomize values? May give wrong results
-        // TODO: assert the categorizer state after adding pages.
     }
 
     public void testCategorizeRawMultivalue() {

+ 248 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java

@@ -0,0 +1,248 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation.blockhash;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.analysis.common.CommonAnalysisPlugin;
+import org.elasticsearch.common.breaker.CircuitBreaker;
+import org.elasticsearch.common.collect.Iterators;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.MockBigArrays;
+import org.elasticsearch.common.util.PageCacheRecycler;
+import org.elasticsearch.compute.aggregation.AggregatorMode;
+import org.elasticsearch.compute.aggregation.ValuesBytesRefAggregatorFunctionSupplier;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BlockUtils;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.CannedSourceOperator;
+import org.elasticsearch.compute.operator.Driver;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.HashAggregationOperator;
+import org.elasticsearch.compute.operator.LocalSourceOperator;
+import org.elasticsearch.compute.operator.PageConsumerOperator;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.env.Environment;
+import org.elasticsearch.env.TestEnvironment;
+import org.elasticsearch.index.analysis.AnalysisRegistry;
+import org.elasticsearch.indices.analysis.AnalysisModule;
+import org.elasticsearch.plugins.scanners.StablePluginsRegistry;
+import org.elasticsearch.xpack.ml.MachineLearning;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.elasticsearch.compute.operator.OperatorTestCase.runDriver;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+
+public class CategorizePackedValuesBlockHashTests extends BlockHashTestCase {
+
+    private AnalysisRegistry analysisRegistry;
+
+    @Before
+    private void initAnalysisRegistry() throws IOException {
+        analysisRegistry = new AnalysisModule(
+            TestEnvironment.newEnvironment(
+                Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build()
+            ),
+            List.of(new MachineLearning(Settings.EMPTY), new CommonAnalysisPlugin()),
+            new StablePluginsRegistry()
+        ).getAnalysisRegistry();
+    }
+
+    public void testCategorize_withDriver() {
+        BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking();
+        CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST);
+        DriverContext driverContext = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays));
+        boolean withNull = randomBoolean();
+        boolean withMultivalues = randomBoolean();
+
+        List<BlockHash.GroupSpec> groupSpecs = List.of(
+            new BlockHash.GroupSpec(0, ElementType.BYTES_REF, true),
+            new BlockHash.GroupSpec(1, ElementType.INT, false)
+        );
+
+        LocalSourceOperator.BlockSupplier input1 = () -> {
+            try (
+                BytesRefBlock.Builder messagesBuilder = driverContext.blockFactory().newBytesRefBlockBuilder(10);
+                IntBlock.Builder idsBuilder = driverContext.blockFactory().newIntBlockBuilder(10)
+            ) {
+                if (withMultivalues) {
+                    messagesBuilder.beginPositionEntry();
+                }
+                messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.1"));
+                messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.2"));
+                if (withMultivalues) {
+                    messagesBuilder.endPositionEntry();
+                }
+                idsBuilder.appendInt(7);
+                if (withMultivalues == false) {
+                    idsBuilder.appendInt(7);
+                }
+
+                messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.3"));
+                messagesBuilder.appendBytesRef(new BytesRef("connection error"));
+                messagesBuilder.appendBytesRef(new BytesRef("connection error"));
+                messagesBuilder.appendBytesRef(new BytesRef("connected to 1.1.4"));
+                idsBuilder.appendInt(42);
+                idsBuilder.appendInt(7);
+                idsBuilder.appendInt(42);
+                idsBuilder.appendInt(7);
+
+                if (withNull) {
+                    messagesBuilder.appendNull();
+                    idsBuilder.appendInt(43);
+                }
+                return new Block[] { messagesBuilder.build(), idsBuilder.build() };
+            }
+        };
+        LocalSourceOperator.BlockSupplier input2 = () -> {
+            try (
+                BytesRefBlock.Builder messagesBuilder = driverContext.blockFactory().newBytesRefBlockBuilder(10);
+                IntBlock.Builder idsBuilder = driverContext.blockFactory().newIntBlockBuilder(10)
+            ) {
+                messagesBuilder.appendBytesRef(new BytesRef("connected to 2.1.1"));
+                messagesBuilder.appendBytesRef(new BytesRef("connected to 2.1.2"));
+                messagesBuilder.appendBytesRef(new BytesRef("disconnected"));
+                messagesBuilder.appendBytesRef(new BytesRef("connection error"));
+                idsBuilder.appendInt(111);
+                idsBuilder.appendInt(7);
+                idsBuilder.appendInt(7);
+                idsBuilder.appendInt(42);
+                if (withNull) {
+                    messagesBuilder.appendNull();
+                    idsBuilder.appendNull();
+                }
+                return new Block[] { messagesBuilder.build(), idsBuilder.build() };
+            }
+        };
+
+        List<Page> intermediateOutput = new ArrayList<>();
+
+        Driver driver = new Driver(
+            driverContext,
+            new LocalSourceOperator(input1),
+            List.of(
+                new HashAggregationOperator.HashAggregationOperatorFactory(
+                    groupSpecs,
+                    AggregatorMode.INITIAL,
+                    List.of(new ValuesBytesRefAggregatorFunctionSupplier(List.of(0)).groupingAggregatorFactory(AggregatorMode.INITIAL)),
+                    16 * 1024,
+                    analysisRegistry
+                ).get(driverContext)
+            ),
+            new PageConsumerOperator(intermediateOutput::add),
+            () -> {}
+        );
+        runDriver(driver);
+
+        driver = new Driver(
+            driverContext,
+            new LocalSourceOperator(input2),
+            List.of(
+                new HashAggregationOperator.HashAggregationOperatorFactory(
+                    groupSpecs,
+                    AggregatorMode.INITIAL,
+                    List.of(new ValuesBytesRefAggregatorFunctionSupplier(List.of(0)).groupingAggregatorFactory(AggregatorMode.INITIAL)),
+                    16 * 1024,
+                    analysisRegistry
+                ).get(driverContext)
+            ),
+            new PageConsumerOperator(intermediateOutput::add),
+            () -> {}
+        );
+        runDriver(driver);
+
+        List<Page> finalOutput = new ArrayList<>();
+
+        driver = new Driver(
+            driverContext,
+            new CannedSourceOperator(intermediateOutput.iterator()),
+            List.of(
+                new HashAggregationOperator.HashAggregationOperatorFactory(
+                    groupSpecs,
+                    AggregatorMode.FINAL,
+                    List.of(new ValuesBytesRefAggregatorFunctionSupplier(List.of(2)).groupingAggregatorFactory(AggregatorMode.FINAL)),
+                    16 * 1024,
+                    analysisRegistry
+                ).get(driverContext)
+            ),
+            new PageConsumerOperator(finalOutput::add),
+            () -> {}
+        );
+        runDriver(driver);
+
+        assertThat(finalOutput, hasSize(1));
+        assertThat(finalOutput.get(0).getBlockCount(), equalTo(3));
+        BytesRefBlock outputMessages = finalOutput.get(0).getBlock(0);
+        IntBlock outputIds = finalOutput.get(0).getBlock(1);
+        BytesRefBlock outputValues = finalOutput.get(0).getBlock(2);
+        assertThat(outputIds.getPositionCount(), equalTo(outputMessages.getPositionCount()));
+        assertThat(outputValues.getPositionCount(), equalTo(outputMessages.getPositionCount()));
+        Map<String, Map<Integer, Set<String>>> result = new HashMap<>();
+        for (int i = 0; i < outputMessages.getPositionCount(); i++) {
+            BytesRef messageBytesRef = ((BytesRef) BlockUtils.toJavaObject(outputMessages, i));
+            String message = messageBytesRef == null ? null : messageBytesRef.utf8ToString();
+            result.computeIfAbsent(message, key -> new HashMap<>());
+
+            Integer id = (Integer) BlockUtils.toJavaObject(outputIds, i);
+            result.get(message).computeIfAbsent(id, key -> new HashSet<>());
+
+            Object values = BlockUtils.toJavaObject(outputValues, i);
+            if (values == null) {
+                result.get(message).get(id).add(null);
+            } else {
+                if ((values instanceof List) == false) {
+                    values = List.of(values);
+                }
+                for (Object valueObject : (List<?>) values) {
+                    BytesRef value = (BytesRef) valueObject;
+                    result.get(message).get(id).add(value.utf8ToString());
+                }
+            }
+        }
+        Releasables.close(() -> Iterators.map(finalOutput.iterator(), (Page p) -> p::releaseBlocks));
+
+        Map<String, Map<Integer, Set<String>>> expectedResult = Map.of(
+            ".*?connected.+?to.*?",
+            Map.of(
+                7,
+                Set.of("connected to 1.1.1", "connected to 1.1.2", "connected to 1.1.4", "connected to 2.1.2"),
+                42,
+                Set.of("connected to 1.1.3"),
+                111,
+                Set.of("connected to 2.1.1")
+            ),
+            ".*?connection.+?error.*?",
+            Map.of(7, Set.of("connection error"), 42, Set.of("connection error")),
+            ".*?disconnected.*?",
+            Map.of(7, Set.of("disconnected"))
+        );
+        if (withNull) {
+            expectedResult = new HashMap<>(expectedResult);
+            expectedResult.put(null, new HashMap<>());
+            expectedResult.get(null).put(null, new HashSet<>());
+            expectedResult.get(null).get(null).add(null);
+            expectedResult.get(null).put(43, new HashSet<>());
+            expectedResult.get(null).get(43).add(null);
+        }
+        assertThat(result, equalTo(expectedResult));
+    }
+}

+ 169 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec

@@ -60,6 +60,19 @@ COUNT():long | VALUES(str):keyword | category:keyword
            1 | [a, b, c]           | .*?disconnected.*?
 ;
 
+limit before stats
+required_capability: categorize_v5
+
+FROM sample_data | SORT message | LIMIT 4
+  | STATS count=COUNT() BY category=CATEGORIZE(message)
+  | SORT category
+;
+
+count:long | category:keyword
+         3 | .*?Connected.+?to.*?
+         1 | .*?Connection.+?error.*?
+;
+
 skips stopwords
 required_capability: categorize_v5
 
@@ -615,3 +628,159 @@ COUNT():long | x:keyword
            3 | [.*?Connection.+?error.*?,.*?Connection.+?error.*?]
            1 | [.*?Disconnected.*?,.*?Disconnected.*?]
 ;
+
+multiple groupings with categorize and ip
+required_capability: categorize_multiple_groupings
+
+FROM sample_data
+  | STATS count=COUNT() BY category=CATEGORIZE(message), client_ip
+  | SORT category, client_ip
+;
+
+count:long | category:keyword         | client_ip:ip
+         1 | .*?Connected.+?to.*?     | 172.21.2.113
+         1 | .*?Connected.+?to.*?     | 172.21.2.162
+         1 | .*?Connected.+?to.*?     | 172.21.3.15
+         3 | .*?Connection.+?error.*? | 172.21.3.15
+         1 | .*?Disconnected.*?       | 172.21.0.5
+;
+
+multiple groupings with categorize and bucketed timestamp
+required_capability: categorize_multiple_groupings
+
+FROM sample_data
+  | STATS count=COUNT() BY category=CATEGORIZE(message), timestamp=BUCKET(@timestamp, 1 HOUR)
+  | SORT category, timestamp
+;
+
+count:long | category:keyword         | timestamp:datetime
+         2 | .*?Connected.+?to.*?     | 2023-10-23T12:00:00.000Z
+         1 | .*?Connected.+?to.*?     | 2023-10-23T13:00:00.000Z
+         3 | .*?Connection.+?error.*? | 2023-10-23T13:00:00.000Z
+         1 | .*?Disconnected.*?       | 2023-10-23T13:00:00.000Z
+;
+
+
+multiple groupings with categorize and limit before stats
+required_capability: categorize_multiple_groupings
+
+FROM sample_data | SORT message | LIMIT 5
+  | STATS count=COUNT() BY category=CATEGORIZE(message), client_ip
+  | SORT category, client_ip
+;
+
+count:long | category:keyword         | client_ip:ip
+         1 | .*?Connected.+?to.*?     | 172.21.2.113
+         1 | .*?Connected.+?to.*?     | 172.21.2.162
+         1 | .*?Connected.+?to.*?     | 172.21.3.15
+         2 | .*?Connection.+?error.*? | 172.21.3.15
+;
+
+multiple groupings with categorize and nulls
+required_capability: categorize_multiple_groupings
+
+FROM employees
+  | STATS SUM(languages) BY category=CATEGORIZE(job_positions), gender
+  | SORT category DESC, gender ASC
+  | LIMIT 5
+;
+
+SUM(languages):long | category:keyword  | gender:keyword
+                 11 | null              | F
+                 16 | null              | M
+                 14 | .*?Tech.+?Lead.*? | F
+                 23 | .*?Tech.+?Lead.*? | M
+                  9 | .*?Tech.+?Lead.*? | null
+;
+
+multiple groupings with categorize and a field that's always null
+required_capability: categorize_multiple_groupings
+
+FROM sample_data
+  | EVAL nullfield = null
+  | STATS count=COUNT() BY category=CATEGORIZE(nullfield), client_ip
+  | SORT client_ip
+;
+
+count:long | category:keyword | client_ip:ip
+         1 | null             | 172.21.0.5
+         1 | null             | 172.21.2.113
+         1 | null             | 172.21.2.162
+         4 | null             | 172.21.3.15
+;
+
+multiple groupings with categorize and the same text field
+required_capability: categorize_multiple_groupings
+
+FROM sample_data
+  | STATS count=COUNT() BY category=CATEGORIZE(message), message
+  | SORT message
+;
+
+count:long | category:keyword         | message:keyword
+         1 | .*?Connected.+?to.*?     | Connected to 10.1.0.1
+         1 | .*?Connected.+?to.*?     | Connected to 10.1.0.2
+         1 | .*?Connected.+?to.*?     | Connected to 10.1.0.3
+         3 | .*?Connection.+?error.*? | Connection error
+         1 | .*?Disconnected.*?       | Disconnected
+;
+
+multiple additional complex groupings with categorize
+required_capability: categorize_multiple_groupings
+
+FROM sample_data
+  | STATS count=COUNT(), duration=SUM(event_duration) BY category=CATEGORIZE(message), SUBSTRING(message, 1, 7), ip_part=TO_LONG(SUBSTRING(TO_STRING(client_ip), 8, 1)), hour=BUCKET(@timestamp, 1 HOUR)
+  | SORT ip_part, category
+;
+
+count:long | duration:long | category:keyword         | SUBSTRING(message, 1, 7):keyword | ip_part:long | hour:datetime
+         1 | 1232382       | .*?Disconnected.*?       | Disconn                          | 0            | 2023-10-23T13:00:00.000Z
+         2 | 6215122       | .*?Connected.+?to.*?     | Connect                          | 2            | 2023-10-23T12:00:00.000Z
+         1 | 1756467       | .*?Connected.+?to.*?     | Connect                          | 3            | 2023-10-23T13:00:00.000Z
+         3 | 14027356      | .*?Connection.+?error.*? | Connect                          | 3            | 2023-10-23T13:00:00.000Z
+;
+
+multiple groupings with categorize and some constants including null
+required_capability: categorize_multiple_groupings
+
+FROM sample_data
+  | STATS count=MV_COUNT(VALUES(message)) BY category=CATEGORIZE(message), null, constant="constant"
+  | SORT category
+;
+
+count:integer | category:keyword         | null:null | constant:keyword
+            3 | .*?Connected.+?to.*?     | null      | constant
+            1 | .*?Connection.+?error.*? | null      | constant
+            1 | .*?Disconnected.*?       | null      | constant
+;
+
+multiple groupings with categorize and aggregation filters
+required_capability: categorize_multiple_groupings
+
+FROM employees
+  | STATS lang_low=AVG(languages) WHERE salary<=50000, lang_high=AVG(languages) WHERE salary>50000 BY category=CATEGORIZE(job_positions), gender
+  | SORT category, gender
+  | LIMIT 5
+;
+
+lang_low:double | lang_high:double | category:keyword  | gender:keyword
+            2.0 |              5.0 | .*?Accountant.*?  | F
+            3.0 |              2.5 | .*?Accountant.*?  | M
+            5.0 |              2.0 | .*?Accountant.*?  | null
+            3.0 |             3.25 | .*?Architect.*?   | F
+           3.75 |             null | .*?Architect.*?   | M
+;
+
+multiple groupings with categorize on null row
+required_capability: categorize_multiple_groupings
+
+ROW message = null, str = ["a", "b", "c"]
+  | STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message), str
+  | SORT str
+;
+
+COUNT():long | VALUES(str):keyword | category:keyword | str:keyword
+           1 | [a, b, c]           | null             | a
+           1 | [a, b, c]           | null             | b
+           1 | [a, b, c]           | null             | c
+;

+ 4 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -407,6 +407,10 @@ public class EsqlCapabilities {
          */
         CATEGORIZE_V5,
 
+        /**
+         * Support for multiple groupings in "CATEGORIZE".
+         */
+        CATEGORIZE_MULTIPLE_GROUPINGS,
         /**
          * QSTR function
          */

+ 6 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java

@@ -325,11 +325,15 @@ public class Verifier {
     private static void checkCategorizeGrouping(Aggregate agg, Set<Failure> failures) {
         // Forbid CATEGORIZE grouping function with other groupings
         if (agg.groupings().size() > 1) {
-            agg.groupings().forEach(g -> {
+            agg.groupings().subList(1, agg.groupings().size()).forEach(g -> {
                 g.forEachDown(
                     Categorize.class,
                     categorize -> failures.add(
-                        fail(categorize, "cannot use CATEGORIZE grouping function [{}] with multiple groupings", categorize.sourceText())
+                        fail(
+                            categorize,
+                            "CATEGORIZE grouping function [{}] can only be in the first grouping expression",
+                            categorize.sourceText()
+                        )
                     )
                 );
             });

+ 2 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Categorize.java

@@ -95,7 +95,8 @@ public class Categorize extends GroupingFunction implements Validatable {
 
     @Override
     public Nullability nullable() {
-        // Both nulls and empty strings result in null values
+        // Null strings and strings that don't produce tokens after analysis lead to null values.
+        // This includes empty strings, only whitespace, (hexa)decimal numbers and stopwords.
         return Nullability.TRUE;
     }
 

+ 15 - 22
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -1894,38 +1894,35 @@ public class VerifierTests extends ESTestCase {
         );
     }
 
-    public void testCategorizeSingleGrouping() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
-
-        query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)");
-        query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)");
+    public void testCategorizeOnlyFirstGrouping() {
+        query("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name)");
+        query("FROM test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)");
+        query("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), emp_no");
+        query("FROM test | STATS COUNT(*) BY a = CATEGORIZE(first_name), b = emp_no");
 
         assertEquals(
-            "1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings",
-            error("from test | STATS COUNT(*) BY CATEGORIZE(first_name), emp_no")
+            "1:39: CATEGORIZE grouping function [CATEGORIZE(first_name)] can only be in the first grouping expression",
+            error("FROM test | STATS COUNT(*) BY emp_no, CATEGORIZE(first_name)")
         );
         assertEquals(
-            "1:39: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings",
-            error("FROM test | STATS COUNT(*) BY emp_no, CATEGORIZE(first_name)")
+            "1:55: CATEGORIZE grouping function [CATEGORIZE(last_name)] can only be in the first grouping expression",
+            error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(last_name)")
         );
         assertEquals(
-            "1:35: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings",
-            error("FROM test | STATS COUNT(*) BY a = CATEGORIZE(first_name), b = emp_no")
+            "1:55: CATEGORIZE grouping function [CATEGORIZE(first_name)] can only be in the first grouping expression",
+            error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(first_name)")
         );
         assertEquals(
-            "1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings\n"
-                + "line 1:55: cannot use CATEGORIZE grouping function [CATEGORIZE(last_name)] with multiple groupings",
-            error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(last_name)")
+            "1:63: CATEGORIZE grouping function [CATEGORIZE(last_name)] can only be in the first grouping expression",
+            error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), emp_no, CATEGORIZE(last_name)")
         );
         assertEquals(
-            "1:31: cannot use CATEGORIZE grouping function [CATEGORIZE(first_name)] with multiple groupings",
-            error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), CATEGORIZE(first_name)")
+            "1:63: CATEGORIZE grouping function [CATEGORIZE(first_name)] can only be in the first grouping expression",
+            error("FROM test | STATS COUNT(*) BY CATEGORIZE(first_name), emp_no, CATEGORIZE(first_name)")
         );
     }
 
     public void testCategorizeNestedGrouping() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
-
         query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)");
 
         assertEquals(
@@ -1939,8 +1936,6 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeWithinAggregations() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
-
         query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)");
         query("from test | STATS MV_COUNT(CATEGORIZE(first_name)), COUNT(*) BY cat = CATEGORIZE(first_name)");
         query("from test | STATS MV_COUNT(CATEGORIZE(first_name)), COUNT(*) BY CATEGORIZE(first_name)");
@@ -1969,8 +1964,6 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeWithFilteredAggregations() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
-
         query("FROM test | STATS COUNT(*) WHERE first_name == \"John\" BY CATEGORIZE(last_name)");
         query("FROM test | STATS COUNT(*) WHERE last_name == \"Doe\" BY CATEGORIZE(last_name)");
 

+ 0 - 5
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -20,7 +20,6 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.esql.EsqlTestUtils;
 import org.elasticsearch.xpack.esql.TestBlockFactory;
 import org.elasticsearch.xpack.esql.VerificationException;
-import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
 import org.elasticsearch.xpack.esql.analysis.Analyzer;
 import org.elasticsearch.xpack.esql.analysis.AnalyzerContext;
 import org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils;
@@ -1212,8 +1211,6 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
      *   \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..]
      */
     public void testCombineProjectionWithCategorizeGrouping() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
-
         var plan = plan("""
             from test
             | eval k = first_name, k1 = k
@@ -3949,8 +3946,6 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
      *     \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..]
      */
     public void testNestedExpressionsInGroupsWithCategorize() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V5.isEnabled());
-
         var plan = optimizedPlan("""
             from test
             | stats c = count(salary) by CATEGORIZE(CONCAT(first_name, "abc"))