Browse Source

ESQL: Add nulls support to Categorize (#117655) (#117716)

Handle nulls and empty strings (Which resolve to null) on Categorize grouping function.

Also, implement `seenGroupIds()`, which would fail some queries with nulls otherwise.
Iván Cea Fontenla 10 months ago
parent
commit
74cf2c63aa

+ 5 - 0
docs/changelog/117655.yaml

@@ -0,0 +1,5 @@
+pr: 117655
+summary: Add nulls support to Categorize
+area: ES|QL
+type: enhancement
+issues: []

+ 32 - 5
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/AbstractCategorizeBlockHash.java

@@ -13,8 +13,10 @@ import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.BitArray;
 import org.elasticsearch.common.util.BytesRefHash;
+import org.elasticsearch.compute.aggregation.SeenGroupIds;
 import org.elasticsearch.compute.data.Block;
 import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BytesRefBlock;
 import org.elasticsearch.compute.data.BytesRefVector;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
@@ -31,11 +33,21 @@ import java.io.IOException;
  * Base BlockHash implementation for {@code Categorize} grouping function.
  */
 public abstract class AbstractCategorizeBlockHash extends BlockHash {
+    protected static final int NULL_ORD = 0;
+
     // TODO: this should probably also take an emitBatchSize
     private final int channel;
     private final boolean outputPartial;
     protected final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
 
+    /**
+     * Store whether we've seen any {@code null} values.
+     * <p>
+     *     Null gets the {@link #NULL_ORD} ord.
+     * </p>
+     */
+    protected boolean seenNull = false;
+
     AbstractCategorizeBlockHash(BlockFactory blockFactory, int channel, boolean outputPartial) {
         super(blockFactory);
         this.channel = channel;
@@ -58,12 +70,12 @@ public abstract class AbstractCategorizeBlockHash extends BlockHash {
 
     @Override
     public IntVector nonEmpty() {
-        return IntVector.range(0, categorizer.getCategoryCount(), blockFactory);
+        return IntVector.range(seenNull ? 0 : 1, categorizer.getCategoryCount() + 1, blockFactory);
     }
 
     @Override
     public BitArray seenGroupIds(BigArrays bigArrays) {
-        throw new UnsupportedOperationException();
+        return new SeenGroupIds.Range(seenNull ? 0 : 1, Math.toIntExact(categorizer.getCategoryCount() + 1)).seenGroupIds(bigArrays);
     }
 
     @Override
@@ -76,24 +88,39 @@ public abstract class AbstractCategorizeBlockHash extends BlockHash {
      */
     private Block buildIntermediateBlock() {
         if (categorizer.getCategoryCount() == 0) {
-            return blockFactory.newConstantNullBlock(0);
+            return blockFactory.newConstantNullBlock(seenNull ? 1 : 0);
         }
         try (BytesStreamOutput out = new BytesStreamOutput()) {
             // TODO be more careful here.
+            out.writeBoolean(seenNull);
             out.writeVInt(categorizer.getCategoryCount());
             for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
                 category.writeTo(out);
             }
             // We're returning a block with N positions just because the Page must have all blocks with the same position count!
-            return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), categorizer.getCategoryCount());
+            int positionCount = categorizer.getCategoryCount() + (seenNull ? 1 : 0);
+            return blockFactory.newConstantBytesRefBlockWith(out.bytes().toBytesRef(), positionCount);
         } catch (IOException e) {
             throw new RuntimeException(e);
         }
     }
 
     private Block buildFinalBlock() {
+        BytesRefBuilder scratch = new BytesRefBuilder();
+
+        if (seenNull) {
+            try (BytesRefBlock.Builder result = blockFactory.newBytesRefBlockBuilder(categorizer.getCategoryCount())) {
+                result.appendNull();
+                for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
+                    scratch.copyChars(category.getRegex());
+                    result.appendBytesRef(scratch.get());
+                    scratch.clear();
+                }
+                return result.build();
+            }
+        }
+
         try (BytesRefVector.Builder result = blockFactory.newBytesRefVectorBuilder(categorizer.getCategoryCount())) {
-            BytesRefBuilder scratch = new BytesRefBuilder();
             for (SerializableTokenListCategory category : categorizer.toCategoriesById()) {
                 scratch.copyChars(category.getRegex());
                 result.appendBytesRef(scratch.get());

+ 9 - 3
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeRawBlockHash.java

@@ -64,7 +64,7 @@ public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
     /**
      * Similar implementation to an Evaluator.
      */
-    public static final class CategorizeEvaluator implements Releasable {
+    public final class CategorizeEvaluator implements Releasable {
         private final CategorizationAnalyzer analyzer;
 
         private final TokenListCategorizer.CloseableTokenListCategorizer categorizer;
@@ -95,7 +95,8 @@ public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
                 BytesRef vScratch = new BytesRef();
                 for (int p = 0; p < positionCount; p++) {
                     if (vBlock.isNull(p)) {
-                        result.appendNull();
+                        seenNull = true;
+                        result.appendInt(NULL_ORD);
                         continue;
                     }
                     int first = vBlock.getFirstValueIndex(p);
@@ -126,7 +127,12 @@ public class CategorizeRawBlockHash extends AbstractCategorizeBlockHash {
         }
 
         private int process(BytesRef v) {
-            return categorizer.computeCategory(v.utf8ToString(), analyzer).getId();
+            var category = categorizer.computeCategory(v.utf8ToString(), analyzer);
+            if (category == null) {
+                seenNull = true;
+                return NULL_ORD;
+            }
+            return category.getId() + 1;
         }
 
         @Override

+ 17 - 2
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/CategorizedIntermediateBlockHash.java

@@ -40,9 +40,19 @@ public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHas
             return;
         }
         BytesRefBlock categorizerState = page.getBlock(channel());
+        if (categorizerState.areAllValuesNull()) {
+            seenNull = true;
+            try (var newIds = blockFactory.newConstantIntVector(NULL_ORD, 1)) {
+                addInput.add(0, newIds);
+            }
+            return;
+        }
+
         Map<Integer, Integer> idMap = readIntermediate(categorizerState.getBytesRef(0, new BytesRef()));
         try (IntBlock.Builder newIdsBuilder = blockFactory.newIntBlockBuilder(idMap.size())) {
-            for (int i = 0; i < idMap.size(); i++) {
+            int fromId = idMap.containsKey(0) ? 0 : 1;
+            int toId = fromId + idMap.size();
+            for (int i = fromId; i < toId; i++) {
                 newIdsBuilder.appendInt(idMap.get(i));
             }
             try (IntBlock newIds = newIdsBuilder.build()) {
@@ -59,10 +69,15 @@ public class CategorizedIntermediateBlockHash extends AbstractCategorizeBlockHas
     private Map<Integer, Integer> readIntermediate(BytesRef bytes) {
         Map<Integer, Integer> idMap = new HashMap<>();
         try (StreamInput in = new BytesArray(bytes).streamInput()) {
+            if (in.readBoolean()) {
+                seenNull = true;
+                idMap.put(NULL_ORD, NULL_ORD);
+            }
             int count = in.readVInt();
             for (int oldCategoryId = 0; oldCategoryId < count; oldCategoryId++) {
                 int newCategoryId = categorizer.mergeWireCategory(new SerializableTokenListCategory(in)).getId();
-                idMap.put(oldCategoryId, newCategoryId);
+                // +1 because the 0 ordinal is reserved for null
+                idMap.put(oldCategoryId + 1, newCategoryId + 1);
             }
             return idMap;
         } catch (IOException e) {

+ 49 - 23
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java

@@ -52,7 +52,8 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
 
     public void testCategorizeRaw() {
         final Page page;
-        final int positions = 7;
+        boolean withNull = randomBoolean();
+        final int positions = 7 + (withNull ? 1 : 0);
         try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions)) {
             builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
             builder.appendBytesRef(new BytesRef("Connection error"));
@@ -61,6 +62,13 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
             builder.appendBytesRef(new BytesRef("Disconnected"));
             builder.appendBytesRef(new BytesRef("Connected to 10.1.0.2"));
             builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
+            if (withNull) {
+                if (randomBoolean()) {
+                    builder.appendNull();
+                } else {
+                    builder.appendBytesRef(new BytesRef(""));
+                }
+            }
             page = new Page(builder.build());
         }
 
@@ -70,13 +78,16 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
                 public void add(int positionOffset, IntBlock groupIds) {
                     assertEquals(groupIds.getPositionCount(), positions);
 
-                    assertEquals(0, groupIds.getInt(0));
-                    assertEquals(1, groupIds.getInt(1));
-                    assertEquals(1, groupIds.getInt(2));
-                    assertEquals(1, groupIds.getInt(3));
-                    assertEquals(2, groupIds.getInt(4));
-                    assertEquals(0, groupIds.getInt(5));
-                    assertEquals(0, groupIds.getInt(6));
+                    assertEquals(1, groupIds.getInt(0));
+                    assertEquals(2, groupIds.getInt(1));
+                    assertEquals(2, groupIds.getInt(2));
+                    assertEquals(2, groupIds.getInt(3));
+                    assertEquals(3, groupIds.getInt(4));
+                    assertEquals(1, groupIds.getInt(5));
+                    assertEquals(1, groupIds.getInt(6));
+                    if (withNull) {
+                        assertEquals(0, groupIds.getInt(7));
+                    }
                 }
 
                 @Override
@@ -100,7 +111,8 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
 
     public void testCategorizeIntermediate() {
         Page page1;
-        int positions1 = 7;
+        boolean withNull = randomBoolean();
+        int positions1 = 7 + (withNull ? 1 : 0);
         try (BytesRefBlock.Builder builder = blockFactory.newBytesRefBlockBuilder(positions1)) {
             builder.appendBytesRef(new BytesRef("Connected to 10.1.0.1"));
             builder.appendBytesRef(new BytesRef("Connection error"));
@@ -109,6 +121,13 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
             builder.appendBytesRef(new BytesRef("Connection error"));
             builder.appendBytesRef(new BytesRef("Connected to 10.1.0.3"));
             builder.appendBytesRef(new BytesRef("Connected to 10.1.0.4"));
+            if (withNull) {
+                if (randomBoolean()) {
+                    builder.appendNull();
+                } else {
+                    builder.appendBytesRef(new BytesRef(""));
+                }
+            }
             page1 = new Page(builder.build());
         }
         Page page2;
@@ -133,13 +152,16 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
                 @Override
                 public void add(int positionOffset, IntBlock groupIds) {
                     assertEquals(groupIds.getPositionCount(), positions1);
-                    assertEquals(0, groupIds.getInt(0));
-                    assertEquals(1, groupIds.getInt(1));
-                    assertEquals(1, groupIds.getInt(2));
-                    assertEquals(0, groupIds.getInt(3));
-                    assertEquals(1, groupIds.getInt(4));
-                    assertEquals(0, groupIds.getInt(5));
-                    assertEquals(0, groupIds.getInt(6));
+                    assertEquals(1, groupIds.getInt(0));
+                    assertEquals(2, groupIds.getInt(1));
+                    assertEquals(2, groupIds.getInt(2));
+                    assertEquals(1, groupIds.getInt(3));
+                    assertEquals(2, groupIds.getInt(4));
+                    assertEquals(1, groupIds.getInt(5));
+                    assertEquals(1, groupIds.getInt(6));
+                    if (withNull) {
+                        assertEquals(0, groupIds.getInt(7));
+                    }
                 }
 
                 @Override
@@ -158,11 +180,11 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
                 @Override
                 public void add(int positionOffset, IntBlock groupIds) {
                     assertEquals(groupIds.getPositionCount(), positions2);
-                    assertEquals(0, groupIds.getInt(0));
-                    assertEquals(1, groupIds.getInt(1));
-                    assertEquals(0, groupIds.getInt(2));
-                    assertEquals(1, groupIds.getInt(3));
-                    assertEquals(2, groupIds.getInt(4));
+                    assertEquals(1, groupIds.getInt(0));
+                    assertEquals(2, groupIds.getInt(1));
+                    assertEquals(1, groupIds.getInt(2));
+                    assertEquals(2, groupIds.getInt(3));
+                    assertEquals(3, groupIds.getInt(4));
                 }
 
                 @Override
@@ -189,7 +211,11 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
                         .map(groupIds::getInt)
                         .boxed()
                         .collect(Collectors.toSet());
-                    assertEquals(values, Set.of(0, 1));
+                    if (withNull) {
+                        assertEquals(Set.of(0, 1, 2), values);
+                    } else {
+                        assertEquals(Set.of(1, 2), values);
+                    }
                 }
 
                 @Override
@@ -212,7 +238,7 @@ public class CategorizeBlockHashTests extends BlockHashTestCase {
                         .collect(Collectors.toSet());
                     // The category IDs {0, 1, 2} should map to groups {0, 2, 3}, because
                     // 0 matches an existing category (Connected to ...), and the others are new.
-                    assertEquals(values, Set.of(0, 2, 3));
+                    assertEquals(Set.of(1, 3, 4), values);
                 }
 
                 @Override

+ 66 - 56
x-pack/plugin/esql/qa/testFixtures/src/main/resources/categorize.csv-spec

@@ -1,5 +1,5 @@
 standard aggs
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS count=COUNT(),
@@ -17,7 +17,7 @@ count:long | sum:long |     avg:double     | count_distinct:long | category:keyw
 ;
 
 values aggs
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS values=MV_SORT(VALUES(message)),
@@ -33,7 +33,7 @@ values:keyword                                                        |      top
 ;
 
 mv
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM mv_sample_data
   | STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(message)
@@ -48,7 +48,7 @@ COUNT():long | SUM(event_duration):long | category:keyword
 ;
 
 row mv
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 ROW message = ["connected to a", "connected to b", "disconnected"], str = ["a", "b", "c"]
   | STATS COUNT(), VALUES(str) BY category=CATEGORIZE(message)
@@ -61,7 +61,7 @@ COUNT():long | VALUES(str):keyword | category:keyword
 ;
 
 with multiple indices
-required_capability: categorize_v2
+required_capability: categorize_v3
 required_capability: union_types
 
 FROM sample_data*
@@ -76,7 +76,7 @@ COUNT():long | category:keyword
 ;
 
 mv with many values
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM employees
   | STATS COUNT() BY category=CATEGORIZE(job_positions)
@@ -92,24 +92,37 @@ COUNT():long | category:keyword
            10 | .*?Head.+?Human.+?Resources.*?
 ;
 
-# Throws when calling AbstractCategorizeBlockHash.seenGroupIds() - Requires nulls support?
-mv with many values-Ignore
-required_capability: categorize_v2
+mv with many values and SUM
+required_capability: categorize_v3
 
 FROM employees
   | STATS SUM(languages) BY category=CATEGORIZE(job_positions)
-  | SORT category DESC
+  | SORT category
   | LIMIT 3
 ;
 
-SUM(languages):integer | category:keyword
-                    43 | .*?Accountant.*?
-                    46 | .*?Architect.*?
-                    35 | .*?Business.+?Analyst.*?
+SUM(languages):long | category:keyword
+                 43 | .*?Accountant.*?
+                 46 | .*?Architect.*?
+                 35 | .*?Business.+?Analyst.*?
+;
+
+mv with many values and nulls and SUM
+required_capability: categorize_v3
+
+FROM employees
+  | STATS SUM(languages) BY category=CATEGORIZE(job_positions)
+  | SORT category DESC
+  | LIMIT 2
+;
+
+SUM(languages):long | category:keyword
+                 27 | null
+                 46 | .*?Tech.+?Lead.*?
 ;
 
 mv via eval
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | EVAL message = MV_APPEND(message, "Banana")
@@ -125,7 +138,7 @@ COUNT():long | category:keyword
 ;
 
 mv via eval const
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | EVAL message = ["Banana", "Bread"]
@@ -139,7 +152,7 @@ COUNT():long | category:keyword
 ;
 
 mv via eval const without aliases
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | EVAL message = ["Banana", "Bread"]
@@ -153,7 +166,7 @@ COUNT():long | CATEGORIZE(message):keyword
 ;
 
 mv const in parameter
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
@@ -166,7 +179,7 @@ COUNT():long | c:keyword
 ;
 
 agg alias shadowing
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS c = COUNT() BY c = CATEGORIZE(["Banana", "Bread"])
@@ -181,7 +194,7 @@ c:keyword
 ;
 
 chained aggregations using categorize
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(message)
@@ -196,7 +209,7 @@ COUNT():long | category:keyword
 ;
 
 stats without aggs
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS BY category=CATEGORIZE(message)
@@ -210,7 +223,7 @@ category:keyword
 ;
 
 text field
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM hosts
   | STATS COUNT() BY category=CATEGORIZE(host_group)
@@ -221,10 +234,11 @@ COUNT():long | category:keyword
            2 | .*?DB.+?servers.*?
            2 | .*?Gateway.+?instances.*?
            5 | .*?Kubernetes.+?cluster.*?
+           1 | null
 ;
 
 on TO_UPPER
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(TO_UPPER(message))
@@ -238,7 +252,7 @@ COUNT():long | category:keyword
 ;
 
 on CONCAT
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " banana"))
@@ -252,7 +266,7 @@ COUNT():long | category:keyword
 ;
 
 on CONCAT with unicode
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(CONCAT(message, " 👍🏽😊"))
@@ -266,7 +280,7 @@ COUNT():long | category:keyword
 ;
 
 on REVERSE(CONCAT())
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(REVERSE(CONCAT(message, " 👍🏽😊")))
@@ -280,7 +294,7 @@ COUNT():long | category:keyword
 ;
 
 and then TO_LOWER
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(message)
@@ -294,9 +308,8 @@ COUNT():long | category:keyword
            1 | .*?disconnected.*?
 ;
 
-# Throws NPE - Requires nulls support
-on const empty string-Ignore
-required_capability: categorize_v2
+on const empty string
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE("")
@@ -304,12 +317,11 @@ FROM sample_data
 ;
 
 COUNT():long | category:keyword
-           7 | .*?.*?
+           7 | null
 ;
 
-# Throws NPE - Requires nulls support
-on const empty string from eval-Ignore
-required_capability: categorize_v2
+on const empty string from eval
+required_capability: categorize_v3
 
 FROM sample_data
   | EVAL x = ""
@@ -318,26 +330,24 @@ FROM sample_data
 ;
 
 COUNT():long | category:keyword
-           7 | .*?.*?
+           7 | null
 ;
 
-# Doesn't give the correct results - Requires nulls support
-on null-Ignore
-required_capability: categorize_v2
+on null
+required_capability: categorize_v3
 
 FROM sample_data
   | EVAL x = null
-  | STATS COUNT() BY category=CATEGORIZE(x)
+  | STATS COUNT(), SUM(event_duration) BY category=CATEGORIZE(x)
   | SORT category
 ;
 
-COUNT():long | category:keyword
-           7 | null
+COUNT():long | SUM(event_duration):long | category:keyword
+           7 |                 23231327 |  null
 ;
 
-# Doesn't give the correct results - Requires nulls support
-on null string-Ignore
-required_capability: categorize_v2
+on null string
+required_capability: categorize_v3
 
 FROM sample_data
   | EVAL x = null::string
@@ -350,7 +360,7 @@ COUNT():long | category:keyword
 ;
 
 filtering out all data
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | WHERE @timestamp < "2023-10-23T00:00:00Z"
@@ -362,7 +372,7 @@ COUNT():long | category:keyword
 ;
 
 filtering out all data with constant
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS COUNT() BY category=CATEGORIZE(message)
@@ -373,7 +383,7 @@ COUNT():long | category:keyword
 ;
 
 drop output columns
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS count=COUNT() BY category=CATEGORIZE(message)
@@ -388,7 +398,7 @@ x:integer
 ;
 
 category value processing
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 ROW message = ["connected to a", "connected to b", "disconnected"]
   | STATS COUNT() BY category=CATEGORIZE(message)
@@ -402,7 +412,7 @@ COUNT():long | category:keyword
 ;
 
 row aliases
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 ROW message = "connected to a"
   | EVAL x = message
@@ -416,7 +426,7 @@ COUNT():long | category:keyword         | y:keyword
 ;
 
 from aliases
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | EVAL x = message
@@ -432,7 +442,7 @@ COUNT():long | category:keyword         | y:keyword
 ;
 
 row aliases with keep
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 ROW message = "connected to a"
   | EVAL x = message
@@ -448,7 +458,7 @@ COUNT():long | y:keyword
 ;
 
 from aliases with keep
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | EVAL x = message
@@ -466,7 +476,7 @@ COUNT():long | y:keyword
 ;
 
 row rename
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 ROW message = "connected to a"
   | RENAME message as x
@@ -480,7 +490,7 @@ COUNT():long | y:keyword
 ;
 
 from rename
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | RENAME message as x
@@ -496,7 +506,7 @@ COUNT():long | y:keyword
 ;
 
 row drop
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 ROW message = "connected to a"
   | STATS c = COUNT() BY category=CATEGORIZE(message)
@@ -509,7 +519,7 @@ c:long
 ;
 
 from drop
-required_capability: categorize_v2
+required_capability: categorize_v3
 
 FROM sample_data
   | STATS c = COUNT() BY category=CATEGORIZE(message)

+ 1 - 4
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -395,11 +395,8 @@ public class EsqlCapabilities {
 
         /**
          * Supported the text categorization function "CATEGORIZE".
-         * <p>
-         *     This capability was initially named `CATEGORIZE`, and got renamed after the function started correctly returning keywords.
-         * </p>
          */
-        CATEGORIZE_V2(Build.current().isSnapshot()),
+        CATEGORIZE_V3(Build.current().isSnapshot()),
 
         /**
          * QSTR function

+ 3 - 3
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -1821,7 +1821,7 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeSingleGrouping() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V3.isEnabled());
 
         query("from test | STATS COUNT(*) BY CATEGORIZE(first_name)");
         query("from test | STATS COUNT(*) BY cat = CATEGORIZE(first_name)");
@@ -1850,7 +1850,7 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeNestedGrouping() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V3.isEnabled());
 
         query("from test | STATS COUNT(*) BY CATEGORIZE(LENGTH(first_name)::string)");
 
@@ -1865,7 +1865,7 @@ public class VerifierTests extends ESTestCase {
     }
 
     public void testCategorizeWithinAggregations() {
-        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V2.isEnabled());
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V3.isEnabled());
 
         query("from test | STATS MV_COUNT(cat), COUNT(*) BY cat = CATEGORIZE(first_name)");
 

+ 5 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java

@@ -20,6 +20,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.esql.EsqlTestUtils;
 import org.elasticsearch.xpack.esql.TestBlockFactory;
 import org.elasticsearch.xpack.esql.VerificationException;
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
 import org.elasticsearch.xpack.esql.analysis.Analyzer;
 import org.elasticsearch.xpack.esql.analysis.AnalyzerContext;
 import org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils;
@@ -1211,6 +1212,8 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
      *   \_EsRelation[test][_meta_field{f}#23, emp_no{f}#17, first_name{f}#18, ..]
      */
     public void testCombineProjectionWithCategorizeGrouping() {
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V3.isEnabled());
+
         var plan = plan("""
             from test
             | eval k = first_name, k1 = k
@@ -3946,6 +3949,8 @@ public class LogicalPlanOptimizerTests extends ESTestCase {
      *     \_EsRelation[test][_meta_field{f}#14, emp_no{f}#8, first_name{f}#9, ge..]
      */
     public void testNestedExpressionsInGroupsWithCategorize() {
+        assumeTrue("requires Categorize capability", EsqlCapabilities.Cap.CATEGORIZE_V3.isEnabled());
+
         var plan = optimizedPlan("""
             from test
             | stats c = count(salary) by CATEGORIZE(CONCAT(first_name, "abc"))

+ 2 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/categorization/TokenListCategorizer.java

@@ -115,6 +115,7 @@ public class TokenListCategorizer implements Accountable {
         cacheRamUsage(0);
     }
 
+    @Nullable
     public TokenListCategory computeCategory(String s, CategorizationAnalyzer analyzer) {
         try (TokenStream ts = analyzer.tokenStream("text", s)) {
             return computeCategory(ts, s.length(), 1);
@@ -123,6 +124,7 @@ public class TokenListCategorizer implements Accountable {
         }
     }
 
+    @Nullable
     public TokenListCategory computeCategory(TokenStream ts, int unfilteredStringLen, long numDocs) throws IOException {
         assert partOfSpeechDictionary != null
             : "This version of computeCategory should only be used when a part-of-speech dictionary is available";