Explorar el Código

Handle nulls in OrdinalsGroupingOperator (#100117)

This change introduces null handling in the OrdinalsGroupingOperator, 
replacing the current behavior which skips null keys. Ordinals are now
incremented by 1, with 0 being used to represent null ordinals.

Closes #100109
Nhat Nguyen hace 2 años
padre
commit
c6f461661f

+ 0 - 59
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/BlockOrdinalsReader.java

@@ -1,59 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.compute.lucene;
-
-import org.apache.lucene.index.SortedSetDocValues;
-import org.elasticsearch.compute.data.IntBlock;
-import org.elasticsearch.compute.data.IntVector;
-
-import java.io.IOException;
-
-public final class BlockOrdinalsReader {
-    private final SortedSetDocValues sortedSetDocValues;
-    private final Thread creationThread;
-
-    public BlockOrdinalsReader(SortedSetDocValues sortedSetDocValues) {
-        this.sortedSetDocValues = sortedSetDocValues;
-        this.creationThread = Thread.currentThread();
-    }
-
-    public IntBlock readOrdinals(IntVector docs) throws IOException {
-        final int positionCount = docs.getPositionCount();
-        IntBlock.Builder builder = IntBlock.newBlockBuilder(positionCount);
-        for (int p = 0; p < positionCount; p++) {
-            int doc = docs.getInt(p);
-            if (false == sortedSetDocValues.advanceExact(doc)) {
-                builder.appendNull();
-                continue;
-            }
-            int count = sortedSetDocValues.docValueCount();
-            // TODO don't come this way if there are a zillion ords on the field
-            if (count == 1) {
-                builder.appendInt(Math.toIntExact(sortedSetDocValues.nextOrd()));
-                continue;
-            }
-            builder.beginPositionEntry();
-            for (int i = 0; i < count; i++) {
-                builder.appendInt(Math.toIntExact(sortedSetDocValues.nextOrd()));
-            }
-            builder.endPositionEntry();
-        }
-        return builder.build();
-    }
-
-    public int docID() {
-        return sortedSetDocValues.docID();
-    }
-
-    /**
-     * Checks if the reader can be used to read a range documents starting with the given docID by the current thread.
-     */
-    public static boolean canReuse(BlockOrdinalsReader reader, int startingDocID) {
-        return reader != null && reader.creationThread == Thread.currentThread() && reader.docID() <= startingDocID;
-    }
-}

+ 70 - 11
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/OrdinalsGroupingOperator.java

@@ -27,7 +27,6 @@ import org.elasticsearch.compute.data.DocVector;
 import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.Page;
-import org.elasticsearch.compute.lucene.BlockOrdinalsReader;
 import org.elasticsearch.compute.lucene.ValueSourceInfo;
 import org.elasticsearch.compute.lucene.ValuesSourceReaderOperator;
 import org.elasticsearch.compute.operator.HashAggregationOperator.GroupSpec;
@@ -234,18 +233,31 @@ public class OrdinalsGroupingOperator implements Operator {
         };
         final List<GroupingAggregator> aggregators = createGroupingAggregators();
         try {
+            boolean seenNulls = false;
+            for (OrdinalSegmentAggregator agg : ordinalAggregators.values()) {
+                if (agg.seenNulls()) {
+                    seenNulls = true;
+                    for (int i = 0; i < aggregators.size(); i++) {
+                        aggregators.get(i).addIntermediateRow(0, agg.aggregators.get(i), 0);
+                    }
+                }
+            }
             for (OrdinalSegmentAggregator agg : ordinalAggregators.values()) {
                 final AggregatedResultIterator it = agg.getResultIterator();
                 if (it.next()) {
                     pq.add(it);
                 }
             }
-            int position = -1;
+            final int startPosition = seenNulls ? 0 : -1;
+            int position = startPosition;
             final BytesRefBuilder lastTerm = new BytesRefBuilder();
             var blockBuilder = BytesRefBlock.newBlockBuilder(1);
+            if (seenNulls) {
+                blockBuilder.appendNull();
+            }
             while (pq.size() > 0) {
                 final AggregatedResultIterator top = pq.top();
-                if (position == -1 || lastTerm.get().equals(top.currentTerm) == false) {
+                if (position == startPosition || lastTerm.get().equals(top.currentTerm) == false) {
                     position++;
                     lastTerm.copyBytes(top.currentTerm);
                     blockBuilder.appendBytesRef(top.currentTerm);
@@ -338,11 +350,8 @@ public class OrdinalsGroupingOperator implements Operator {
                 if (BlockOrdinalsReader.canReuse(currentReader, docs.getInt(0)) == false) {
                     currentReader = new BlockOrdinalsReader(withOrdinals.ordinalsValues(leafReaderContext));
                 }
-                final IntBlock ordinals = currentReader.readOrdinals(docs);
+                final IntBlock ordinals = currentReader.readOrdinalsAdded1(docs);
                 for (int p = 0; p < ordinals.getPositionCount(); p++) {
-                    if (ordinals.isNull(p)) {
-                        continue;
-                    }
                     int start = ordinals.getFirstValueIndex(p);
                     int end = start + ordinals.getValueCount(p);
                     for (int i = start; i < end; i++) {
@@ -350,8 +359,8 @@ public class OrdinalsGroupingOperator implements Operator {
                         visitedOrds.set(ord);
                     }
                 }
-                for (GroupingAggregator aggregator : aggregators) {
-                    aggregator.prepareProcessPage(this, page).add(0, ordinals);
+                for (GroupingAggregatorFunction.AddInput addInput : prepared) {
+                    addInput.add(0, ordinals);
                 }
             } catch (IOException e) {
                 throw new UncheckedIOException(e);
@@ -362,6 +371,10 @@ public class OrdinalsGroupingOperator implements Operator {
             return new AggregatedResultIterator(aggregators, visitedOrds, withOrdinals.ordinalsValues(leafReaderContext));
         }
 
+        boolean seenNulls() {
+            return visitedOrds.get(0);
+        }
+
         @Override
         public BitArray seenGroupIds(BigArrays bigArrays) {
             BitArray seen = new BitArray(0, bigArrays);
@@ -377,7 +390,7 @@ public class OrdinalsGroupingOperator implements Operator {
 
     private static class AggregatedResultIterator {
         private BytesRef currentTerm;
-        private long currentOrd = -1;
+        private long currentOrd = 0;
         private final List<GroupingAggregator> aggregators;
         private final BitArray ords;
         private final SortedSetDocValues dv;
@@ -395,8 +408,9 @@ public class OrdinalsGroupingOperator implements Operator {
 
         boolean next() throws IOException {
             currentOrd = ords.nextSetBit(currentOrd + 1);
+            assert currentOrd > 0 : currentOrd;
             if (currentOrd < Long.MAX_VALUE) {
-                currentTerm = dv.lookupOrd(currentOrd);
+                currentTerm = dv.lookupOrd(currentOrd - 1);
                 return true;
             } else {
                 currentTerm = null;
@@ -448,4 +462,49 @@ public class OrdinalsGroupingOperator implements Operator {
             Releasables.close(extractor, aggregator);
         }
     }
+
+    static final class BlockOrdinalsReader {
+        private final SortedSetDocValues sortedSetDocValues;
+        private final Thread creationThread;
+
+        BlockOrdinalsReader(SortedSetDocValues sortedSetDocValues) {
+            this.sortedSetDocValues = sortedSetDocValues;
+            this.creationThread = Thread.currentThread();
+        }
+
+        IntBlock readOrdinalsAdded1(IntVector docs) throws IOException {
+            final int positionCount = docs.getPositionCount();
+            IntBlock.Builder builder = IntBlock.newBlockBuilder(positionCount);
+            for (int p = 0; p < positionCount; p++) {
+                int doc = docs.getInt(p);
+                if (false == sortedSetDocValues.advanceExact(doc)) {
+                    builder.appendInt(0);
+                    continue;
+                }
+                int count = sortedSetDocValues.docValueCount();
+                // TODO don't come this way if there are a zillion ords on the field
+                if (count == 1) {
+                    builder.appendInt(Math.toIntExact(sortedSetDocValues.nextOrd() + 1));
+                    continue;
+                }
+                builder.beginPositionEntry();
+                for (int i = 0; i < count; i++) {
+                    builder.appendInt(Math.toIntExact(sortedSetDocValues.nextOrd() + 1));
+                }
+                builder.endPositionEntry();
+            }
+            return builder.build();
+        }
+
+        int docID() {
+            return sortedSetDocValues.docID();
+        }
+
+        /**
+         * Checks if the reader can be used to read a range documents starting with the given docID by the current thread.
+         */
+        static boolean canReuse(BlockOrdinalsReader reader, int startingDocID) {
+            return reader != null && reader.creationThread == Thread.currentThread() && reader.docID() <= startingDocID;
+        }
+    }
 }

+ 9 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

@@ -569,3 +569,12 @@ ca:l | cx:l | l:i
 1    | 1    | 5
 1    | 1    | null
 ;
+
+aggsWithoutStats
+from employees | stats by gender | sort gender;
+
+gender:keyword
+F
+M
+null
+;

+ 4 - 24
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.esql.action;
 import org.elasticsearch.Build;
 import org.elasticsearch.action.admin.indices.alias.IndicesAliasesRequest;
 import org.elasticsearch.action.bulk.BulkRequestBuilder;
-import org.elasticsearch.action.delete.DeleteRequest;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.index.IndexRequestBuilder;
 import org.elasticsearch.action.support.WriteRequest;
@@ -34,6 +33,7 @@ import org.junit.Before;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Locale;
@@ -265,7 +265,7 @@ public class EsqlActionIT extends AbstractEsqlIntegTestCase {
             EsqlQueryResponse results = run("from test | stats avg = avg(" + field + ") by color");
             logger.info(results);
             Assert.assertEquals(2, results.columns().size());
-            Assert.assertEquals(4, getValuesList(results).size());
+            Assert.assertEquals(5, getValuesList(results).size());
 
             // assert column metadata
             assertEquals("avg", results.columns().get(0).name());
@@ -276,6 +276,7 @@ public class EsqlActionIT extends AbstractEsqlIntegTestCase {
 
             }
             List<Group> expectedGroups = List.of(
+                new Group(null, 120.0),
                 new Group("blue", 42.0),
                 new Group("green", 44.0),
                 new Group("red", 43.0),
@@ -283,18 +284,10 @@ public class EsqlActionIT extends AbstractEsqlIntegTestCase {
             );
             List<Group> actualGroups = getValuesList(results).stream()
                 .map(l -> new Group((String) l.get(1), (Double) l.get(0)))
-                .sorted(comparing(c -> c.color))
+                .sorted(Comparator.comparing(c -> c.color, Comparator.nullsFirst(String::compareTo)))
                 .toList();
             assertThat(actualGroups, equalTo(expectedGroups));
         }
-        for (int i = 0; i < 5; i++) {
-            client().prepareBulk()
-                .add(new DeleteRequest("test").id("no_color_" + i))
-                .add(new DeleteRequest("test").id("no_count_red_" + i))
-                .add(new DeleteRequest("test").id("no_count_yellow_" + i))
-                .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
-                .get();
-        }
     }
 
     public void testFromStatsMultipleAggs() {
@@ -562,11 +555,6 @@ public class EsqlActionIT extends AbstractEsqlIntegTestCase {
         assertThat(results.columns(), hasItem(equalTo(new ColumnInfo("data", "long"))));
         assertThat(results.columns(), hasItem(equalTo(new ColumnInfo("data_d", "double"))));
         assertThat(results.columns(), hasItem(equalTo(new ColumnInfo("time", "long"))));
-
-        // restore index to original pre-test state
-        client().prepareBulk().add(new DeleteRequest("test").id("no_count")).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get();
-        results = run("from test");
-        Assert.assertEquals(40, getValuesList(results).size());
     }
 
     public void testMultiConditionalWhere() {
@@ -963,9 +951,6 @@ public class EsqlActionIT extends AbstractEsqlIntegTestCase {
     }
 
     public void testTopNPushedToLucene() {
-        BulkRequestBuilder bulkDelete = client().prepareBulk();
-        bulkDelete.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
-
         for (int i = 5; i < 11; i++) {
             var yellowDocId = "yellow_" + i;
             var yellowNullCountDocId = "yellow_null_count_" + i;
@@ -979,11 +964,6 @@ public class EsqlActionIT extends AbstractEsqlIntegTestCase {
             if (randomBoolean()) {
                 client().admin().indices().prepareRefresh("test").get();
             }
-
-            // build the cleanup request now, as well, not to miss anything ;-)
-            bulkDelete.add(new DeleteRequest("test").id(yellowDocId))
-                .add(new DeleteRequest("test").id(yellowNullCountDocId))
-                .add(new DeleteRequest("test").id(yellowNullDataDocId));
         }
         client().admin().indices().prepareRefresh("test").get();