Browse Source

ESQL: Add boolean support to TOP aggregation (#110718)

- Added a custom implementation of BooleanBucketedSort to keep the top booleans
- Added boolean aggregator to TOP
- Added tests (Boolean aggregator tests, Top tests for boolean, and added boolean fields to CSV cases)
Iván Cea Fontenla 1 year ago
parent
commit
43a3af66e8
17 changed files with 935 additions and 13 deletions
  1. 5 0
      docs/changelog/110718.yaml
  2. 24 0
      docs/reference/esql/functions/kibana/definition/top.json
  3. 1 0
      docs/reference/esql/functions/types/top.asciidoc
  4. 126 0
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanAggregatorFunction.java
  5. 45 0
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanAggregatorFunctionSupplier.java
  6. 202 0
      x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java
  7. 137 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java
  8. 198 0
      x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/BooleanBucketedSort.java
  9. 44 0
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopBooleanAggregatorFunctionTests.java
  10. 62 0
      x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BooleanBucketedSortTests.java
  11. 3 3
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec
  12. 40 4
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec
  13. 5 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
  14. 7 3
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java
  15. 1 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java
  16. 6 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java
  17. 29 2
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java

+ 5 - 0
docs/changelog/110718.yaml

@@ -0,0 +1,5 @@
+pr: 110718
+summary: "ESQL: Add boolean support to TOP aggregation"
+area: ES|QL
+type: feature
+issues: []

+ 24 - 0
docs/reference/esql/functions/kibana/definition/top.json

@@ -4,6 +4,30 @@
   "name" : "top",
   "description" : "Collects the top values for a field. Includes repeated values.",
   "signatures" : [
+    {
+      "params" : [
+        {
+          "name" : "field",
+          "type" : "boolean",
+          "optional" : false,
+          "description" : "The field to collect the top values for."
+        },
+        {
+          "name" : "limit",
+          "type" : "integer",
+          "optional" : false,
+          "description" : "The maximum number of values to collect."
+        },
+        {
+          "name" : "order",
+          "type" : "keyword",
+          "optional" : false,
+          "description" : "The order to calculate the top values. Either `asc` or `desc`."
+        }
+      ],
+      "variadic" : false,
+      "returnType" : "boolean"
+    },
     {
       "params" : [
         {

+ 1 - 0
docs/reference/esql/functions/types/top.asciidoc

@@ -5,6 +5,7 @@
 [%header.monospaced.styled,format=dsv,separator=|]
 |===
 field | limit | order | result
+boolean | integer | keyword | boolean
 datetime | integer | keyword | datetime
 double | integer | keyword | double
 integer | integer | keyword | integer

+ 126 - 0
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanAggregatorFunction.java

@@ -0,0 +1,126 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.lang.StringBuilder;
+import java.util.List;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BooleanBlock;
+import org.elasticsearch.compute.data.BooleanVector;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link AggregatorFunction} implementation for {@link TopBooleanAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopBooleanAggregatorFunction implements AggregatorFunction {
+  private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
+      new IntermediateStateDesc("top", ElementType.BOOLEAN)  );
+
+  private final DriverContext driverContext;
+
+  private final TopBooleanAggregator.SingleState state;
+
+  private final List<Integer> channels;
+
+  private final int limit;
+
+  private final boolean ascending;
+
+  public TopBooleanAggregatorFunction(DriverContext driverContext, List<Integer> channels,
+      TopBooleanAggregator.SingleState state, int limit, boolean ascending) {
+    this.driverContext = driverContext;
+    this.channels = channels;
+    this.state = state;
+    this.limit = limit;
+    this.ascending = ascending;
+  }
+
+  public static TopBooleanAggregatorFunction create(DriverContext driverContext,
+      List<Integer> channels, int limit, boolean ascending) {
+    return new TopBooleanAggregatorFunction(driverContext, channels, TopBooleanAggregator.initSingle(driverContext.bigArrays(), limit, ascending), limit, ascending);
+  }
+
+  public static List<IntermediateStateDesc> intermediateStateDesc() {
+    return INTERMEDIATE_STATE_DESC;
+  }
+
+  @Override
+  public int intermediateBlockCount() {
+    return INTERMEDIATE_STATE_DESC.size();
+  }
+
+  @Override
+  public void addRawInput(Page page) {
+    BooleanBlock block = page.getBlock(channels.get(0));
+    BooleanVector vector = block.asVector();
+    if (vector != null) {
+      addRawVector(vector);
+    } else {
+      addRawBlock(block);
+    }
+  }
+
+  private void addRawVector(BooleanVector vector) {
+    for (int i = 0; i < vector.getPositionCount(); i++) {
+      TopBooleanAggregator.combine(state, vector.getBoolean(i));
+    }
+  }
+
+  private void addRawBlock(BooleanBlock block) {
+    for (int p = 0; p < block.getPositionCount(); p++) {
+      if (block.isNull(p)) {
+        continue;
+      }
+      int start = block.getFirstValueIndex(p);
+      int end = start + block.getValueCount(p);
+      for (int i = start; i < end; i++) {
+        TopBooleanAggregator.combine(state, block.getBoolean(i));
+      }
+    }
+  }
+
+  @Override
+  public void addIntermediateInput(Page page) {
+    assert channels.size() == intermediateBlockCount();
+    assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size();
+    Block topUncast = page.getBlock(channels.get(0));
+    if (topUncast.areAllValuesNull()) {
+      return;
+    }
+    BooleanBlock top = (BooleanBlock) topUncast;
+    assert top.getPositionCount() == 1;
+    TopBooleanAggregator.combineIntermediate(state, top);
+  }
+
+  @Override
+  public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+    state.toIntermediate(blocks, offset, driverContext);
+  }
+
+  @Override
+  public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) {
+    blocks[offset] = TopBooleanAggregator.evaluateFinal(state, driverContext);
+  }
+
+  @Override
+  public String toString() {
+    StringBuilder sb = new StringBuilder();
+    sb.append(getClass().getSimpleName()).append("[");
+    sb.append("channels=").append(channels);
+    sb.append("]");
+    return sb.toString();
+  }
+
+  @Override
+  public void close() {
+    state.close();
+  }
+}

+ 45 - 0
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanAggregatorFunctionSupplier.java

@@ -0,0 +1,45 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.util.List;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link AggregatorFunctionSupplier} implementation for {@link TopBooleanAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopBooleanAggregatorFunctionSupplier implements AggregatorFunctionSupplier {
+  private final List<Integer> channels;
+
+  private final int limit;
+
+  private final boolean ascending;
+
+  public TopBooleanAggregatorFunctionSupplier(List<Integer> channels, int limit,
+      boolean ascending) {
+    this.channels = channels;
+    this.limit = limit;
+    this.ascending = ascending;
+  }
+
+  @Override
+  public TopBooleanAggregatorFunction aggregator(DriverContext driverContext) {
+    return TopBooleanAggregatorFunction.create(driverContext, channels, limit, ascending);
+  }
+
+  @Override
+  public TopBooleanGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) {
+    return TopBooleanGroupingAggregatorFunction.create(channels, driverContext, limit, ascending);
+  }
+
+  @Override
+  public String describe() {
+    return "top of booleans";
+  }
+}

+ 202 - 0
x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/TopBooleanGroupingAggregatorFunction.java

@@ -0,0 +1,202 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.compute.aggregation;
+
+import java.lang.Integer;
+import java.lang.Override;
+import java.lang.String;
+import java.lang.StringBuilder;
+import java.util.List;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BooleanBlock;
+import org.elasticsearch.compute.data.BooleanVector;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+
+/**
+ * {@link GroupingAggregatorFunction} implementation for {@link TopBooleanAggregator}.
+ * This class is generated. Do not edit it.
+ */
+public final class TopBooleanGroupingAggregatorFunction implements GroupingAggregatorFunction {
+  private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
+      new IntermediateStateDesc("top", ElementType.BOOLEAN)  );
+
+  private final TopBooleanAggregator.GroupingState state;
+
+  private final List<Integer> channels;
+
+  private final DriverContext driverContext;
+
+  private final int limit;
+
+  private final boolean ascending;
+
+  public TopBooleanGroupingAggregatorFunction(List<Integer> channels,
+      TopBooleanAggregator.GroupingState state, DriverContext driverContext, int limit,
+      boolean ascending) {
+    this.channels = channels;
+    this.state = state;
+    this.driverContext = driverContext;
+    this.limit = limit;
+    this.ascending = ascending;
+  }
+
+  public static TopBooleanGroupingAggregatorFunction create(List<Integer> channels,
+      DriverContext driverContext, int limit, boolean ascending) {
+    return new TopBooleanGroupingAggregatorFunction(channels, TopBooleanAggregator.initGrouping(driverContext.bigArrays(), limit, ascending), driverContext, limit, ascending);
+  }
+
+  public static List<IntermediateStateDesc> intermediateStateDesc() {
+    return INTERMEDIATE_STATE_DESC;
+  }
+
+  @Override
+  public int intermediateBlockCount() {
+    return INTERMEDIATE_STATE_DESC.size();
+  }
+
+  @Override
+  public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
+      Page page) {
+    BooleanBlock valuesBlock = page.getBlock(channels.get(0));
+    BooleanVector valuesVector = valuesBlock.asVector();
+    if (valuesVector == null) {
+      if (valuesBlock.mayHaveNulls()) {
+        state.enableGroupIdTracking(seenGroupIds);
+      }
+      return new GroupingAggregatorFunction.AddInput() {
+        @Override
+        public void add(int positionOffset, IntBlock groupIds) {
+          addRawInput(positionOffset, groupIds, valuesBlock);
+        }
+
+        @Override
+        public void add(int positionOffset, IntVector groupIds) {
+          addRawInput(positionOffset, groupIds, valuesBlock);
+        }
+      };
+    }
+    return new GroupingAggregatorFunction.AddInput() {
+      @Override
+      public void add(int positionOffset, IntBlock groupIds) {
+        addRawInput(positionOffset, groupIds, valuesVector);
+      }
+
+      @Override
+      public void add(int positionOffset, IntVector groupIds) {
+        addRawInput(positionOffset, groupIds, valuesVector);
+      }
+    };
+  }
+
+  private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) {
+    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+      int groupId = Math.toIntExact(groups.getInt(groupPosition));
+      if (values.isNull(groupPosition + positionOffset)) {
+        continue;
+      }
+      int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
+      int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
+      for (int v = valuesStart; v < valuesEnd; v++) {
+        TopBooleanAggregator.combine(state, groupId, values.getBoolean(v));
+      }
+    }
+  }
+
+  private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) {
+    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+      int groupId = Math.toIntExact(groups.getInt(groupPosition));
+      TopBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset));
+    }
+  }
+
+  private void addRawInput(int positionOffset, IntBlock groups, BooleanBlock values) {
+    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+      if (groups.isNull(groupPosition)) {
+        continue;
+      }
+      int groupStart = groups.getFirstValueIndex(groupPosition);
+      int groupEnd = groupStart + groups.getValueCount(groupPosition);
+      for (int g = groupStart; g < groupEnd; g++) {
+        int groupId = Math.toIntExact(groups.getInt(g));
+        if (values.isNull(groupPosition + positionOffset)) {
+          continue;
+        }
+        int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
+        int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
+        for (int v = valuesStart; v < valuesEnd; v++) {
+          TopBooleanAggregator.combine(state, groupId, values.getBoolean(v));
+        }
+      }
+    }
+  }
+
+  private void addRawInput(int positionOffset, IntBlock groups, BooleanVector values) {
+    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+      if (groups.isNull(groupPosition)) {
+        continue;
+      }
+      int groupStart = groups.getFirstValueIndex(groupPosition);
+      int groupEnd = groupStart + groups.getValueCount(groupPosition);
+      for (int g = groupStart; g < groupEnd; g++) {
+        int groupId = Math.toIntExact(groups.getInt(g));
+        TopBooleanAggregator.combine(state, groupId, values.getBoolean(groupPosition + positionOffset));
+      }
+    }
+  }
+
+  @Override
+  public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
+    assert channels.size() == intermediateBlockCount();
+    Block topUncast = page.getBlock(channels.get(0));
+    if (topUncast.areAllValuesNull()) {
+      return;
+    }
+    BooleanBlock top = (BooleanBlock) topUncast;
+    for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
+      int groupId = Math.toIntExact(groups.getInt(groupPosition));
+      TopBooleanAggregator.combineIntermediate(state, groupId, top, groupPosition + positionOffset);
+    }
+  }
+
+  @Override
+  public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
+    if (input.getClass() != getClass()) {
+      throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
+    }
+    TopBooleanAggregator.GroupingState inState = ((TopBooleanGroupingAggregatorFunction) input).state;
+    state.enableGroupIdTracking(new SeenGroupIds.Empty());
+    TopBooleanAggregator.combineStates(state, groupId, inState, position);
+  }
+
+  @Override
+  public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
+    state.toIntermediate(blocks, offset, selected, driverContext);
+  }
+
+  @Override
+  public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
+      DriverContext driverContext) {
+    blocks[offset] = TopBooleanAggregator.evaluateFinal(state, selected, driverContext);
+  }
+
+  @Override
+  public String toString() {
+    StringBuilder sb = new StringBuilder();
+    sb.append(getClass().getSimpleName()).append("[");
+    sb.append("channels=").append(channels);
+    sb.append("]");
+    return sb.toString();
+  }
+
+  @Override
+  public void close() {
+    state.close();
+  }
+}

