Browse Source

ESQL: Preserve the projection of an aggregation (#99936)

The internal AggregateOperators are focused creating groups and
 performing the aggregations using their intermediate states. However
 the logical projection is not enforced leading to errors downstream.
 Currently this is handled through OutputOperator which forces an
 alignment between the layout and the attributes however this is not
 only fragile but also insufficient.

This commit fixes a couple of things:
- introduces a ProjectOperator in front of the AggregateOperator.
 This is much easier, at least at the moment as it keeps the agg logic
 untouched.
- improves the ProjectOperator by moving away from the BitSet to a List
 that handles not just the selection but also ordering

Fix #99782
Costin Leau 2 years ago
parent
commit
7fdba1a283

+ 1 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java

@@ -90,6 +90,7 @@ public class AggregationOperator implements Operator {
         }
         finished = true;
         int[] aggBlockCounts = aggregators.stream().mapToInt(Aggregator::evaluateBlockCount).toArray();
+        // TODO: look into allocating the blocks lazily
         Block[] blocks = new Block[Arrays.stream(aggBlockCounts).sum()];
         int offset = 0;
         for (int i = 0; i < aggregators.size(); i++) {

+ 41 - 21
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ProjectOperator.java

@@ -14,65 +14,85 @@ import org.elasticsearch.core.Releasables;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.BitSet;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 
 public class ProjectOperator extends AbstractPageMappingOperator {
 
-    private final BitSet bs;
+    private final Set<Integer> pagesUsed;
+    private final int[] projection;
     private Block[] blocks;
 
-    public record ProjectOperatorFactory(BitSet mask) implements OperatorFactory {
+    public record ProjectOperatorFactory(List<Integer> projection) implements OperatorFactory {
 
         @Override
         public Operator get(DriverContext driverContext) {
-            return new ProjectOperator(mask);
+            return new ProjectOperator(projection);
         }
 
         @Override
         public String describe() {
-            return "ProjectOperator[mask = " + mask + "]";
+            return "ProjectOperator[projection = " + projection + "]";
         }
     }
 
     /**
-     * Creates a project that applies the given mask (as a bitset).
+     * Creates an operator that applies the given projection, encoded as an integer list where
+     * the ordinal indicates the output order and the value, the backing channel that to be used.
+     * Given the input {a,b,c,d}, project {a,d,a} is encoded as {0,3,0}.
      *
-     * @param mask bitset mask for enabling/disabling blocks / columns inside a Page
+     * @param projection list of blocks to keep and their order.
      */
-    public ProjectOperator(BitSet mask) {
-        this.bs = mask;
+    public ProjectOperator(List<Integer> projection) {
+        this.pagesUsed = new HashSet<>(projection);
+        this.projection = projection.stream().mapToInt(Integer::intValue).toArray();
     }
 
     @Override
     protected Page process(Page page) {
-        if (page.getBlockCount() == 0) {
+        var blockCount = page.getBlockCount();
+        if (blockCount == 0) {
             return page;
         }
         if (blocks == null) {
-            blocks = new Block[bs.cardinality()];
+            blocks = new Block[projection.length];
         }
 
         Arrays.fill(blocks, null);
         int b = 0;
-        int positionCount = page.getPositionCount();
+        for (int source : projection) {
+            if (source >= blockCount) {
+                throw new IllegalArgumentException(
+                    "Cannot project block with index [" + source + "] from a page with size [" + blockCount + "]"
+                );
+            }
+            var block = page.getBlock(source);
+            blocks[b++] = block;
+        }
+        // iterate the blocks to see which one isn't used
         List<Releasable> blocksToRelease = new ArrayList<>();
-        for (int i = 0; i < page.getBlockCount(); i++) {
-            var block = page.getBlock(i);
-            if (bs.get(i)) {
-                assertNotReleasing(blocksToRelease, block);
-                blocks[b++] = block;
-            } else {
-                blocksToRelease.add(block);
+
+        for (int i = 0; i < blockCount; i++) {
+            boolean used = false;
+            var current = page.getBlock(i);
+            for (int j = 0; j < blocks.length; j++) {
+                if (current == blocks[j]) {
+                    used = true;
+                    break;
+                }
+            }
+            if (used == false) {
+                blocksToRelease.add(current);
             }
         }
         Releasables.close(blocksToRelease);
-        return new Page(positionCount, blocks);
+        return new Page(page.getPositionCount(), blocks);
     }
 
     @Override
     public String toString() {
-        return "ProjectOperator[mask = " + bs + ']';
+        return "ProjectOperator[projection = " + Arrays.toString(projection) + ']';
     }
 
     static void assertNotReleasing(List<Releasable> toRelease, Block toKeep) {

+ 11 - 20
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/ProjectOperatorTests.java

@@ -22,12 +22,13 @@ import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.junit.After;
 import org.junit.Before;
 
-import java.util.BitSet;
+import java.util.Arrays;
 import java.util.List;
 import java.util.stream.LongStream;
 
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -50,7 +51,7 @@ public class ProjectOperatorTests extends OperatorTestCase {
 
     public void testProjectionOnEmptyPage() {
         var page = new Page(0);
-        var projection = new ProjectOperator(randomMask(randomIntBetween(2, 10)));
+        var projection = new ProjectOperator(randomProjection(10));
         projection.addInput(page);
         assertEquals(page, projection.getOutput());
     }
@@ -63,30 +64,22 @@ public class ProjectOperatorTests extends OperatorTestCase {
         }
 
         var page = new Page(size, blocks);
-        var mask = randomMask(size);
+        var randomProjection = randomProjection(size);
 
-        var projection = new ProjectOperator(mask);
+        var projection = new ProjectOperator(randomProjection);
         projection.addInput(page);
         var out = projection.getOutput();
-        assertEquals(mask.cardinality(), out.getBlockCount());
+        assertThat(randomProjection.size(), lessThanOrEqualTo(out.getBlockCount()));
 
-        int lastSetIndex = -1;
         for (int i = 0; i < out.getBlockCount(); i++) {
             var block = out.<IntBlock>getBlock(i);
-            var shouldBeSetInMask = block.getInt(0);
-            assertTrue(mask.get(shouldBeSetInMask));
-            lastSetIndex = mask.nextSetBit(lastSetIndex + 1);
-            assertEquals(shouldBeSetInMask, lastSetIndex);
+            assertEquals(block, page.getBlock(randomProjection.get(i)));
             block.close();
         }
     }
 
-    private BitSet randomMask(int size) {
-        var mask = new BitSet(size);
-        for (int i = 0; i < size; i++) {
-            mask.set(i, randomBoolean());
-        }
-        return mask;
+    private List<Integer> randomProjection(int size) {
+        return randomList(size, () -> randomIntBetween(0, size - 1));
     }
 
     @Override
@@ -96,14 +89,12 @@ public class ProjectOperatorTests extends OperatorTestCase {
 
     @Override
     protected Operator.OperatorFactory simple(BigArrays bigArrays) {
-        BitSet mask = new BitSet();
-        mask.set(1, true);
-        return new ProjectOperator.ProjectOperatorFactory(mask);
+        return new ProjectOperator.ProjectOperatorFactory(Arrays.asList(1));
     }
 
     @Override
     protected String expectedDescriptionOfSimple() {
-        return "ProjectOperator[mask = {1}]";
+        return "ProjectOperator[projection = [1]]";
     }
 
     @Override

+ 8 - 6
x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java

@@ -441,21 +441,23 @@ public final class CsvTestUtils {
         logger.info(sb.toString());
     }
 
-    static void logData(List<List<Object>> values, Logger logger) {
-        for (List<Object> list : values) {
-            logger.info(rowAsString(list));
+    static void logData(Iterator<Iterator<Object>> values, Logger logger) {
+        while (values.hasNext()) {
+            var val = values.next();
+            logger.info(rowAsString(val));
         }
     }
 
-    private static String rowAsString(List<Object> list) {
+    private static String rowAsString(Iterator<Object> iterator) {
         StringBuilder sb = new StringBuilder();
         StringBuilder column = new StringBuilder();
-        for (int i = 0; i < list.size(); i++) {
+        for (int i = 0; iterator.hasNext(); i++) {
             column.setLength(0);
             if (i > 0) {
                 sb.append(" | ");
             }
-            sb.append(trimOrPad(column.append(list.get(i))));
+            var next = iterator.next();
+            sb.append(trimOrPad(column.append(next)));
         }
         return sb.toString();
     }

+ 32 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

@@ -413,6 +413,36 @@ c:long | languages:integer | still_hired:boolean
      4 |              null | true
 ;
 
+byUnmentionedIntAndBooleanFollowedByProjection
+from employees | stats c = count(gender) by languages, still_hired | where languages > 3 | sort languages | keep languages;
+
+languages:integer 
+                4 
+                4
+                5
+                5 
+;
+
+byTwoGroupReturnedInDifferentOrder
+from employees | stats c = count(emp_no) by gender, languages | rename languages as l, gender as g | where l > 3 | keep g, l | sort g, l;
+
+g:keyword  | l:integer
+ F         | 4        
+ F         | 5
+ M         | 4        
+ M         | 5
+ null      | 4
+ null      | 5
+;
+
+repetitiveAggregation
+from employees | stats m1 = max(salary), m2 = min(salary), m3 = min(salary), m4 = max(salary);
+
+m1:i | m2:i | m3:i | m4:i
+74999| 25324| 25324| 74999
+;
+
+
 byDateAndKeywordAndInt
 from employees | eval d = date_trunc(1 year, hire_date) | stats c = count(emp_no) by d, gender, languages | sort c desc, d, languages desc, gender desc | limit 10;
 
@@ -501,3 +531,5 @@ from employees | limit 10 | eval x = 1 | stats c = count(x);
 c:l
 10
 ;
+
+

+ 3 - 2
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java

@@ -213,13 +213,14 @@ public class EsqlQueryResponse extends ActionResponse implements ChunkedToXConte
                  */
                 int count = block.getValueCount(p);
                 int start = block.getFirstValueIndex(p);
+                String dataType = dataTypes.get(b);
                 if (count == 1) {
-                    return valueAt(dataTypes.get(b), block, start, scratch);
+                    return valueAt(dataType, block, start, scratch);
                 }
                 List<Object> thisResult = new ArrayList<>(count);
                 int end = count + start;
                 for (int i = start; i < end; i++) {
-                    thisResult.add(valueAt(dataTypes.get(b), block, i, scratch));
+                    thisResult.add(valueAt(dataType, block, i, scratch));
                 }
                 return thisResult;
             }))

+ 8 - 5
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java

@@ -66,7 +66,6 @@ import org.elasticsearch.xpack.ql.expression.NamedExpression;
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.BitSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -245,10 +244,14 @@ public class EnrichLookupService {
     }
 
     private static Operator droppingBlockOperator(int totalBlocks, int droppingPosition) {
-        BitSet bitSet = new BitSet(totalBlocks);
-        bitSet.set(0, totalBlocks);
-        bitSet.clear(droppingPosition);
-        return new ProjectOperator(bitSet);
+        var size = totalBlocks - 1;
+        var projection = new ArrayList<Integer>(size);
+        for (int i = 0; i < totalBlocks; i++) {
+            if (i != droppingPosition) {
+                projection.add(i);
+            }
+        }
+        return new ProjectOperator(projection);
     }
 
     private class TransportHandler implements TransportRequestHandler<LookupRequest> {

+ 15 - 19
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java

@@ -85,7 +85,6 @@ import org.elasticsearch.xpack.ql.util.Holder;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.BitSet;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -152,6 +151,12 @@ public class LocalExecutionPlanner {
             blockFactory
         );
 
+        // workaround for https://github.com/elastic/elasticsearch/issues/99782
+        node = node.transformUp(
+            AggregateExec.class,
+            a -> a.getMode() == AggregateExec.Mode.FINAL ? new ProjectExec(a.source(), a, Expressions.asAttributes(a.aggregates())) : a
+        );
+
         PhysicalOperation physicalOperation = plan(node, context);
 
         context.addDriverFactory(
@@ -506,9 +511,14 @@ public class LocalExecutionPlanner {
 
     private PhysicalOperation planProject(ProjectExec project, LocalExecutionPlannerContext context) {
         var source = plan(project.child(), context);
+        List<? extends NamedExpression> projections = project.projections();
+        List<Integer> projectionList = new ArrayList<>(projections.size());
 
+        Layout.Builder layout = new Layout.Builder();
         Map<Integer, Layout.ChannelSet> inputChannelToOutputIds = new HashMap<>();
-        for (NamedExpression ne : project.projections()) {
+        for (int index = 0, size = projections.size(); index < size; index++) {
+            NamedExpression ne = projections.get(index);
+
             NameId inputId;
             if (ne instanceof Alias a) {
                 inputId = ((NamedExpression) a.child()).id();
@@ -524,26 +534,12 @@ public class LocalExecutionPlanner {
                 throw new IllegalArgumentException("type mismatch for aliases");
             }
             channelSet.nameIds().add(ne.id());
-        }
 
-        BitSet mask = new BitSet();
-        Layout.Builder layout = new Layout.Builder();
-
-        for (int inChannel = 0; inChannel < source.layout.numberOfChannels(); inChannel++) {
-            Layout.ChannelSet outputSet = inputChannelToOutputIds.get(inChannel);
-
-            if (outputSet != null) {
-                mask.set(inChannel);
-                layout.append(outputSet);
-            }
+            layout.append(channelSet);
+            projectionList.add(input.channel());
         }
 
-        if (mask.cardinality() == source.layout.numberOfChannels()) {
-            // all columns are retained, project operator is not needed but the layout needs to be updated
-            return source.with(layout.build());
-        } else {
-            return source.with(new ProjectOperatorFactory(mask), layout.build());
-        }
+        return source.with(new ProjectOperatorFactory(projectionList), layout.build());
     }
 
     private PhysicalOperation planFilter(FilterExec filter, LocalExecutionPlannerContext context) {