瀏覽代碼

Support ordinals grouping for rate aggregation (#106735)

Add support for ordinal grouping in the rate aggregation function.

Relates #106703
Nhat Nguyen 1 年之前
父節點
當前提交
0b3382cd24

+ 50 - 4
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleAggregator.java

@@ -24,6 +24,8 @@ import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 
+import java.util.Arrays;
+
 /**
  * A rate grouping aggregation definition for double.
  * This class is generated. Edit `X-RateAggregator.java.st` instead.
@@ -59,10 +61,10 @@ public class RateDoubleAggregator {
     public static void combineStates(
         DoubleRateGroupingState current,
         int currentGroupId, // make the stylecheck happy
-        DoubleRateGroupingState state,
-        int statePosition
+        DoubleRateGroupingState otherState,
+        int otherGroupId
     ) {
-        throw new UnsupportedOperationException("ordinals grouping is not supported yet");
+        current.combineState(currentGroupId, otherState, otherGroupId);
     }
 
     public static Block evaluateFinal(DoubleRateGroupingState state, IntVector selected, DriverContext driverContext) {
@@ -163,6 +165,7 @@ public class RateDoubleAggregator {
             if (state == null) {
                 adjustBreaker(DoubleRateState.bytesUsed(valueCount));
                 state = new DoubleRateState(valueCount);
+                state.reset = reset;
                 states.set(groupId, state);
                 // TODO: add bulk_copy to Block
                 for (int i = 0; i < valueCount; i++) {
@@ -172,11 +175,11 @@ public class RateDoubleAggregator {
             } else {
                 adjustBreaker(DoubleRateState.bytesUsed(state.entries() + valueCount));
                 var newState = new DoubleRateState(state.entries() + valueCount);
+                newState.reset = state.reset + reset;
                 states.set(groupId, newState);
                 merge(state, newState, firstIndex, valueCount, timestamps, values);
                 adjustBreaker(-DoubleRateState.bytesUsed(state.entries())); // old state
             }
-            state.reset += reset;
         }
 
         void merge(DoubleRateState curr, DoubleRateState dst, int firstIndex, int rightCount, LongBlock timestamps, DoubleBlock values) {
@@ -208,6 +211,49 @@ public class RateDoubleAggregator {
             }
         }
 
+        void combineState(int groupId, DoubleRateGroupingState otherState, int otherGroupId) {
+            var other = otherGroupId < otherState.states.size() ? otherState.states.get(otherGroupId) : null;
+            if (other == null) {
+                return;
+            }
+            ensureCapacity(groupId);
+            var curr = states.get(groupId);
+            if (curr == null) {
+                var len = other.entries();
+                adjustBreaker(DoubleRateState.bytesUsed(len));
+                curr = new DoubleRateState(Arrays.copyOf(other.timestamps, len), Arrays.copyOf(other.values, len));
+                curr.reset = other.reset;
+                states.set(groupId, curr);
+            } else {
+                states.set(groupId, mergeState(curr, other));
+            }
+        }
+
+        DoubleRateState mergeState(DoubleRateState s1, DoubleRateState s2) {
+            var newLen = s1.entries() + s2.entries();
+            adjustBreaker(DoubleRateState.bytesUsed(newLen));
+            var dst = new DoubleRateState(newLen);
+            dst.reset = s1.reset + s2.reset;
+            int i = 0, j = 0, k = 0;
+            while (i < s1.entries() && j < s2.entries()) {
+                if (s1.timestamps[i] > s2.timestamps[j]) {
+                    dst.timestamps[k] = s1.timestamps[i];
+                    dst.values[k] = s1.values[i];
+                    ++i;
+                } else {
+                    dst.timestamps[k] = s2.timestamps[j];
+                    dst.values[k] = s2.values[j];
+                    ++j;
+                }
+                ++k;
+            }
+            System.arraycopy(s1.timestamps, i, dst.timestamps, k, s1.entries() - i);
+            System.arraycopy(s1.values, i, dst.values, k, s1.entries() - i);
+            System.arraycopy(s2.timestamps, j, dst.timestamps, k, s2.entries() - j);
+            System.arraycopy(s2.values, j, dst.values, k, s2.entries() - j);
+            return dst;
+        }
+
         @Override
         public long ramBytesUsed() {
             return states.ramBytesUsed() + stateBytes;

+ 50 - 4
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntAggregator.java

@@ -25,6 +25,8 @@ import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 
+import java.util.Arrays;
+
 /**
  * A rate grouping aggregation definition for int.
  * This class is generated. Edit `X-RateAggregator.java.st` instead.
@@ -60,10 +62,10 @@ public class RateIntAggregator {
     public static void combineStates(
         IntRateGroupingState current,
         int currentGroupId, // make the stylecheck happy
-        IntRateGroupingState state,
-        int statePosition
+        IntRateGroupingState otherState,
+        int otherGroupId
     ) {
-        throw new UnsupportedOperationException("ordinals grouping is not supported yet");
+        current.combineState(currentGroupId, otherState, otherGroupId);
     }
 
     public static Block evaluateFinal(IntRateGroupingState state, IntVector selected, DriverContext driverContext) {
@@ -164,6 +166,7 @@ public class RateIntAggregator {
             if (state == null) {
                 adjustBreaker(IntRateState.bytesUsed(valueCount));
                 state = new IntRateState(valueCount);
+                state.reset = reset;
                 states.set(groupId, state);
                 // TODO: add bulk_copy to Block
                 for (int i = 0; i < valueCount; i++) {
@@ -173,11 +176,11 @@ public class RateIntAggregator {
             } else {
                 adjustBreaker(IntRateState.bytesUsed(state.entries() + valueCount));
                 var newState = new IntRateState(state.entries() + valueCount);
+                newState.reset = state.reset + reset;
                 states.set(groupId, newState);
                 merge(state, newState, firstIndex, valueCount, timestamps, values);
                 adjustBreaker(-IntRateState.bytesUsed(state.entries())); // old state
             }
-            state.reset += reset;
         }
 
         void merge(IntRateState curr, IntRateState dst, int firstIndex, int rightCount, LongBlock timestamps, IntBlock values) {
@@ -209,6 +212,49 @@ public class RateIntAggregator {
             }
         }
 
+        void combineState(int groupId, IntRateGroupingState otherState, int otherGroupId) {
+            var other = otherGroupId < otherState.states.size() ? otherState.states.get(otherGroupId) : null;
+            if (other == null) {
+                return;
+            }
+            ensureCapacity(groupId);
+            var curr = states.get(groupId);
+            if (curr == null) {
+                var len = other.entries();
+                adjustBreaker(IntRateState.bytesUsed(len));
+                curr = new IntRateState(Arrays.copyOf(other.timestamps, len), Arrays.copyOf(other.values, len));
+                curr.reset = other.reset;
+                states.set(groupId, curr);
+            } else {
+                states.set(groupId, mergeState(curr, other));
+            }
+        }
+
+        IntRateState mergeState(IntRateState s1, IntRateState s2) {
+            var newLen = s1.entries() + s2.entries();
+            adjustBreaker(IntRateState.bytesUsed(newLen));
+            var dst = new IntRateState(newLen);
+            dst.reset = s1.reset + s2.reset;
+            int i = 0, j = 0, k = 0;
+            while (i < s1.entries() && j < s2.entries()) {
+                if (s1.timestamps[i] > s2.timestamps[j]) {
+                    dst.timestamps[k] = s1.timestamps[i];
+                    dst.values[k] = s1.values[i];
+                    ++i;
+                } else {
+                    dst.timestamps[k] = s2.timestamps[j];
+                    dst.values[k] = s2.values[j];
+                    ++j;
+                }
+                ++k;
+            }
+            System.arraycopy(s1.timestamps, i, dst.timestamps, k, s1.entries() - i);
+            System.arraycopy(s1.values, i, dst.values, k, s1.entries() - i);
+            System.arraycopy(s2.timestamps, j, dst.timestamps, k, s2.entries() - j);
+            System.arraycopy(s2.values, j, dst.values, k, s2.entries() - j);
+            return dst;
+        }
+
         @Override
         public long ramBytesUsed() {
             return states.ramBytesUsed() + stateBytes;

+ 50 - 4
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongAggregator.java

@@ -24,6 +24,8 @@ import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 
+import java.util.Arrays;
+
 /**
  * A rate grouping aggregation definition for long.
  * This class is generated. Edit `X-RateAggregator.java.st` instead.
@@ -59,10 +61,10 @@ public class RateLongAggregator {
     public static void combineStates(
         LongRateGroupingState current,
         int currentGroupId, // make the stylecheck happy
-        LongRateGroupingState state,
-        int statePosition
+        LongRateGroupingState otherState,
+        int otherGroupId
     ) {
-        throw new UnsupportedOperationException("ordinals grouping is not supported yet");
+        current.combineState(currentGroupId, otherState, otherGroupId);
     }
 
     public static Block evaluateFinal(LongRateGroupingState state, IntVector selected, DriverContext driverContext) {
@@ -163,6 +165,7 @@ public class RateLongAggregator {
             if (state == null) {
                 adjustBreaker(LongRateState.bytesUsed(valueCount));
                 state = new LongRateState(valueCount);
+                state.reset = reset;
                 states.set(groupId, state);
                 // TODO: add bulk_copy to Block
                 for (int i = 0; i < valueCount; i++) {
@@ -172,11 +175,11 @@ public class RateLongAggregator {
             } else {
                 adjustBreaker(LongRateState.bytesUsed(state.entries() + valueCount));
                 var newState = new LongRateState(state.entries() + valueCount);
+                newState.reset = state.reset + reset;
                 states.set(groupId, newState);
                 merge(state, newState, firstIndex, valueCount, timestamps, values);
                 adjustBreaker(-LongRateState.bytesUsed(state.entries())); // old state
             }
-            state.reset += reset;
         }
 
         void merge(LongRateState curr, LongRateState dst, int firstIndex, int rightCount, LongBlock timestamps, LongBlock values) {
@@ -208,6 +211,49 @@ public class RateLongAggregator {
             }
         }
 
+        void combineState(int groupId, LongRateGroupingState otherState, int otherGroupId) {
+            var other = otherGroupId < otherState.states.size() ? otherState.states.get(otherGroupId) : null;
+            if (other == null) {
+                return;
+            }
+            ensureCapacity(groupId);
+            var curr = states.get(groupId);
+            if (curr == null) {
+                var len = other.entries();
+                adjustBreaker(LongRateState.bytesUsed(len));
+                curr = new LongRateState(Arrays.copyOf(other.timestamps, len), Arrays.copyOf(other.values, len));
+                curr.reset = other.reset;
+                states.set(groupId, curr);
+            } else {
+                states.set(groupId, mergeState(curr, other));
+            }
+        }
+
+        LongRateState mergeState(LongRateState s1, LongRateState s2) {
+            var newLen = s1.entries() + s2.entries();
+            adjustBreaker(LongRateState.bytesUsed(newLen));
+            var dst = new LongRateState(newLen);
+            dst.reset = s1.reset + s2.reset;
+            int i = 0, j = 0, k = 0;
+            while (i < s1.entries() && j < s2.entries()) {
+                if (s1.timestamps[i] > s2.timestamps[j]) {
+                    dst.timestamps[k] = s1.timestamps[i];
+                    dst.values[k] = s1.values[i];
+                    ++i;
+                } else {
+                    dst.timestamps[k] = s2.timestamps[j];
+                    dst.values[k] = s2.values[j];
+                    ++j;
+                }
+                ++k;
+            }
+            System.arraycopy(s1.timestamps, i, dst.timestamps, k, s1.entries() - i);
+            System.arraycopy(s1.values, i, dst.values, k, s1.entries() - i);
+            System.arraycopy(s2.timestamps, j, dst.timestamps, k, s2.entries() - j);
+            System.arraycopy(s2.values, j, dst.values, k, s2.entries() - j);
+            return dst;
+        }
+
         @Override
         public long ramBytesUsed() {
             return states.ramBytesUsed() + stateBytes;

+ 50 - 4
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateAggregator.java.st

@@ -27,6 +27,8 @@ import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 
+import java.util.Arrays;
+
 /**
  * A rate grouping aggregation definition for $type$.
  * This class is generated. Edit `X-RateAggregator.java.st` instead.
@@ -62,10 +64,10 @@ public class Rate$Type$Aggregator {
     public static void combineStates(
         $Type$RateGroupingState current,
         int currentGroupId, // make the stylecheck happy
-        $Type$RateGroupingState state,
-        int statePosition
+        $Type$RateGroupingState otherState,
+        int otherGroupId
     ) {
-        throw new UnsupportedOperationException("ordinals grouping is not supported yet");
+        current.combineState(currentGroupId, otherState, otherGroupId);
     }
 
     public static Block evaluateFinal($Type$RateGroupingState state, IntVector selected, DriverContext driverContext) {
@@ -166,6 +168,7 @@ public class Rate$Type$Aggregator {
             if (state == null) {
                 adjustBreaker($Type$RateState.bytesUsed(valueCount));
                 state = new $Type$RateState(valueCount);
+                state.reset = reset;
                 states.set(groupId, state);
                 // TODO: add bulk_copy to Block
                 for (int i = 0; i < valueCount; i++) {
@@ -175,11 +178,11 @@ public class Rate$Type$Aggregator {
             } else {
                 adjustBreaker($Type$RateState.bytesUsed(state.entries() + valueCount));
                 var newState = new $Type$RateState(state.entries() + valueCount);
+                newState.reset = state.reset + reset;
                 states.set(groupId, newState);
                 merge(state, newState, firstIndex, valueCount, timestamps, values);
                 adjustBreaker(-$Type$RateState.bytesUsed(state.entries())); // old state
             }
-            state.reset += reset;
         }
 
         void merge($Type$RateState curr, $Type$RateState dst, int firstIndex, int rightCount, LongBlock timestamps, $Type$Block values) {
@@ -211,6 +214,49 @@ public class Rate$Type$Aggregator {
             }
         }
 
+        void combineState(int groupId, $Type$RateGroupingState otherState, int otherGroupId) {
+            var other = otherGroupId < otherState.states.size() ? otherState.states.get(otherGroupId) : null;
+            if (other == null) {
+                return;
+            }
+            ensureCapacity(groupId);
+            var curr = states.get(groupId);
+            if (curr == null) {
+                var len = other.entries();
+                adjustBreaker($Type$RateState.bytesUsed(len));
+                curr = new $Type$RateState(Arrays.copyOf(other.timestamps, len), Arrays.copyOf(other.values, len));
+                curr.reset = other.reset;
+                states.set(groupId, curr);
+            } else {
+                states.set(groupId, mergeState(curr, other));
+            }
+        }
+
+        $Type$RateState mergeState($Type$RateState s1, $Type$RateState s2) {
+            var newLen = s1.entries() + s2.entries();
+            adjustBreaker($Type$RateState.bytesUsed(newLen));
+            var dst = new $Type$RateState(newLen);
+            dst.reset = s1.reset + s2.reset;
+            int i = 0, j = 0, k = 0;
+            while (i < s1.entries() && j < s2.entries()) {
+                if (s1.timestamps[i] > s2.timestamps[j]) {
+                    dst.timestamps[k] = s1.timestamps[i];
+                    dst.values[k] = s1.values[i];
+                    ++i;
+                } else {
+                    dst.timestamps[k] = s2.timestamps[j];
+                    dst.values[k] = s2.values[j];
+                    ++j;
+                }
+                ++k;
+            }
+            System.arraycopy(s1.timestamps, i, dst.timestamps, k, s1.entries() - i);
+            System.arraycopy(s1.values, i, dst.values, k, s1.entries() - i);
+            System.arraycopy(s2.timestamps, j, dst.timestamps, k, s2.entries() - j);
+            System.arraycopy(s2.values, j, dst.values, k, s2.entries() - j);
+            return dst;
+        }
+
         @Override
         public long ramBytesUsed() {
             return states.ramBytesUsed() + stateBytes;

+ 1 - 5
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSortedSourceOperatorFactory.java

@@ -143,14 +143,11 @@ public record TimeSeriesSortedSourceOperatorFactory(int limit, int maxPageSize,
                 }
                 iterator.consume();
                 shard = blockFactory.newConstantIntBlockWith(iterator.slice.shardContext().index(), currentPagePos);
-                boolean singleSegmentNonDecreasing;
                 if (iterator.slice.numLeaves() == 1) {
-                    singleSegmentNonDecreasing = true;
                     int segmentOrd = iterator.slice.getLeaf(0).leafReaderContext().ord;
                     leaf = blockFactory.newConstantIntBlockWith(segmentOrd, currentPagePos).asVector();
                 } else {
                     // Due to the multi segment nature of time series source operator singleSegmentNonDecreasing must be false
-                    singleSegmentNonDecreasing = false;
                     leaf = segmentsBuilder.build();
                     segmentsBuilder = blockFactory.newIntVectorBuilder(Math.min(remainingDocs, maxPageSize));
                 }
@@ -161,10 +158,9 @@ public record TimeSeriesSortedSourceOperatorFactory(int limit, int maxPageSize,
                 timestampIntervalBuilder = blockFactory.newLongVectorBuilder(Math.min(remainingDocs, maxPageSize));
                 tsids = tsOrdBuilder.build();
                 tsOrdBuilder = blockFactory.newIntVectorBuilder(Math.min(remainingDocs, maxPageSize));
-
                 page = new Page(
                     currentPagePos,
-                    new DocVector(shard.asVector(), leaf, docs, singleSegmentNonDecreasing).asBlock(),
+                    new DocVector(shard.asVector(), leaf, docs, leaf.isConstant()).asBlock(),
                     tsids.asBlock(),
                     timestampIntervals.asBlock()
                 );

+ 57 - 23
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/TimeSeriesSortedSourceOperatorTests.java

@@ -43,14 +43,17 @@ import org.elasticsearch.compute.operator.DriverContext;
 import org.elasticsearch.compute.operator.HashAggregationOperator;
 import org.elasticsearch.compute.operator.Operator;
 import org.elasticsearch.compute.operator.OperatorTestCase;
+import org.elasticsearch.compute.operator.OrdinalsGroupingOperator;
 import org.elasticsearch.compute.operator.TestResultPageSinkOperator;
 import org.elasticsearch.core.CheckedFunction;
 import org.elasticsearch.core.IOUtils;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.index.mapper.BlockDocValuesReader;
 import org.elasticsearch.index.mapper.DataStreamTimestampFieldMapper;
 import org.elasticsearch.index.mapper.DateFieldMapper;
 import org.elasticsearch.index.mapper.KeywordFieldMapper;
 import org.elasticsearch.index.mapper.NumberFieldMapper;
+import org.elasticsearch.index.mapper.SourceLoader;
 import org.elasticsearch.index.mapper.TimeSeriesIdFieldMapper;
 import org.junit.After;
 
@@ -285,17 +288,6 @@ public class TimeSeriesSortedSourceOperatorTests extends AnyOperatorTestCase {
             return docs.size();
         });
         var ctx = driverContext();
-        HashAggregationOperator initialHash = new HashAggregationOperator(
-            List.of(new RateLongAggregatorFunctionSupplier(List.of(4, 2), unitInMillis).groupingAggregatorFactory(AggregatorMode.INITIAL)),
-            () -> BlockHash.build(
-                List.of(new HashAggregationOperator.GroupSpec(3, ElementType.BYTES_REF)),
-                ctx.blockFactory(),
-                randomIntBetween(1, 1000),
-                randomBoolean()
-            ),
-            ctx
-        );
-
         HashAggregationOperator finalHash = new HashAggregationOperator(
             List.of(new RateLongAggregatorFunctionSupplier(List.of(1, 2, 3), unitInMillis).groupingAggregatorFactory(AggregatorMode.FINAL)),
             () -> BlockHash.build(
@@ -309,20 +301,62 @@ public class TimeSeriesSortedSourceOperatorTests extends AnyOperatorTestCase {
         List<Page> results = new ArrayList<>();
         var requestsField = new NumberFieldMapper.NumberFieldType("requests", NumberFieldMapper.NumberType.LONG);
         var podField = new KeywordFieldMapper.KeywordFieldType("pod");
-        OperatorTestCase.runDriver(
-            new Driver(
-                ctx,
-                sourceOperatorFactory.get(ctx),
+        if (randomBoolean()) {
+            HashAggregationOperator initialHash = new HashAggregationOperator(
                 List.of(
-                    ValuesSourceReaderOperatorTests.factory(reader, podField, ElementType.BYTES_REF).get(ctx),
-                    ValuesSourceReaderOperatorTests.factory(reader, requestsField, ElementType.LONG).get(ctx),
-                    initialHash,
-                    finalHash
+                    new RateLongAggregatorFunctionSupplier(List.of(4, 2), unitInMillis).groupingAggregatorFactory(AggregatorMode.INITIAL)
                 ),
-                new TestResultPageSinkOperator(results::add),
-                () -> {}
-            )
-        );
+                () -> BlockHash.build(
+                    List.of(new HashAggregationOperator.GroupSpec(3, ElementType.BYTES_REF)),
+                    ctx.blockFactory(),
+                    randomIntBetween(1, 1000),
+                    randomBoolean()
+                ),
+                ctx
+            );
+            OperatorTestCase.runDriver(
+                new Driver(
+                    ctx,
+                    sourceOperatorFactory.get(ctx),
+                    List.of(
+                        ValuesSourceReaderOperatorTests.factory(reader, podField, ElementType.BYTES_REF).get(ctx),
+                        ValuesSourceReaderOperatorTests.factory(reader, requestsField, ElementType.LONG).get(ctx),
+                        initialHash,
+                        finalHash
+                    ),
+                    new TestResultPageSinkOperator(results::add),
+                    () -> {}
+                )
+            );
+        } else {
+            var blockLoader = new BlockDocValuesReader.BytesRefsFromOrdsBlockLoader("pod");
+            var shardContext = new ValuesSourceReaderOperator.ShardContext(reader, () -> SourceLoader.FROM_STORED_SOURCE);
+            var ordinalGrouping = new OrdinalsGroupingOperator(
+                shardIdx -> blockLoader,
+                List.of(shardContext),
+                ElementType.BYTES_REF,
+                0,
+                "pod",
+                List.of(
+                    new RateLongAggregatorFunctionSupplier(List.of(3, 2), unitInMillis).groupingAggregatorFactory(AggregatorMode.INITIAL)
+                ),
+                randomIntBetween(1, 1000),
+                ctx
+            );
+            OperatorTestCase.runDriver(
+                new Driver(
+                    ctx,
+                    sourceOperatorFactory.get(ctx),
+                    List.of(
+                        ValuesSourceReaderOperatorTests.factory(reader, requestsField, ElementType.LONG).get(ctx),
+                        ordinalGrouping,
+                        finalHash
+                    ),
+                    new TestResultPageSinkOperator(results::add),
+                    () -> {}
+                )
+            );
+        }
         Map<String, Double> rates = new HashMap<>();
         for (Page result : results) {
             BytesRefBlock keysBlock = result.getBlock(0);