+ 137 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/TopBooleanAggregator.java

@@ -0,0 +1,137 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.ann.Aggregator;
+import org.elasticsearch.compute.ann.GroupingAggregator;
+import org.elasticsearch.compute.ann.IntermediateState;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BooleanBlock;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.compute.data.sort.BooleanBucketedSort;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.search.sort.SortOrder;
+
+/**
+ * Aggregates the top N field values for boolean.
+ */
+@Aggregator({ @IntermediateState(name = "top", type = "BOOLEAN_BLOCK") })
+@GroupingAggregator
+class TopBooleanAggregator {
+    public static SingleState initSingle(BigArrays bigArrays, int limit, boolean ascending) {
+        return new SingleState(bigArrays, limit, ascending);
+    }
+
+    public static void combine(SingleState state, boolean v) {
+        state.add(v);
+    }
+
+    public static void combineIntermediate(SingleState state, BooleanBlock values) {
+        int start = values.getFirstValueIndex(0);
+        int end = start + values.getValueCount(0);
+        for (int i = start; i < end; i++) {
+            combine(state, values.getBoolean(i));
+        }
+    }
+
+    public static Block evaluateFinal(SingleState state, DriverContext driverContext) {
+        return state.toBlock(driverContext.blockFactory());
+    }
+
+    public static GroupingState initGrouping(BigArrays bigArrays, int limit, boolean ascending) {
+        return new GroupingState(bigArrays, limit, ascending);
+    }
+
+    public static void combine(GroupingState state, int groupId, boolean v) {
+        state.add(groupId, v);
+    }
+
+    public static void combineIntermediate(GroupingState state, int groupId, BooleanBlock values, int valuesPosition) {
+        int start = values.getFirstValueIndex(valuesPosition);
+        int end = start + values.getValueCount(valuesPosition);
+        for (int i = start; i < end; i++) {
+            combine(state, groupId, values.getBoolean(i));
+        }
+    }
+
+    public static void combineStates(GroupingState current, int groupId, GroupingState state, int statePosition) {
+        current.merge(groupId, state, statePosition);
+    }
+
+    public static Block evaluateFinal(GroupingState state, IntVector selected, DriverContext driverContext) {
+        return state.toBlock(driverContext.blockFactory(), selected);
+    }
+
+    public static class GroupingState implements Releasable {
+        private final BooleanBucketedSort sort;
+
+        private GroupingState(BigArrays bigArrays, int limit, boolean ascending) {
+            this.sort = new BooleanBucketedSort(bigArrays, ascending ? SortOrder.ASC : SortOrder.DESC, limit);
+        }
+
+        public void add(int groupId, boolean value) {
+            sort.collect(value, groupId);
+        }
+
+        public void merge(int groupId, GroupingState other, int otherGroupId) {
+            sort.merge(groupId, other.sort, otherGroupId);
+        }
+
+        void toIntermediate(Block[] blocks, int offset, IntVector selected, DriverContext driverContext) {
+            blocks[offset] = toBlock(driverContext.blockFactory(), selected);
+        }
+
+        Block toBlock(BlockFactory blockFactory, IntVector selected) {
+            return sort.toBlock(blockFactory, selected);
+        }
+
+        void enableGroupIdTracking(SeenGroupIds seen) {
+            // we figure out seen values from nulls on the values block
+        }
+
+        @Override
+        public void close() {
+            Releasables.closeExpectNoException(sort);
+        }
+    }
+
+    public static class SingleState implements Releasable {
+        private final GroupingState internalState;
+
+        private SingleState(BigArrays bigArrays, int limit, boolean ascending) {
+            this.internalState = new GroupingState(bigArrays, limit, ascending);
+        }
+
+        public void add(boolean value) {
+            internalState.add(0, value);
+        }
+
+        public void merge(GroupingState other) {
+            internalState.merge(0, other, 0);
+        }
+
+        void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) {
+            blocks[offset] = toBlock(driverContext.blockFactory());
+        }
+
+        Block toBlock(BlockFactory blockFactory) {
+            try (var intValues = blockFactory.newConstantIntVector(0, 1)) {
+                return internalState.toBlock(blockFactory, intValues);
+            }
+        }
+
+        @Override
+        public void close() {
+            Releasables.closeExpectNoException(internalState);
+        }
+    }
+}

