Browse Source

ESQL: Test partially filtered aggs (#114510) (#114654)

Tests for partially filtered aggs. It uses the existing aggs tests and
adds junk rows that are filtered away. That way we don't have to add new
testing assertions to each class - we just can reuse the existing
assertions.
Nik Everett 1 year ago
parent
commit
0e2f832516

+ 17 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/AggregatorFunctionTestCase.java

@@ -22,6 +22,7 @@ import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.data.TestBlockFactory;
+import org.elasticsearch.compute.operator.AddGarbageRowsSourceOperator;
 import org.elasticsearch.compute.operator.AggregationOperator;
 import org.elasticsearch.compute.operator.CannedSourceOperator;
 import org.elasticsearch.compute.operator.Driver;
@@ -203,6 +204,22 @@ public abstract class AggregatorFunctionTestCase extends ForkingOperatorTestCase
         assertSimpleOutput(origInput, results);
     }
 
+    public void testSomeFiltered() {
+        Operator.OperatorFactory factory = simpleWithMode(
+            AggregatorMode.SINGLE,
+            agg -> new FilteredAggregatorFunctionSupplier(agg, AddGarbageRowsSourceOperator.filterFactory())
+        );
+        DriverContext driverContext = driverContext();
+        // Build the test data
+        List<Page> input = CannedSourceOperator.collectPages(simpleInput(driverContext.blockFactory(), 10));
+        List<Page> origInput = BlockTestUtils.deepCopyOf(input, TestBlockFactory.getNonBreakingInstance());
+        // Sprinkle garbage into it
+        input = CannedSourceOperator.collectPages(new AddGarbageRowsSourceOperator(new CannedSourceOperator(input.iterator())));
+        List<Page> results = drive(factory.get(driverContext), input.iterator(), driverContext);
+        assertThat(results, hasSize(1));
+        assertSimpleOutput(origInput, results);
+    }
+
     // Returns an intermediate state that is equivalent to what the local execution planner will emit
     // if it determines that certain shards have no relevant data.
     List<Page> nullIntermediateState(BlockFactory blockFactory) {

+ 5 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/FilteredAggregatorFunctionTests.java

@@ -103,4 +103,9 @@ public class FilteredAggregatorFunctionTests extends AggregatorFunctionTestCase
     public void testAllFiltered() {
         assumeFalse("can't double filter. tests already filter.", true);
     }
+
+    @Override
+    public void testSomeFiltered() {
+        assumeFalse("can't double filter. tests already filter.", true);
+    }
 }

+ 26 - 1
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java

@@ -26,6 +26,7 @@ import org.elasticsearch.compute.data.LongBlock;
 import org.elasticsearch.compute.data.LongVector;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.data.TestBlockFactory;
+import org.elasticsearch.compute.operator.AddGarbageRowsSourceOperator;
 import org.elasticsearch.compute.operator.CannedSourceOperator;
 import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.compute.operator.ForkingOperatorTestCase;