+ 198 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/data/sort/BooleanBucketedSort.java

@@ -0,0 +1,198 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.data.sort;
+
+import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.common.util.IntArray;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.search.sort.SortOrder;
+
+import java.util.stream.IntStream;
+
+/**
+ * Aggregates the top N boolean values per bucket.
+ * This class collects by just keeping the count of true and false values.
+ */
+public class BooleanBucketedSort implements Releasable {
+
+    private final BigArrays bigArrays;
+    private final SortOrder order;
+    private final int bucketSize;
+    /**
+     * An array containing all the values on all buckets. The structure is as follows:
+     * <p>
+     *     For each bucket, there are 2 values: The first keeps the count of true values, and the second the count of false values.
+     * </p>
+     */
+    private IntArray values;
+
+    public BooleanBucketedSort(BigArrays bigArrays, SortOrder order, int bucketSize) {
+        this.bigArrays = bigArrays;
+        this.order = order;
+        this.bucketSize = bucketSize;
+
+        boolean success = false;
+        try {
+            values = bigArrays.newIntArray(0, true);
+            success = true;
+        } finally {
+            if (success == false) {
+                close();
+            }
+        }
+    }
+
+    /**
+     * Collects a {@code value} into a {@code bucket}.
+     * <p>
+     *     It may or may not be inserted in the heap, depending on if it is better than the current root.
+     * </p>
+     */
+    public void collect(boolean value, int bucket) {
+        long rootIndex = (long) bucket * 2;
+
+        long requiredSize = rootIndex + 2;
+        if (values.size() < requiredSize) {
+            grow(requiredSize);
+        }
+
+        if (value) {
+            values.increment(rootIndex + 1, 1);
+        } else {
+            values.increment(rootIndex, 1);
+        }
+    }
+
+    /**
+     * The order of the sort.
+     */
+    public SortOrder getOrder() {
+        return order;
+    }
+
+    /**
+     * The number of values to store per bucket.
+     */
+    public int getBucketSize() {
+        return bucketSize;
+    }
+
+    /**
+     * Merge the values from {@code other}'s {@code otherGroupId} into {@code groupId}.
+     */
+    public void merge(int groupId, BooleanBucketedSort other, int otherGroupId) {
+        long otherRootIndex = (long) otherGroupId * 2;
+
+        if (other.values.size() < otherRootIndex + 2) {
+            return;
+        }
+
+        int falseValues = other.values.get(otherRootIndex);
+        int trueValues = other.values.get(otherRootIndex + 1);
+
+        if (falseValues + trueValues == 0) {
+            return;
+        }
+
+        long rootIndex = (long) groupId * 2;
+
+        long requiredSize = rootIndex + 2;
+        if (values.size() < requiredSize) {
+            grow(requiredSize);
+        }
+
+        values.increment(rootIndex, falseValues);
+        values.increment(rootIndex + 1, trueValues);
+    }
+
+    /**
+     * Creates a block with the values from the {@code selected} groups.
+     */
+    public Block toBlock(BlockFactory blockFactory, IntVector selected) {
+        // Check if the selected groups are all empty, to avoid allocating extra memory
+        if (bucketSize == 0 || IntStream.range(0, selected.getPositionCount()).map(selected::getInt).noneMatch(bucket -> {
+            long rootIndex = (long) bucket * 2;
+
+            if (values.size() < rootIndex + 2) {
+                return false;
+            }
+
+            var size = values.get(rootIndex) + values.get(rootIndex + 1);
+            return size > 0;
+        })) {
+            return blockFactory.newConstantNullBlock(selected.getPositionCount());
+        }
+
+        try (var builder = blockFactory.newBooleanBlockBuilder(selected.getPositionCount())) {
+            for (int s = 0; s < selected.getPositionCount(); s++) {
+                int bucket = selected.getInt(s);
+
+                long rootIndex = (long) bucket * 2;
+
+                if (values.size() < rootIndex + 2) {
+                    builder.appendNull();
+                    continue;
+                }
+
+                int falseValues = values.get(rootIndex);
+                int trueValues = values.get(rootIndex + 1);
+                long totalValues = (long) falseValues + trueValues;
+
+                if (totalValues == 0) {
+                    builder.appendNull();
+                    continue;
+                }
+
+                if (totalValues == 1) {
+                    builder.appendBoolean(trueValues > 0);
+                    continue;
+                }
+
+                builder.beginPositionEntry();
+                if (order == SortOrder.ASC) {
+                    int falseValuesToAdd = Math.min(falseValues, bucketSize);
+                    int trueValuesToAdd = Math.min(trueValues, bucketSize - falseValuesToAdd);
+                    for (int i = 0; i < falseValuesToAdd; i++) {
+                        builder.appendBoolean(false);
+                    }
+                    for (int i = 0; i < trueValuesToAdd; i++) {
+                        builder.appendBoolean(true);
+                    }
+                } else {
+                    int trueValuesToAdd = Math.min(trueValues, bucketSize);
+                    int falseValuesToAdd = Math.min(falseValues, bucketSize - trueValuesToAdd);
+                    for (int i = 0; i < trueValuesToAdd; i++) {
+                        builder.appendBoolean(true);
+                    }
+                    for (int i = 0; i < falseValuesToAdd; i++) {
+                        builder.appendBoolean(false);
+                    }
+                }
+                builder.endPositionEntry();
+            }
+            return builder.build();
+        }
+    }
+
+    /**
+     * Allocate storage for more buckets and store the "next gather offset"
+     * for those new buckets.
+     */
+    private void grow(long minSize) {
+        values = bigArrays.grow(values, minSize);
+    }
+
+    @Override
+    public final void close() {
+        Releasables.close(values);
+    }
+}

+ 44 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/TopBooleanAggregatorFunctionTests.java

@@ -0,0 +1,44 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.aggregation;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BlockUtils;
+import org.elasticsearch.compute.operator.SequenceBooleanBlockSourceOperator;
+import org.elasticsearch.compute.operator.SourceOperator;
+
+import java.util.List;
+import java.util.stream.IntStream;
+
+import static org.hamcrest.Matchers.contains;
+
+public class TopBooleanAggregatorFunctionTests extends AggregatorFunctionTestCase {
+    private static final int LIMIT = 100;
+
+    @Override
+    protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
+        return new SequenceBooleanBlockSourceOperator(blockFactory, IntStream.range(0, size).mapToObj(l -> randomBoolean()).toList());
+    }
+
+    @Override
+    protected AggregatorFunctionSupplier aggregatorFunction(List<Integer> inputChannels) {
+        return new TopBooleanAggregatorFunctionSupplier(inputChannels, LIMIT, true);
+    }
+
+    @Override
+    protected String expectedDescriptionOfAggregator() {
+        return "top of booleans";
+    }
+
+    @Override
+    public void assertSimpleOutput(List<Block> input, Block result) {
+        Object[] values = input.stream().flatMap(b -> allBooleans(b)).sorted().limit(LIMIT).toArray(Object[]::new);
+        assertThat((List<?>) BlockUtils.toJavaObject(result, 0), contains(values));
+    }
+}