@@ -53,6 +54,7 @@ import static java.util.stream.IntStream.range;
 import static org.elasticsearch.compute.data.BlockTestUtils.append;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.in;
 
 /**
  * Shared tests for testing grouped aggregations.
@@ -160,11 +162,17 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
 
     @Override
     protected final void assertSimpleOutput(List<Page> input, List<Page> results) {
+        assertSimpleOutput(input, results, true);
+    }
+
+    private void assertSimpleOutput(List<Page> input, List<Page> results, boolean assertGroupCount) {
         SeenGroups seenGroups = seenGroups(input);
 
         assertThat(results, hasSize(1));
         assertThat(results.get(0).getBlockCount(), equalTo(2));
-        assertThat(results.get(0).getPositionCount(), equalTo(seenGroups.size()));
+        if (assertGroupCount) {
+            assertThat(results.get(0).getPositionCount(), equalTo(seenGroups.size()));
+        }
 
         Block groups = results.get(0).getBlock(0);
         Block result = results.get(0).getBlock(1);
@@ -394,6 +402,23 @@ public abstract class GroupingAggregatorFunctionTestCase extends ForkingOperator
         assertSimpleOutput(origInput, results);
     }
 
+    public void testSomeFiltered() {
+        Operator.OperatorFactory factory = simpleWithMode(
+            AggregatorMode.SINGLE,
+            agg -> new FilteredAggregatorFunctionSupplier(agg, AddGarbageRowsSourceOperator.filterFactory())
+        );
+        DriverContext driverContext = driverContext();
+        // Build the test data
+        List<Page> input = CannedSourceOperator.collectPages(simpleInput(driverContext.blockFactory(), 10));
+        List<Page> origInput = BlockTestUtils.deepCopyOf(input, TestBlockFactory.getNonBreakingInstance());
+        // Sprinkle garbage into it
+        input = CannedSourceOperator.collectPages(new AddGarbageRowsSourceOperator(new CannedSourceOperator(input.iterator())));
+        List<Page> results = drive(factory.get(driverContext), input.iterator(), driverContext);
+        assertThat(results, hasSize(1));
+
+        assertSimpleOutput(origInput, results, false);
+    }
+
     /**
      * Asserts that the output from an empty input is a {@link Block} containing
      * only {@code null}. Override for {@code count} style aggregations that

+ 133 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AddGarbageRowsSourceOperator.java

@@ -0,0 +1,133 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.operator;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BooleanBlock;
+import org.elasticsearch.compute.data.BytesRefBlock;
+import org.elasticsearch.compute.data.DoubleBlock;
+import org.elasticsearch.compute.data.FloatBlock;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.LongBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.test.ESTestCase;
+
+/**
+ * A {@link SourceOperator} that inserts random garbage into data from another
+ * {@link SourceOperator}. It also inserts an extra channel at the end of the page
+ * containing a {@code boolean} column. If it is {@code true} then the data came
+ * from the original operator. If it's {@code false} then the data is random
+ * garbage inserted by this operator.
+ */
+public class AddGarbageRowsSourceOperator extends SourceOperator {
+    public static EvalOperator.ExpressionEvaluator.Factory filterFactory() {
+        /*
+         * Grabs the filter from the last block. That's where we put it.
+         */
+        return ctx -> new EvalOperator.ExpressionEvaluator() {
+            @Override
+            public Block eval(Page page) {
+                Block block = page.getBlock(page.getBlockCount() - 1);
+                block.incRef();
+                return block;
+            }
+
+            @Override
+            public void close() {}
+        };
+    }
+
+    private final SourceOperator next;
+
+    public AddGarbageRowsSourceOperator(SourceOperator next) {
+        this.next = next;
+    }
+
+    @Override
+    public void finish() {
+        next.finish();
+    }
+
+    @Override
+    public boolean isFinished() {
+        return next.isFinished();
+    }
+
+    @Override
+    public Page getOutput() {
+        Page page = next.getOutput();
+        if (page == null) {
+            return null;
+        }
+        Block.Builder[] newBlocks = new Block.Builder[page.getBlockCount() + 1];
+        try {
+            for (int b = 0; b < page.getBlockCount(); b++) {
+                Block block = page.getBlock(b);
+                newBlocks[b] = block.elementType().newBlockBuilder(page.getPositionCount(), block.blockFactory());
+            }
+            newBlocks[page.getBlockCount()] = page.getBlock(0).blockFactory().newBooleanBlockBuilder(page.getPositionCount());
+
+            for (int p = 0; p < page.getPositionCount(); p++) {
+                if (ESTestCase.randomBoolean()) {
+                    insertGarbageRows(newBlocks, page);
+                }
+                copyPosition(newBlocks, page, p);
+                if (ESTestCase.randomBoolean()) {
+                    insertGarbageRows(newBlocks, page);
+                }
+            }
+
+            return new Page(Block.Builder.buildAll(newBlocks));
+        } finally {
+            Releasables.close(Releasables.wrap(newBlocks), page::releaseBlocks);
+        }
+    }
+
+    private void copyPosition(Block.Builder[] newBlocks, Page page, int p) {
+        for (int b = 0; b < page.getBlockCount(); b++) {
+            Block block = page.getBlock(b);
+            newBlocks[b].copyFrom(block, p, p + 1);
+        }
+        signalKeep(newBlocks, true);
+    }
+
+    private void insertGarbageRows(Block.Builder[] newBlocks, Page page) {
+        int count = ESTestCase.between(1, 5);
+        for (int c = 0; c < count; c++) {
+            insertGarbageRow(newBlocks, page);
+        }
+    }
+
+    private void insertGarbageRow(Block.Builder[] newBlocks, Page page) {
+        for (int b = 0; b < page.getBlockCount(); b++) {
+            Block block = page.getBlock(b);
+            switch (block.elementType()) {
+                case BOOLEAN -> ((BooleanBlock.Builder) newBlocks[b]).appendBoolean(ESTestCase.randomBoolean());
+                case BYTES_REF -> ((BytesRefBlock.Builder) newBlocks[b]).appendBytesRef(new BytesRef(ESTestCase.randomAlphaOfLength(5)));
+                case COMPOSITE, DOC, UNKNOWN -> throw new UnsupportedOperationException();
+                case INT -> ((IntBlock.Builder) newBlocks[b]).appendInt(ESTestCase.randomInt());
+                case LONG -> ((LongBlock.Builder) newBlocks[b]).appendLong(ESTestCase.randomLong());
+                case NULL -> newBlocks[b].appendNull();
+                case DOUBLE -> ((DoubleBlock.Builder) newBlocks[b]).appendDouble(ESTestCase.randomDouble());
+                case FLOAT -> ((FloatBlock.Builder) newBlocks[b]).appendFloat(ESTestCase.randomFloat());
+            }
+        }
+        signalKeep(newBlocks, false);
+    }
+
+    private void signalKeep(Block.Builder[] newBlocks, boolean shouldKeep) {
+        ((BooleanBlock.Builder) newBlocks[newBlocks.length - 1]).appendBoolean(shouldKeep);
+    }
+
+    @Override
+    public void close() {
+        next.close();
+    }
+}