+ 62 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/sort/BooleanBucketedSortTests.java

@@ -0,0 +1,62 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute.data.sort;
+
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.BlockFactory;
+import org.elasticsearch.compute.data.BooleanBlock;
+import org.elasticsearch.compute.data.ElementType;
+import org.elasticsearch.compute.data.IntVector;
+import org.elasticsearch.search.sort.SortOrder;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public class BooleanBucketedSortTests extends BucketedSortTestCase<BooleanBucketedSort> {
+    @Override
+    protected BooleanBucketedSort build(SortOrder sortOrder, int bucketSize) {
+        return new BooleanBucketedSort(bigArrays(), sortOrder, bucketSize);
+    }
+
+    @Override
+    protected Object expectedValue(double v) {
+        return toBoolean(v);
+    }
+
+    @Override
+    protected double randomValue() {
+        return randomBoolean() ? 0d : 1d;
+    }
+
+    @Override
+    protected void collect(BooleanBucketedSort sort, double value, int bucket) {
+        sort.collect(toBoolean(value), bucket);
+    }
+
+    @Override
+    protected void merge(BooleanBucketedSort sort, int groupId, BooleanBucketedSort other, int otherGroupId) {
+        sort.merge(groupId, other, otherGroupId);
+    }
+
+    @Override
+    protected Block toBlock(BooleanBucketedSort sort, BlockFactory blockFactory, IntVector selected) {
+        return sort.toBlock(blockFactory, selected);
+    }
+
+    @Override
+    protected void assertBlockTypeAndValues(Block block, Object... values) {
+        assertThat(block.elementType(), equalTo(ElementType.BOOLEAN));
+        var typedBlock = (BooleanBlock) block;
+        for (int i = 0; i < values.length; i++) {
+            assertThat(typedBlock.getBoolean(i), equalTo(values[i]));
+        }
+    }
+
+    private boolean toBoolean(double value) {
+        return value > 0;
+    }
+}

+ 3 - 3
x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec

@@ -110,7 +110,7 @@ double tau()
 "keyword|text to_upper(str:keyword|text)"
 "version to_ver(field:keyword|text|version)"
 "version to_version(field:keyword|text|version)"
-"double|integer|long|date top(field:double|integer|long|date, limit:integer, order:keyword)"
+"boolean|double|integer|long|date top(field:boolean|double|integer|long|date, limit:integer, order:keyword)"
 "keyword|text trim(string:keyword|text)"
 "boolean|date|double|integer|ip|keyword|long|text|version values(field:boolean|date|double|integer|ip|keyword|long|text|version)"
 "double weighted_avg(number:double|integer|long, weight:double|integer|long)"
@@ -230,7 +230,7 @@ to_unsigned_lo|field                               |"boolean|date|keyword|text|d
 to_upper      |str                                 |"keyword|text"                                                                                                                    |String expression. If `null`, the function returns `null`.
 to_ver        |field                               |"keyword|text|version"                                                                                                            |Input value. The input can be a single- or multi-valued column or an expression.
 to_version    |field                               |"keyword|text|version"                                                                                                            |Input value. The input can be a single- or multi-valued column or an expression.
-top      |[field, limit, order]               |["double|integer|long|date", integer, keyword]                                                                                    |[The field to collect the top values for.,The maximum number of values to collect.,The order to calculate the top values. Either `asc` or `desc`.]
+top           |[field, limit, order]               |["boolean|double|integer|long|date", integer, keyword]                                                                            |[The field to collect the top values for.,The maximum number of values to collect.,The order to calculate the top values. Either `asc` or `desc`.]
 trim          |string                              |"keyword|text"                                                                                                                    |String expression. If `null`, the function returns `null`.
 values        |field                               |"boolean|date|double|integer|ip|keyword|long|text|version"                                                                        |[""]
 weighted_avg  |[number, weight]                    |["double|integer|long", "double|integer|long"]                                                                                    |[A numeric value., A numeric weight.]
@@ -473,7 +473,7 @@ to_unsigned_lo|unsigned_long
 to_upper      |"keyword|text"                                                                                                              |false                       |false           |false
 to_ver        |version                                                                                                                     |false                       |false           |false
 to_version    |version                                                                                                                     |false                       |false           |false
-top      |"double|integer|long|date"                                                                                                  |[false, false, false]       |false           |true
+top           |"boolean|double|integer|long|date"                                                                                          |[false, false, false]       |false           |true
 trim          |"keyword|text"                                                                                                              |false                       |false           |false
 values        |"boolean|date|double|integer|ip|keyword|long|text|version"                                                                  |false                       |false           |true
 weighted_avg  |"double"                                                                                                                    |[false, false]              |false           |true

+ 40 - 4
x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top.csv-spec

@@ -106,8 +106,8 @@ FROM employees
     long = TOP(salary_change.long, 1, "asc")
 ;
 
-date:date | double:double | integer:integer | long:long
-1985-02-18T00:00:00.000Z | -9.81 | 25324 | -9
+date:date                | double:double | integer:integer | long:long
+1985-02-18T00:00:00.000Z | -9.81         | 25324           | -9
 ;
 
 topAllTypesMax
@@ -120,8 +120,8 @@ FROM employees
     long = TOP(salary_change.long, 1, "desc")
 ;
 
-date:date | double:double | integer:integer | long:long
-1999-04-30T00:00:00.000Z | 14.74 | 74999 | 14
+date:date                | double:double | integer:integer | long:long
+1999-04-30T00:00:00.000Z | 14.74         | 74999           | 14
 ;
 
 topAscDesc
@@ -154,3 +154,39 @@ FROM employees
 integer:integer
 [5, 5]
 ;
+
+topBooleans
+required_capability: agg_top
+required_capability: agg_top_boolean_support
+FROM employees
+| eval x = salary is not null
+| where emp_no > 10050
+| STATS
+    top_asc = TOP(still_hired, 2, "asc"),
+    min = TOP(still_hired, 1, "asc"),
+    top_desc = TOP(still_hired, 2, "desc"),
+    max = TOP(still_hired, 1, "desc"),
+    a = TOP(salary is not null, 2, "asc"),
+    b = TOP(x, 2, "asc"),
+    c = TOP(case(salary is null, true, false), 2, "asc"),
+    d = TOP(is_rehired, 2, "asc")
+;
+
+top_asc:boolean    | min:boolean | top_desc:boolean | max:boolean | a:boolean    | b:boolean    | c:boolean      | d:boolean
+[false, false]     | false       | [true, true]     | true        | [true, true] | [true, true] | [false, false] | [false, false]
+;
+
+topBooleansRow
+required_capability: agg_top
+required_capability: agg_top_boolean_support
+ROW constant = true, mv = [true, false]
+| STATS
+    constant_asc = TOP(constant, 2, "asc"),
+    constant_desc = TOP(constant, 2, "desc"),
+    mv_asc = TOP(mv, 2, "asc"),
+    mv_desc = TOP(mv, 2, "desc")
+;
+
+constant_asc:boolean | constant_desc:boolean | mv_asc:boolean | mv_desc:boolean
+true                 | true                  | [false, true]  | [true, false]
+;

+ 5 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -54,6 +54,11 @@ public class EsqlCapabilities {
          */
         AGG_MAX_MIN_BOOLEAN_SUPPORT,
 
+        /**
+         * Support for booleans in {@code TOP} aggregation.
+         */
+        AGG_TOP_BOOLEAN_SUPPORT,
+
         /**
          * Optimization for ST_CENTROID changed some results in cartesian data. #108713
          */

+ 7 - 3
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java

@@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.lucene.BytesRefs;
 import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
+import org.elasticsearch.compute.aggregation.TopBooleanAggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.TopDoubleAggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.TopIntAggregatorFunctionSupplier;
 import org.elasticsearch.compute.aggregation.TopLongAggregatorFunctionSupplier;
@@ -46,7 +47,7 @@ public class Top extends AggregateFunction implements ToAggregator, SurrogateExp
     private static final String ORDER_DESC = "DESC";
 
     @FunctionInfo(
-        returnType = { "double", "integer", "long", "date" },
+        returnType = { "boolean", "double", "integer", "long", "date" },
         description = "Collects the top values for a field. Includes repeated values.",
         isAggregation = true,
         examples = @Example(file = "stats_top", tag = "top")
@@ -55,7 +56,7 @@ public class Top extends AggregateFunction implements ToAggregator, SurrogateExp
         Source source,
         @Param(
             name = "field",
-            type = { "double", "integer", "long", "date" },
+            type = { "boolean", "double", "integer", "long", "date" },
             description = "The field to collect the top values for."
         ) Expression field,
         @Param(name = "limit", type = { "integer" }, description = "The maximum number of values to collect.") Expression limit,
@@ -120,7 +121,7 @@ public class Top extends AggregateFunction implements ToAggregator, SurrogateExp
 
         var typeResolution = isType(
             field(),
-            dt -> dt == DataType.DATETIME || dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
+            dt -> dt == DataType.BOOLEAN || dt == DataType.DATETIME || dt.isNumeric() && dt != DataType.UNSIGNED_LONG,
             sourceText(),
             FIRST,
             "numeric except unsigned_long or counter types"
@@ -176,6 +177,9 @@ public class Top extends AggregateFunction implements ToAggregator, SurrogateExp
         if (type == DataType.DOUBLE) {
             return new TopDoubleAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
         }
+        if (type == DataType.BOOLEAN) {
+            return new TopBooleanAggregatorFunctionSupplier(inputChannels, limitValue(), orderValue());
+        }
         throw EsqlIllegalArgumentException.illegalDataType(type);
     }
 

+ 1 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java

@@ -157,7 +157,7 @@ final class AggregateMapper {
             // TODO can't we figure this out from the function itself?
             types = List.of("Int", "Long", "Double", "Boolean", "BytesRef");
         } else if (Top.class.isAssignableFrom(clazz)) {
-            types = List.of("Int", "Long", "Double");
+            types = List.of("Boolean", "Int", "Long", "Double");
         } else if (Rate.class.isAssignableFrom(clazz)) {
             types = List.of("Int", "Long", "Double");
         } else if (FromPartial.class.isAssignableFrom(clazz) || ToPartial.class.isAssignableFrom(clazz)) {

+ 6 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java

@@ -991,6 +991,12 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
         return cases;
     }
 
+    /**
+     * Generate cases for {@link DataType#BOOLEAN}.
+     * <p>
+     *     For multi-row parameters, see {@link MultiRowTestCaseSupplier#booleanCases}.
+     * </p>
+     */
     public static List<TypedDataSupplier> booleanCases() {
         return List.of(
             new TypedDataSupplier("<true>", () -> true, DataType.BOOLEAN),

+ 29 - 2
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java

@@ -42,7 +42,8 @@ public class TopTests extends AbstractAggregationTestCase {
                     MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true),
                     MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true),
                     MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true),
-                    MultiRowTestCaseSupplier.dateCases(1, 1000)
+                    MultiRowTestCaseSupplier.dateCases(1, 1000),
+                    MultiRowTestCaseSupplier.booleanCases(1, 1000)
                 )
                     .flatMap(List::stream)
                     .map(fieldCaseSupplier -> TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order))
@@ -53,6 +54,19 @@ public class TopTests extends AbstractAggregationTestCase {
         suppliers.addAll(
             List.of(
                 // Surrogates
+                new TestCaseSupplier(
+                    List.of(DataType.BOOLEAN, DataType.INTEGER, DataType.KEYWORD),
+                    () -> new TestCaseSupplier.TestCase(
+                        List.of(
+                            TestCaseSupplier.TypedData.multiRow(List.of(true, true, false), DataType.BOOLEAN, "field"),
+                            new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(),
+                            new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral()
+                        ),
+                        "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]",
+                        DataType.BOOLEAN,
+                        equalTo(true)
+                    )
+                ),
                 new TestCaseSupplier(
                     List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD),
                     () -> new TestCaseSupplier.TestCase(
@@ -107,6 +121,19 @@ public class TopTests extends AbstractAggregationTestCase {
                 ),
 
                 // Folding
+                new TestCaseSupplier(
+                    List.of(DataType.BOOLEAN, DataType.INTEGER, DataType.KEYWORD),
+                    () -> new TestCaseSupplier.TestCase(
+                        List.of(
+                            TestCaseSupplier.TypedData.multiRow(List.of(true), DataType.BOOLEAN, "field"),
+                            new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(),
+                            new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral()
+                        ),
+                        "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]",
+                        DataType.BOOLEAN,
+                        equalTo(true)
+                    )
+                ),
                 new TestCaseSupplier(
                     List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD),
                     () -> new TestCaseSupplier.TestCase(
@@ -222,7 +249,7 @@ public class TopTests extends AbstractAggregationTestCase {
         TestCaseSupplier.TypedDataSupplier limitCaseSupplier,
         String order
     ) {
-        return new TestCaseSupplier(List.of(fieldSupplier.type(), DataType.INTEGER, DataType.KEYWORD), () -> {
+        return new TestCaseSupplier(fieldSupplier.name(), List.of(fieldSupplier.type(), DataType.INTEGER, DataType.KEYWORD), () -> {
             var fieldTypedData = fieldSupplier.get();
             var limitTypedData = limitCaseSupplier.get().forceLiteral();
             var limit = (int) limitTypedData.getValue();