Browse Source

Simplify computation of resets in rate aggregation (#134700)

This change reworks the computation of resets to remove the need 
for using delta.
Nhat Nguyen 1 month ago
parent
commit
8cb29e0184

+ 78 - 107
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateDoubleGroupingAggregatorFunction.java

@@ -65,7 +65,7 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
     static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
         new IntermediateStateDesc("timestamps", ElementType.LONG),
         new IntermediateStateDesc("values", ElementType.DOUBLE),
-        new IntermediateStateDesc("sampleCounts", ElementType.INT),
+        new IntermediateStateDesc("sampleCounts", ElementType.LONG),
         new IntermediateStateDesc("resets", ElementType.DOUBLE)
     );
 
@@ -272,11 +272,11 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
         if (values.areAllValuesNull()) {
             return;
         }
-        IntVector sampleCounts = ((IntBlock) page.getBlock(channels.get(2))).asVector();
+        LongVector sampleCounts = ((LongBlock) page.getBlock(channels.get(2))).asVector();
         DoubleVector resets = ((DoubleBlock) page.getBlock(channels.get(3))).asVector();
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             int valuePosition = positionOffset + groupPosition;
-            int sampleCount = sampleCounts.getInt(valuePosition);
+            long sampleCount = sampleCounts.getLong(valuePosition);
             if (sampleCount == 0) {
                 continue;
             }
@@ -287,7 +287,7 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
                 state = new ReducedState();
                 reducedStates.set(groupId, state);
             }
-            state.appendValuesFromBlocks(timestamps, values, valuePosition);
+            state.appendIntervalsFromBlocks(timestamps, values, valuePosition);
             state.samples += sampleCount;
             state.resets += resets.getDouble(valuePosition);
         }
@@ -301,11 +301,11 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
         if (values.areAllValuesNull()) {
             return;
         }
-        IntVector sampleCounts = ((IntBlock) page.getBlock(channels.get(2))).asVector();
+        LongVector sampleCounts = ((LongBlock) page.getBlock(channels.get(2))).asVector();
         DoubleVector resets = ((DoubleBlock) page.getBlock(channels.get(3))).asVector();
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             int valuePosition = positionOffset + groupPosition;
-            int sampleCount = sampleCounts.getInt(valuePosition);
+            long sampleCount = sampleCounts.getLong(valuePosition);
             if (sampleCount == 0) {
                 continue;
             }
@@ -322,46 +322,44 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
                     state = new ReducedState();
                     reducedStates.set(groupId, state);
                 }
-                state.appendValuesFromBlocks(timestamps, values, valuePosition);
+                state.appendIntervalsFromBlocks(timestamps, values, valuePosition);
                 state.samples += sampleCount;
-                state.resets += resets.getDouble(groupPosition);
+                state.resets += resets.getDouble(valuePosition);
             }
         }
     }
 
     @Override
-    public final void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
+    public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
         BlockFactory blockFactory = driverContext.blockFactory();
         int positionCount = selected.getPositionCount();
         try (
             var timestamps = blockFactory.newLongBlockBuilder(positionCount * 2);
             var values = blockFactory.newDoubleBlockBuilder(positionCount * 2);
-            var sampleCounts = blockFactory.newIntVectorFixedBuilder(positionCount);
+            var sampleCounts = blockFactory.newLongVectorFixedBuilder(positionCount);
             var resets = blockFactory.newDoubleVectorFixedBuilder(positionCount)
         ) {
             for (int p = 0; p < positionCount; p++) {
                 int group = selected.getInt(p);
                 var state = flushAndCombineState(group);
-                if (state != null && state.timestamps.length > 0) {
-                    if (state.samples > 1) {
-                        timestamps.beginPositionEntry();
-                        values.beginPositionEntry();
-                        for (int s = 0; s < state.timestamps.length; s++) {
-                            timestamps.appendLong(state.timestamps[s]);
-                            values.appendDouble(state.values[s]);
-                        }
-                        timestamps.endPositionEntry();
-                        values.endPositionEntry();
-                    } else {
-                        timestamps.appendLong(state.timestamps[0]);
-                        values.appendDouble(state.values[0]);
+                // Do not combine intervals across shards because intervals from different indices may overlap.
+                if (state != null && state.samples > 0) {
+                    timestamps.beginPositionEntry();
+                    values.beginPositionEntry();
+                    for (Interval interval : state.intervals) {
+                        timestamps.appendLong(interval.t1);
+                        timestamps.appendLong(interval.t2);
+                        values.appendDouble(interval.v1);
+                        values.appendDouble(interval.v2);
                     }
-                    sampleCounts.appendInt(state.samples);
+                    timestamps.endPositionEntry();
+                    values.endPositionEntry();
+                    sampleCounts.appendLong(state.samples);
                     resets.appendDouble(state.resets);
                 } else {
                     timestamps.appendLong(0);
                     values.appendDouble(0);
-                    sampleCounts.appendInt(0);
+                    sampleCounts.appendLong(0);
                     resets.appendDouble(0);
                 }
             }
@@ -449,7 +447,9 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
             }
             if (pendingCount == 1) {
                 state.samples++;
-                state.appendOneValue(timestamps.get(0), values.get(0));
+                long t = timestamps.get(0);
+                double v = values.get(0);
+                state.appendInterval(new Interval(t, v, t, v));
                 return;
             }
             PriorityQueue<Slice> pq = mergeQueue();
@@ -468,7 +468,6 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
                 }
             }
             var prevValue = lastValue;
-            double reset = 0;
             int position = -1;
             while (pq.size() > 0) {
                 Slice top = pq.top();
@@ -479,12 +478,13 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
                     pq.updateTop();
                 }
                 var val = values.get(position);
-                reset += dv(val, prevValue) + dv(prevValue, lastValue) - dv(val, lastValue);
+                if (val > prevValue) {
+                    state.resets += val;
+                }
                 prevValue = val;
             }
             state.samples += pendingCount;
-            state.resets += reset;
-            state.appendTwoValues(lastTimestamp, lastValue, timestamps.get(position), prevValue);
+            state.appendInterval(new Interval(lastTimestamp, lastValue, timestamps.get(position), prevValue));
         }
 
         private PriorityQueue<Slice> mergeQueue() {
@@ -537,17 +537,27 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
     }
 
     @Override
-    public final void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evalContext) {
+    public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evalContext) {
         BlockFactory blockFactory = driverContext.blockFactory();
         int positionCount = selected.getPositionCount();
         try (var rates = blockFactory.newDoubleBlockBuilder(positionCount)) {
             for (int p = 0; p < positionCount; p++) {
                 int group = selected.getInt(p);
                 var state = flushAndCombineState(group);
-                if (state == null || state.timestamps.length < 2) {
+                if (state == null || state.samples < 2) {
                     rates.appendNull();
                     continue;
                 }
+                // combine intervals for the final evaluation
+                Interval[] intervals = state.intervals;
+                ArrayUtil.timSort(intervals);
+                for (int i = 1; i < intervals.length; i++) {
+                    Interval next = intervals[i - 1]; // reversed
+                    Interval prev = intervals[i];
+                    if (prev.v2 > next.v2) {
+                        state.resets += prev.v2;
+                    }
+                }
                 final double rate;
                 if (evalContext instanceof TimeSeriesGroupingAggregatorEvaluationContext tsContext) {
                     rate = extrapolateRate(state, tsContext.rangeStartInMillis(group), tsContext.rangeEndInMillis(group));
@@ -583,77 +593,50 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
         return sb.toString();
     }
 
+    record Interval(long t1, double v1, long t2, double v2) implements Comparable<Interval> {
+        @Override
+        public int compareTo(Interval other) {
+            return Long.compare(other.t1, t1); // want most recent first
+        }
+    }
+
     static final class ReducedState {
-        private static final long[] EMPTY_LONGS = new long[0];
-        private static final double[] EMPTY_VALUES = new double[0];
-        int samples;
+        private static final Interval[] EMPTY_INTERVALS = new Interval[0];
+        long samples;
         double resets;
-        long[] timestamps = EMPTY_LONGS;
-        double[] values = EMPTY_VALUES;
+        Interval[] intervals = EMPTY_INTERVALS;
 
-        void appendOneValue(long t, double v) {
-            int currentSize = timestamps.length;
-            this.timestamps = ArrayUtil.growExact(timestamps, currentSize + 1);
-            this.values = ArrayUtil.growExact(values, currentSize + 1);
-            this.timestamps[currentSize] = t;
-            this.values[currentSize] = v;
+        void appendInterval(Interval interval) {
+            int currentSize = intervals.length;
+            this.intervals = ArrayUtil.growExact(intervals, currentSize + 1);
+            this.intervals[currentSize] = interval;
         }
 
-        void appendTwoValues(long t1, double v1, long t2, double v2) {
-            int currentSize = timestamps.length;
-            this.timestamps = ArrayUtil.growExact(timestamps, currentSize + 2);
-            this.values = ArrayUtil.growExact(values, currentSize + 2);
-            this.timestamps[currentSize] = t1;
-            this.values[currentSize] = v1;
-            currentSize++;
-            this.timestamps[currentSize] = t2;
-            this.values[currentSize] = v2;
-        }
-
-        void appendValuesFromBlocks(LongBlock ts, DoubleBlock vs, int position) {
+        void appendIntervalsFromBlocks(LongBlock ts, DoubleBlock vs, int position) {
             int tsFirst = ts.getFirstValueIndex(position);
             int vsFirst = vs.getFirstValueIndex(position);
             int count = ts.getValueCount(position);
-            int total = timestamps.length + count;
-            long[] mergedTimestamps = new long[total];
-            double[] mergedValues = new double[total];
-            int i = 0, j = 0, k = 0;
-            while (i < timestamps.length && j < count) {
-                long t = ts.getLong(tsFirst + j);
-                if (timestamps[i] > t) {
-                    mergedTimestamps[k] = timestamps[i];
-                    mergedValues[k++] = values[i++];
-                } else {
-                    mergedTimestamps[k] = t;
-                    mergedValues[k++] = vs.getDouble(vsFirst + j++);
-                }
-            }
-            while (i < timestamps.length) {
-                mergedTimestamps[k] = timestamps[i];
-                mergedValues[k++] = values[i++];
-            }
-            while (j < count) {
-                mergedTimestamps[k] = ts.getLong(tsFirst + j);
-                mergedValues[k++] = vs.getDouble(vsFirst + j++);
+            assert count % 2 == 0 : "expected even number of values for intervals, got " + count + " in " + ts;
+            int currentSize = intervals.length;
+            intervals = ArrayUtil.growExact(intervals, currentSize + (count / 2));
+            for (int i = 0; i < count; i += 2) {
+                Interval interval = new Interval(
+                    ts.getLong(tsFirst + i),
+                    vs.getDouble(vsFirst + i),
+                    ts.getLong(tsFirst + i + 1),
+                    vs.getDouble(vsFirst + i + 1)
+                );
+                intervals[currentSize++] = interval;
             }
-            this.timestamps = mergedTimestamps;
-            this.values = mergedValues;
         }
     }
 
     private static double computeRateWithoutExtrapolate(ReducedState state) {
-        final int len = state.timestamps.length;
-        assert len >= 2 : "rate requires at least two samples; got " + len;
-        final long firstTS = state.timestamps[state.timestamps.length - 1];
-        final long lastTS = state.timestamps[0];
-        double reset = state.resets;
-        for (int i = 1; i < len; i++) {
-            if (state.values[i - 1] < state.values[i]) {
-                reset += state.values[i];
-            }
-        }
-        final double firstValue = state.values[len - 1];
-        final double lastValue = state.values[0] + reset;
+        assert state.samples >= 2 : "rate requires at least two samples; got " + state.samples;
+        final long firstTS = state.intervals[state.intervals.length - 1].t2;
+        final long lastTS = state.intervals[0].t1;
+        double firstValue = state.intervals[state.intervals.length - 1].v2;
+        double lastValue = state.intervals[0].v1 + state.resets;
         return (lastValue - firstValue) * 1000.0 / (lastTS - firstTS);
     }
 
@@ -667,18 +650,11 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
      * samples (which is our guess for where the series actually starts or ends).
      */
     private static double extrapolateRate(ReducedState state, long rangeStart, long rangeEnd) {
-        final int len = state.timestamps.length;
-        assert len >= 2 : "rate requires at least two samples; got " + len;
-        final long firstTS = state.timestamps[state.timestamps.length - 1];
-        final long lastTS = state.timestamps[0];
-        double reset = state.resets;
-        for (int i = 1; i < len; i++) {
-            if (state.values[i - 1] < state.values[i]) {
-                reset += state.values[i];
-            }
-        }
-        double firstValue = state.values[len - 1];
-        double lastValue = state.values[0] + reset;
+        assert state.samples >= 2 : "rate requires at least two samples; got " + state.samples;
+        final long firstTS = state.intervals[state.intervals.length - 1].t2;
+        final long lastTS = state.intervals[0].t1;
+        double firstValue = state.intervals[state.intervals.length - 1].v2;
+        double lastValue = state.intervals[0].v1 + state.resets;
         final double sampleTS = lastTS - firstTS;
         final double averageSampleInterval = sampleTS / state.samples;
         final double slope = (lastValue - firstValue) / sampleTS;
@@ -698,9 +674,4 @@ public final class RateDoubleGroupingAggregatorFunction implements GroupingAggre
         }
         return (lastValue - firstValue) * 1000.0 / (rangeEnd - rangeStart);
     }
-
-    // TODO: copied from old rate - simplify this or explain why we need it?
-    static double dv(double v0, double v1) {
-        return v0 > v1 ? v1 : v1 - v0;
-    }
 }

+ 78 - 107
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateIntGroupingAggregatorFunction.java

@@ -65,7 +65,7 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
     static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
         new IntermediateStateDesc("timestamps", ElementType.LONG),
         new IntermediateStateDesc("values", ElementType.INT),
-        new IntermediateStateDesc("sampleCounts", ElementType.INT),
+        new IntermediateStateDesc("sampleCounts", ElementType.LONG),
         new IntermediateStateDesc("resets", ElementType.DOUBLE)
     );
 
@@ -272,11 +272,11 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
         if (values.areAllValuesNull()) {
             return;
         }
-        IntVector sampleCounts = ((IntBlock) page.getBlock(channels.get(2))).asVector();
+        LongVector sampleCounts = ((LongBlock) page.getBlock(channels.get(2))).asVector();
         DoubleVector resets = ((DoubleBlock) page.getBlock(channels.get(3))).asVector();
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             int valuePosition = positionOffset + groupPosition;
-            int sampleCount = sampleCounts.getInt(valuePosition);
+            long sampleCount = sampleCounts.getLong(valuePosition);
             if (sampleCount == 0) {
                 continue;
             }
@@ -287,7 +287,7 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
                 state = new ReducedState();
                 reducedStates.set(groupId, state);
             }
-            state.appendValuesFromBlocks(timestamps, values, valuePosition);
+            state.appendIntervalsFromBlocks(timestamps, values, valuePosition);
             state.samples += sampleCount;
             state.resets += resets.getDouble(valuePosition);
         }
@@ -301,11 +301,11 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
         if (values.areAllValuesNull()) {
             return;
         }
-        IntVector sampleCounts = ((IntBlock) page.getBlock(channels.get(2))).asVector();
+        LongVector sampleCounts = ((LongBlock) page.getBlock(channels.get(2))).asVector();
         DoubleVector resets = ((DoubleBlock) page.getBlock(channels.get(3))).asVector();
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             int valuePosition = positionOffset + groupPosition;
-            int sampleCount = sampleCounts.getInt(valuePosition);
+            long sampleCount = sampleCounts.getLong(valuePosition);
             if (sampleCount == 0) {
                 continue;
             }
@@ -322,46 +322,44 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
                     state = new ReducedState();
                     reducedStates.set(groupId, state);
                 }
-                state.appendValuesFromBlocks(timestamps, values, valuePosition);
+                state.appendIntervalsFromBlocks(timestamps, values, valuePosition);
                 state.samples += sampleCount;
-                state.resets += resets.getDouble(groupPosition);
+                state.resets += resets.getDouble(valuePosition);
             }
         }
     }
 
     @Override
-    public final void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
+    public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
         BlockFactory blockFactory = driverContext.blockFactory();
         int positionCount = selected.getPositionCount();
         try (
             var timestamps = blockFactory.newLongBlockBuilder(positionCount * 2);
             var values = blockFactory.newIntBlockBuilder(positionCount * 2);
-            var sampleCounts = blockFactory.newIntVectorFixedBuilder(positionCount);
+            var sampleCounts = blockFactory.newLongVectorFixedBuilder(positionCount);
             var resets = blockFactory.newDoubleVectorFixedBuilder(positionCount)
         ) {
             for (int p = 0; p < positionCount; p++) {
                 int group = selected.getInt(p);
                 var state = flushAndCombineState(group);
-                if (state != null && state.timestamps.length > 0) {
-                    if (state.samples > 1) {
-                        timestamps.beginPositionEntry();
-                        values.beginPositionEntry();
-                        for (int s = 0; s < state.timestamps.length; s++) {
-                            timestamps.appendLong(state.timestamps[s]);
-                            values.appendInt(state.values[s]);
-                        }
-                        timestamps.endPositionEntry();
-                        values.endPositionEntry();
-                    } else {
-                        timestamps.appendLong(state.timestamps[0]);
-                        values.appendInt(state.values[0]);
+                // Do not combine intervals across shards because intervals from different indices may overlap.
+                if (state != null && state.samples > 0) {
+                    timestamps.beginPositionEntry();
+                    values.beginPositionEntry();
+                    for (Interval interval : state.intervals) {
+                        timestamps.appendLong(interval.t1);
+                        timestamps.appendLong(interval.t2);
+                        values.appendInt(interval.v1);
+                        values.appendInt(interval.v2);
                     }
-                    sampleCounts.appendInt(state.samples);
+                    timestamps.endPositionEntry();
+                    values.endPositionEntry();
+                    sampleCounts.appendLong(state.samples);
                     resets.appendDouble(state.resets);
                 } else {
                     timestamps.appendLong(0);
                     values.appendInt(0);
-                    sampleCounts.appendInt(0);
+                    sampleCounts.appendLong(0);
                     resets.appendDouble(0);
                 }
             }
@@ -449,7 +447,9 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
             }
             if (pendingCount == 1) {
                 state.samples++;
-                state.appendOneValue(timestamps.get(0), values.get(0));
+                long t = timestamps.get(0);
+                int v = values.get(0);
+                state.appendInterval(new Interval(t, v, t, v));
                 return;
             }
             PriorityQueue<Slice> pq = mergeQueue();
@@ -468,7 +468,6 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
                 }
             }
             var prevValue = lastValue;
-            double reset = 0;
             int position = -1;
             while (pq.size() > 0) {
                 Slice top = pq.top();
@@ -479,12 +478,13 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
                     pq.updateTop();
                 }
                 var val = values.get(position);
-                reset += dv(val, prevValue) + dv(prevValue, lastValue) - dv(val, lastValue);
+                if (val > prevValue) {
+                    state.resets += val;
+                }
                 prevValue = val;
             }
             state.samples += pendingCount;
-            state.resets += reset;
-            state.appendTwoValues(lastTimestamp, lastValue, timestamps.get(position), prevValue);
+            state.appendInterval(new Interval(lastTimestamp, lastValue, timestamps.get(position), prevValue));
         }
 
         private PriorityQueue<Slice> mergeQueue() {
@@ -537,17 +537,27 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
     }
 
     @Override
-    public final void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evalContext) {
+    public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evalContext) {
         BlockFactory blockFactory = driverContext.blockFactory();
         int positionCount = selected.getPositionCount();
         try (var rates = blockFactory.newDoubleBlockBuilder(positionCount)) {
             for (int p = 0; p < positionCount; p++) {
                 int group = selected.getInt(p);
                 var state = flushAndCombineState(group);
-                if (state == null || state.timestamps.length < 2) {
+                if (state == null || state.samples < 2) {
                     rates.appendNull();
                     continue;
                 }
+                // combine intervals for the final evaluation
+                Interval[] intervals = state.intervals;
+                ArrayUtil.timSort(intervals);
+                for (int i = 1; i < intervals.length; i++) {
+                    Interval next = intervals[i - 1]; // reversed
+                    Interval prev = intervals[i];
+                    if (prev.v2 > next.v2) {
+                        state.resets += prev.v2;
+                    }
+                }
                 final double rate;
                 if (evalContext instanceof TimeSeriesGroupingAggregatorEvaluationContext tsContext) {
                     rate = extrapolateRate(state, tsContext.rangeStartInMillis(group), tsContext.rangeEndInMillis(group));
@@ -583,77 +593,50 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
         return sb.toString();
     }
 
+    record Interval(long t1, int v1, long t2, int v2) implements Comparable<Interval> {
+        @Override
+        public int compareTo(Interval other) {
+            return Long.compare(other.t1, t1); // want most recent first
+        }
+    }
+
     static final class ReducedState {
-        private static final long[] EMPTY_LONGS = new long[0];
-        private static final int[] EMPTY_VALUES = new int[0];
-        int samples;
+        private static final Interval[] EMPTY_INTERVALS = new Interval[0];
+        long samples;
         double resets;
-        long[] timestamps = EMPTY_LONGS;
-        int[] values = EMPTY_VALUES;
+        Interval[] intervals = EMPTY_INTERVALS;
 
-        void appendOneValue(long t, int v) {
-            int currentSize = timestamps.length;
-            this.timestamps = ArrayUtil.growExact(timestamps, currentSize + 1);
-            this.values = ArrayUtil.growExact(values, currentSize + 1);
-            this.timestamps[currentSize] = t;
-            this.values[currentSize] = v;
+        void appendInterval(Interval interval) {
+            int currentSize = intervals.length;
+            this.intervals = ArrayUtil.growExact(intervals, currentSize + 1);
+            this.intervals[currentSize] = interval;
         }
 
-        void appendTwoValues(long t1, int v1, long t2, int v2) {
-            int currentSize = timestamps.length;
-            this.timestamps = ArrayUtil.growExact(timestamps, currentSize + 2);
-            this.values = ArrayUtil.growExact(values, currentSize + 2);
-            this.timestamps[currentSize] = t1;
-            this.values[currentSize] = v1;
-            currentSize++;
-            this.timestamps[currentSize] = t2;
-            this.values[currentSize] = v2;
-        }
-
-        void appendValuesFromBlocks(LongBlock ts, IntBlock vs, int position) {
+        void appendIntervalsFromBlocks(LongBlock ts, IntBlock vs, int position) {
             int tsFirst = ts.getFirstValueIndex(position);
             int vsFirst = vs.getFirstValueIndex(position);
             int count = ts.getValueCount(position);
-            int total = timestamps.length + count;
-            long[] mergedTimestamps = new long[total];
-            int[] mergedValues = new int[total];
-            int i = 0, j = 0, k = 0;
-            while (i < timestamps.length && j < count) {
-                long t = ts.getLong(tsFirst + j);
-                if (timestamps[i] > t) {
-                    mergedTimestamps[k] = timestamps[i];
-                    mergedValues[k++] = values[i++];
-                } else {
-                    mergedTimestamps[k] = t;
-                    mergedValues[k++] = vs.getInt(vsFirst + j++);
-                }
-            }
-            while (i < timestamps.length) {
-                mergedTimestamps[k] = timestamps[i];
-                mergedValues[k++] = values[i++];
-            }
-            while (j < count) {
-                mergedTimestamps[k] = ts.getLong(tsFirst + j);
-                mergedValues[k++] = vs.getInt(vsFirst + j++);
+            assert count % 2 == 0 : "expected even number of values for intervals, got " + count + " in " + ts;
+            int currentSize = intervals.length;
+            intervals = ArrayUtil.growExact(intervals, currentSize + (count / 2));
+            for (int i = 0; i < count; i += 2) {
+                Interval interval = new Interval(
+                    ts.getLong(tsFirst + i),
+                    vs.getInt(vsFirst + i),
+                    ts.getLong(tsFirst + i + 1),
+                    vs.getInt(vsFirst + i + 1)
+                );
+                intervals[currentSize++] = interval;
             }
-            this.timestamps = mergedTimestamps;
-            this.values = mergedValues;
         }
     }
 
     private static double computeRateWithoutExtrapolate(ReducedState state) {
-        final int len = state.timestamps.length;
-        assert len >= 2 : "rate requires at least two samples; got " + len;
-        final long firstTS = state.timestamps[state.timestamps.length - 1];
-        final long lastTS = state.timestamps[0];
-        double reset = state.resets;
-        for (int i = 1; i < len; i++) {
-            if (state.values[i - 1] < state.values[i]) {
-                reset += state.values[i];
-            }
-        }
-        final double firstValue = state.values[len - 1];
-        final double lastValue = state.values[0] + reset;
+        assert state.samples >= 2 : "rate requires at least two samples; got " + state.samples;
+        final long firstTS = state.intervals[state.intervals.length - 1].t2;
+        final long lastTS = state.intervals[0].t1;
+        double firstValue = state.intervals[state.intervals.length - 1].v2;
+        double lastValue = state.intervals[0].v1 + state.resets;
         return (lastValue - firstValue) * 1000.0 / (lastTS - firstTS);
     }
 
@@ -667,18 +650,11 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
      * samples (which is our guess for where the series actually starts or ends).
      */
     private static double extrapolateRate(ReducedState state, long rangeStart, long rangeEnd) {
-        final int len = state.timestamps.length;
-        assert len >= 2 : "rate requires at least two samples; got " + len;
-        final long firstTS = state.timestamps[state.timestamps.length - 1];
-        final long lastTS = state.timestamps[0];
-        double reset = state.resets;
-        for (int i = 1; i < len; i++) {
-            if (state.values[i - 1] < state.values[i]) {
-                reset += state.values[i];
-            }
-        }
-        double firstValue = state.values[len - 1];
-        double lastValue = state.values[0] + reset;
+        assert state.samples >= 2 : "rate requires at least two samples; got " + state.samples;
+        final long firstTS = state.intervals[state.intervals.length - 1].t2;
+        final long lastTS = state.intervals[0].t1;
+        double firstValue = state.intervals[state.intervals.length - 1].v2;
+        double lastValue = state.intervals[0].v1 + state.resets;
         final double sampleTS = lastTS - firstTS;
         final double averageSampleInterval = sampleTS / state.samples;
         final double slope = (lastValue - firstValue) / sampleTS;
@@ -698,9 +674,4 @@ public final class RateIntGroupingAggregatorFunction implements GroupingAggregat
         }
         return (lastValue - firstValue) * 1000.0 / (rangeEnd - rangeStart);
     }
-
-    // TODO: copied from old rate - simplify this or explain why we need it?
-    static double dv(double v0, double v1) {
-        return v0 > v1 ? v1 : v1 - v0;
-    }
 }

+ 78 - 107
x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/RateLongGroupingAggregatorFunction.java

@@ -65,7 +65,7 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
     static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
         new IntermediateStateDesc("timestamps", ElementType.LONG),
         new IntermediateStateDesc("values", ElementType.LONG),
-        new IntermediateStateDesc("sampleCounts", ElementType.INT),
+        new IntermediateStateDesc("sampleCounts", ElementType.LONG),
         new IntermediateStateDesc("resets", ElementType.DOUBLE)
     );
 
@@ -272,11 +272,11 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
         if (values.areAllValuesNull()) {
             return;
         }
-        IntVector sampleCounts = ((IntBlock) page.getBlock(channels.get(2))).asVector();
+        LongVector sampleCounts = ((LongBlock) page.getBlock(channels.get(2))).asVector();
         DoubleVector resets = ((DoubleBlock) page.getBlock(channels.get(3))).asVector();
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             int valuePosition = positionOffset + groupPosition;
-            int sampleCount = sampleCounts.getInt(valuePosition);
+            long sampleCount = sampleCounts.getLong(valuePosition);
             if (sampleCount == 0) {
                 continue;
             }
@@ -287,7 +287,7 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
                 state = new ReducedState();
                 reducedStates.set(groupId, state);
             }
-            state.appendValuesFromBlocks(timestamps, values, valuePosition);
+            state.appendIntervalsFromBlocks(timestamps, values, valuePosition);
             state.samples += sampleCount;
             state.resets += resets.getDouble(valuePosition);
         }
@@ -301,11 +301,11 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
         if (values.areAllValuesNull()) {
             return;
         }
-        IntVector sampleCounts = ((IntBlock) page.getBlock(channels.get(2))).asVector();
+        LongVector sampleCounts = ((LongBlock) page.getBlock(channels.get(2))).asVector();
         DoubleVector resets = ((DoubleBlock) page.getBlock(channels.get(3))).asVector();
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             int valuePosition = positionOffset + groupPosition;
-            int sampleCount = sampleCounts.getInt(valuePosition);
+            long sampleCount = sampleCounts.getLong(valuePosition);
             if (sampleCount == 0) {
                 continue;
             }
@@ -322,46 +322,44 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
                     state = new ReducedState();
                     reducedStates.set(groupId, state);
                 }
-                state.appendValuesFromBlocks(timestamps, values, valuePosition);
+                state.appendIntervalsFromBlocks(timestamps, values, valuePosition);
                 state.samples += sampleCount;
-                state.resets += resets.getDouble(groupPosition);
+                state.resets += resets.getDouble(valuePosition);
             }
         }
     }
 
     @Override
-    public final void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
+    public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
         BlockFactory blockFactory = driverContext.blockFactory();
         int positionCount = selected.getPositionCount();
         try (
             var timestamps = blockFactory.newLongBlockBuilder(positionCount * 2);
             var values = blockFactory.newLongBlockBuilder(positionCount * 2);
-            var sampleCounts = blockFactory.newIntVectorFixedBuilder(positionCount);
+            var sampleCounts = blockFactory.newLongVectorFixedBuilder(positionCount);
             var resets = blockFactory.newDoubleVectorFixedBuilder(positionCount)
         ) {
             for (int p = 0; p < positionCount; p++) {
                 int group = selected.getInt(p);
                 var state = flushAndCombineState(group);
-                if (state != null && state.timestamps.length > 0) {
-                    if (state.samples > 1) {
-                        timestamps.beginPositionEntry();
-                        values.beginPositionEntry();
-                        for (int s = 0; s < state.timestamps.length; s++) {
-                            timestamps.appendLong(state.timestamps[s]);
-                            values.appendLong(state.values[s]);
-                        }
-                        timestamps.endPositionEntry();
-                        values.endPositionEntry();
-                    } else {
-                        timestamps.appendLong(state.timestamps[0]);
-                        values.appendLong(state.values[0]);
+                // Do not combine intervals across shards because intervals from different indices may overlap.
+                if (state != null && state.samples > 0) {
+                    timestamps.beginPositionEntry();
+                    values.beginPositionEntry();
+                    for (Interval interval : state.intervals) {
+                        timestamps.appendLong(interval.t1);
+                        timestamps.appendLong(interval.t2);
+                        values.appendLong(interval.v1);
+                        values.appendLong(interval.v2);
                     }
-                    sampleCounts.appendInt(state.samples);
+                    timestamps.endPositionEntry();
+                    values.endPositionEntry();
+                    sampleCounts.appendLong(state.samples);
                     resets.appendDouble(state.resets);
                 } else {
                     timestamps.appendLong(0);
                     values.appendLong(0);
-                    sampleCounts.appendInt(0);
+                    sampleCounts.appendLong(0);
                     resets.appendDouble(0);
                 }
             }
@@ -449,7 +447,9 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
             }
             if (pendingCount == 1) {
                 state.samples++;
-                state.appendOneValue(timestamps.get(0), values.get(0));
+                long t = timestamps.get(0);
+                long v = values.get(0);
+                state.appendInterval(new Interval(t, v, t, v));
                 return;
             }
             PriorityQueue<Slice> pq = mergeQueue();
@@ -468,7 +468,6 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
                 }
             }
             var prevValue = lastValue;
-            double reset = 0;
             int position = -1;
             while (pq.size() > 0) {
                 Slice top = pq.top();
@@ -479,12 +478,13 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
                     pq.updateTop();
                 }
                 var val = values.get(position);
-                reset += dv(val, prevValue) + dv(prevValue, lastValue) - dv(val, lastValue);
+                if (val > prevValue) {
+                    state.resets += val;
+                }
                 prevValue = val;
             }
             state.samples += pendingCount;
-            state.resets += reset;
-            state.appendTwoValues(lastTimestamp, lastValue, timestamps.get(position), prevValue);
+            state.appendInterval(new Interval(lastTimestamp, lastValue, timestamps.get(position), prevValue));
         }
 
         private PriorityQueue<Slice> mergeQueue() {
@@ -537,17 +537,27 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
     }
 
     @Override
-    public final void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evalContext) {
+    public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evalContext) {
         BlockFactory blockFactory = driverContext.blockFactory();
         int positionCount = selected.getPositionCount();
         try (var rates = blockFactory.newDoubleBlockBuilder(positionCount)) {
             for (int p = 0; p < positionCount; p++) {
                 int group = selected.getInt(p);
                 var state = flushAndCombineState(group);
-                if (state == null || state.timestamps.length < 2) {
+                if (state == null || state.samples < 2) {
                     rates.appendNull();
                     continue;
                 }
+                // combine intervals for the final evaluation
+                Interval[] intervals = state.intervals;
+                ArrayUtil.timSort(intervals);
+                for (int i = 1; i < intervals.length; i++) {
+                    Interval next = intervals[i - 1]; // reversed
+                    Interval prev = intervals[i];
+                    if (prev.v2 > next.v2) {
+                        state.resets += prev.v2;
+                    }
+                }
                 final double rate;
                 if (evalContext instanceof TimeSeriesGroupingAggregatorEvaluationContext tsContext) {
                     rate = extrapolateRate(state, tsContext.rangeStartInMillis(group), tsContext.rangeEndInMillis(group));
@@ -583,77 +593,50 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
         return sb.toString();
     }
 
+    record Interval(long t1, long v1, long t2, long v2) implements Comparable<Interval> {
+        @Override
+        public int compareTo(Interval other) {
+            return Long.compare(other.t1, t1); // want most recent first
+        }
+    }
+
     static final class ReducedState {
-        private static final long[] EMPTY_LONGS = new long[0];
-        private static final long[] EMPTY_VALUES = new long[0];
-        int samples;
+        private static final Interval[] EMPTY_INTERVALS = new Interval[0];
+        long samples;
         double resets;
-        long[] timestamps = EMPTY_LONGS;
-        long[] values = EMPTY_VALUES;
+        Interval[] intervals = EMPTY_INTERVALS;
 
-        void appendOneValue(long t, long v) {
-            int currentSize = timestamps.length;
-            this.timestamps = ArrayUtil.growExact(timestamps, currentSize + 1);
-            this.values = ArrayUtil.growExact(values, currentSize + 1);
-            this.timestamps[currentSize] = t;
-            this.values[currentSize] = v;
+        void appendInterval(Interval interval) {
+            int currentSize = intervals.length;
+            this.intervals = ArrayUtil.growExact(intervals, currentSize + 1);
+            this.intervals[currentSize] = interval;
         }
 
-        void appendTwoValues(long t1, long v1, long t2, long v2) {
-            int currentSize = timestamps.length;
-            this.timestamps = ArrayUtil.growExact(timestamps, currentSize + 2);
-            this.values = ArrayUtil.growExact(values, currentSize + 2);
-            this.timestamps[currentSize] = t1;
-            this.values[currentSize] = v1;
-            currentSize++;
-            this.timestamps[currentSize] = t2;
-            this.values[currentSize] = v2;
-        }
-
-        void appendValuesFromBlocks(LongBlock ts, LongBlock vs, int position) {
+        void appendIntervalsFromBlocks(LongBlock ts, LongBlock vs, int position) {
             int tsFirst = ts.getFirstValueIndex(position);
             int vsFirst = vs.getFirstValueIndex(position);
             int count = ts.getValueCount(position);
-            int total = timestamps.length + count;
-            long[] mergedTimestamps = new long[total];
-            long[] mergedValues = new long[total];
-            int i = 0, j = 0, k = 0;
-            while (i < timestamps.length && j < count) {
-                long t = ts.getLong(tsFirst + j);
-                if (timestamps[i] > t) {
-                    mergedTimestamps[k] = timestamps[i];
-                    mergedValues[k++] = values[i++];
-                } else {
-                    mergedTimestamps[k] = t;
-                    mergedValues[k++] = vs.getLong(vsFirst + j++);
-                }
-            }
-            while (i < timestamps.length) {
-                mergedTimestamps[k] = timestamps[i];
-                mergedValues[k++] = values[i++];
-            }
-            while (j < count) {
-                mergedTimestamps[k] = ts.getLong(tsFirst + j);
-                mergedValues[k++] = vs.getLong(vsFirst + j++);
+            assert count % 2 == 0 : "expected even number of values for intervals, got " + count + " in " + ts;
+            int currentSize = intervals.length;
+            intervals = ArrayUtil.growExact(intervals, currentSize + (count / 2));
+            for (int i = 0; i < count; i += 2) {
+                Interval interval = new Interval(
+                    ts.getLong(tsFirst + i),
+                    vs.getLong(vsFirst + i),
+                    ts.getLong(tsFirst + i + 1),
+                    vs.getLong(vsFirst + i + 1)
+                );
+                intervals[currentSize++] = interval;
             }
-            this.timestamps = mergedTimestamps;
-            this.values = mergedValues;
         }
     }
 
     private static double computeRateWithoutExtrapolate(ReducedState state) {
-        final int len = state.timestamps.length;
-        assert len >= 2 : "rate requires at least two samples; got " + len;
-        final long firstTS = state.timestamps[state.timestamps.length - 1];
-        final long lastTS = state.timestamps[0];
-        double reset = state.resets;
-        for (int i = 1; i < len; i++) {
-            if (state.values[i - 1] < state.values[i]) {
-                reset += state.values[i];
-            }
-        }
-        final double firstValue = state.values[len - 1];
-        final double lastValue = state.values[0] + reset;
+        assert state.samples >= 2 : "rate requires at least two samples; got " + state.samples;
+        final long firstTS = state.intervals[state.intervals.length - 1].t2;
+        final long lastTS = state.intervals[0].t1;
+        double firstValue = state.intervals[state.intervals.length - 1].v2;
+        double lastValue = state.intervals[0].v1 + state.resets;
         return (lastValue - firstValue) * 1000.0 / (lastTS - firstTS);
     }
 
@@ -667,18 +650,11 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
      * samples (which is our guess for where the series actually starts or ends).
      */
     private static double extrapolateRate(ReducedState state, long rangeStart, long rangeEnd) {
-        final int len = state.timestamps.length;
-        assert len >= 2 : "rate requires at least two samples; got " + len;
-        final long firstTS = state.timestamps[state.timestamps.length - 1];
-        final long lastTS = state.timestamps[0];
-        double reset = state.resets;
-        for (int i = 1; i < len; i++) {
-            if (state.values[i - 1] < state.values[i]) {
-                reset += state.values[i];
-            }
-        }
-        double firstValue = state.values[len - 1];
-        double lastValue = state.values[0] + reset;
+        assert state.samples >= 2 : "rate requires at least two samples; got " + state.samples;
+        final long firstTS = state.intervals[state.intervals.length - 1].t2;
+        final long lastTS = state.intervals[0].t1;
+        double firstValue = state.intervals[state.intervals.length - 1].v2;
+        double lastValue = state.intervals[0].v1 + state.resets;
         final double sampleTS = lastTS - firstTS;
         final double averageSampleInterval = sampleTS / state.samples;
         final double slope = (lastValue - firstValue) / sampleTS;
@@ -698,9 +674,4 @@ public final class RateLongGroupingAggregatorFunction implements GroupingAggrega
         }
         return (lastValue - firstValue) * 1000.0 / (rangeEnd - rangeStart);
     }
-
-    // TODO: copied from old rate - simplify this or explain why we need it?
-    static double dv(double v0, double v1) {
-        return v0 > v1 ? v1 : v1 - v0;
-    }
 }

+ 78 - 107
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-RateGroupingAggregatorFunction.java.st

@@ -65,7 +65,7 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
     static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
         new IntermediateStateDesc("timestamps", ElementType.LONG),
         new IntermediateStateDesc("values", ElementType.$TYPE$),
-        new IntermediateStateDesc("sampleCounts", ElementType.INT),
+        new IntermediateStateDesc("sampleCounts", ElementType.LONG),
         new IntermediateStateDesc("resets", ElementType.DOUBLE)
     );
 
@@ -272,11 +272,11 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
         if (values.areAllValuesNull()) {
             return;
         }
-        IntVector sampleCounts = ((IntBlock) page.getBlock(channels.get(2))).asVector();
+        LongVector sampleCounts = ((LongBlock) page.getBlock(channels.get(2))).asVector();
         DoubleVector resets = ((DoubleBlock) page.getBlock(channels.get(3))).asVector();
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             int valuePosition = positionOffset + groupPosition;
-            int sampleCount = sampleCounts.getInt(valuePosition);
+            long sampleCount = sampleCounts.getLong(valuePosition);
             if (sampleCount == 0) {
                 continue;
             }
@@ -287,7 +287,7 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
                 state = new ReducedState();
                 reducedStates.set(groupId, state);
             }
-            state.appendValuesFromBlocks(timestamps, values, valuePosition);
+            state.appendIntervalsFromBlocks(timestamps, values, valuePosition);
             state.samples += sampleCount;
             state.resets += resets.getDouble(valuePosition);
         }
@@ -301,11 +301,11 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
         if (values.areAllValuesNull()) {
             return;
         }
-        IntVector sampleCounts = ((IntBlock) page.getBlock(channels.get(2))).asVector();
+        LongVector sampleCounts = ((LongBlock) page.getBlock(channels.get(2))).asVector();
         DoubleVector resets = ((DoubleBlock) page.getBlock(channels.get(3))).asVector();
         for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
             int valuePosition = positionOffset + groupPosition;
-            int sampleCount = sampleCounts.getInt(valuePosition);
+            long sampleCount = sampleCounts.getLong(valuePosition);
             if (sampleCount == 0) {
                 continue;
             }
@@ -322,46 +322,44 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
                     state = new ReducedState();
                     reducedStates.set(groupId, state);
                 }
-                state.appendValuesFromBlocks(timestamps, values, valuePosition);
+                state.appendIntervalsFromBlocks(timestamps, values, valuePosition);
                 state.samples += sampleCount;
-                state.resets += resets.getDouble(groupPosition);
+                state.resets += resets.getDouble(valuePosition);
             }
         }
     }
 
     @Override
-    public final void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
+    public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
         BlockFactory blockFactory = driverContext.blockFactory();
         int positionCount = selected.getPositionCount();
         try (
             var timestamps = blockFactory.newLongBlockBuilder(positionCount * 2);
             var values = blockFactory.new$Type$BlockBuilder(positionCount * 2);
-            var sampleCounts = blockFactory.newIntVectorFixedBuilder(positionCount);
+            var sampleCounts = blockFactory.newLongVectorFixedBuilder(positionCount);
             var resets = blockFactory.newDoubleVectorFixedBuilder(positionCount)
         ) {
             for (int p = 0; p < positionCount; p++) {
                 int group = selected.getInt(p);
                 var state = flushAndCombineState(group);
-                if (state != null && state.timestamps.length > 0) {
-                    if (state.samples > 1) {
-                        timestamps.beginPositionEntry();
-                        values.beginPositionEntry();
-                        for (int s = 0; s < state.timestamps.length; s++) {
-                            timestamps.appendLong(state.timestamps[s]);
-                            values.append$Type$(state.values[s]);
-                        }
-                        timestamps.endPositionEntry();
-                        values.endPositionEntry();
-                    } else {
-                        timestamps.appendLong(state.timestamps[0]);
-                        values.append$Type$(state.values[0]);
+                // Do not combine intervals across shards because intervals from different indices may overlap.
+                if (state != null && state.samples > 0) {
+                    timestamps.beginPositionEntry();
+                    values.beginPositionEntry();
+                    for (Interval interval : state.intervals) {
+                        timestamps.appendLong(interval.t1);
+                        timestamps.appendLong(interval.t2);
+                        values.append$Type$(interval.v1);
+                        values.append$Type$(interval.v2);
                     }
-                    sampleCounts.appendInt(state.samples);
+                    timestamps.endPositionEntry();
+                    values.endPositionEntry();
+                    sampleCounts.appendLong(state.samples);
                     resets.appendDouble(state.resets);
                 } else {
                     timestamps.appendLong(0);
                     values.append$Type$(0);
-                    sampleCounts.appendInt(0);
+                    sampleCounts.appendLong(0);
                     resets.appendDouble(0);
                 }
             }
@@ -449,7 +447,9 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
             }
             if (pendingCount == 1) {
                 state.samples++;
-                state.appendOneValue(timestamps.get(0), values.get(0));
+                long t = timestamps.get(0);
+                $type$ v = values.get(0);
+                state.appendInterval(new Interval(t, v, t, v));
                 return;
             }
             PriorityQueue<Slice> pq = mergeQueue();
@@ -468,7 +468,6 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
                 }
             }
             var prevValue = lastValue;
-            double reset = 0;
             int position = -1;
             while (pq.size() > 0) {
                 Slice top = pq.top();
@@ -479,12 +478,13 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
                     pq.updateTop();
                 }
                 var val = values.get(position);
-                reset += dv(val, prevValue) + dv(prevValue, lastValue) - dv(val, lastValue);
+                if (val > prevValue) {
+                    state.resets += val;
+                }
                 prevValue = val;
             }
             state.samples += pendingCount;
-            state.resets += reset;
-            state.appendTwoValues(lastTimestamp, lastValue, timestamps.get(position), prevValue);
+            state.appendInterval(new Interval(lastTimestamp, lastValue, timestamps.get(position), prevValue));
         }
 
         private PriorityQueue<Slice> mergeQueue() {
@@ -537,17 +537,27 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
     }
 
     @Override
-    public final void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evalContext) {
+    public void evaluateFinal(Block[] blocks, int offset, IntVector selected, GroupingAggregatorEvaluationContext evalContext) {
         BlockFactory blockFactory = driverContext.blockFactory();
         int positionCount = selected.getPositionCount();
         try (var rates = blockFactory.newDoubleBlockBuilder(positionCount)) {
             for (int p = 0; p < positionCount; p++) {
                 int group = selected.getInt(p);
                 var state = flushAndCombineState(group);
-                if (state == null || state.timestamps.length < 2) {
+                if (state == null || state.samples < 2) {
                     rates.appendNull();
                     continue;
                 }
+                // combine intervals for the final evaluation
+                Interval[] intervals = state.intervals;
+                ArrayUtil.timSort(intervals);
+                for (int i = 1; i < intervals.length; i++) {
+                    Interval next = intervals[i - 1]; // reversed
+                    Interval prev = intervals[i];
+                    if (prev.v2 > next.v2) {
+                        state.resets += prev.v2;
+                    }
+                }
                 final double rate;
                 if (evalContext instanceof TimeSeriesGroupingAggregatorEvaluationContext tsContext) {
                     rate = extrapolateRate(state, tsContext.rangeStartInMillis(group), tsContext.rangeEndInMillis(group));
@@ -583,77 +593,50 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
         return sb.toString();
     }
 
+    record Interval(long t1, $type$ v1, long t2, $type$ v2) implements Comparable<Interval> {
+        @Override
+        public int compareTo(Interval other) {
+            return Long.compare(other.t1, t1); // want most recent first
+        }
+    }
+
     static final class ReducedState {
-        private static final long[] EMPTY_LONGS = new long[0];
-        private static final $type$[] EMPTY_VALUES = new $type$[0];
-        int samples;
+        private static final Interval[] EMPTY_INTERVALS = new Interval[0];
+        long samples;
         double resets;
-        long[] timestamps = EMPTY_LONGS;
-        $type$[] values = EMPTY_VALUES;
+        Interval[] intervals = EMPTY_INTERVALS;
 
-        void appendOneValue(long t, $type$ v) {
-            int currentSize = timestamps.length;
-            this.timestamps = ArrayUtil.growExact(timestamps, currentSize + 1);
-            this.values = ArrayUtil.growExact(values, currentSize + 1);
-            this.timestamps[currentSize] = t;
-            this.values[currentSize] = v;
+        void appendInterval(Interval interval) {
+            int currentSize = intervals.length;
+            this.intervals = ArrayUtil.growExact(intervals, currentSize + 1);
+            this.intervals[currentSize] = interval;
         }
 
-        void appendTwoValues(long t1, $type$ v1, long t2, $type$ v2) {
-            int currentSize = timestamps.length;
-            this.timestamps = ArrayUtil.growExact(timestamps, currentSize + 2);
-            this.values = ArrayUtil.growExact(values, currentSize + 2);
-            this.timestamps[currentSize] = t1;
-            this.values[currentSize] = v1;
-            currentSize++;
-            this.timestamps[currentSize] = t2;
-            this.values[currentSize] = v2;
-        }
-
-        void appendValuesFromBlocks(LongBlock ts, $Type$Block vs, int position) {
+        void appendIntervalsFromBlocks(LongBlock ts, $Type$Block vs, int position) {
             int tsFirst = ts.getFirstValueIndex(position);
             int vsFirst = vs.getFirstValueIndex(position);
             int count = ts.getValueCount(position);
-            int total = timestamps.length + count;
-            long[] mergedTimestamps = new long[total];
-            $type$[] mergedValues = new $type$[total];
-            int i = 0, j = 0, k = 0;
-            while (i < timestamps.length && j < count) {
-                long t = ts.getLong(tsFirst + j);
-                if (timestamps[i] > t) {
-                    mergedTimestamps[k] = timestamps[i];
-                    mergedValues[k++] = values[i++];
-                } else {
-                    mergedTimestamps[k] = t;
-                    mergedValues[k++] = vs.get$Type$(vsFirst + j++);
-                }
-            }
-            while (i < timestamps.length) {
-                mergedTimestamps[k] = timestamps[i];
-                mergedValues[k++] = values[i++];
-            }
-            while (j < count) {
-                mergedTimestamps[k] = ts.getLong(tsFirst + j);
-                mergedValues[k++] = vs.get$Type$(vsFirst + j++);
+            assert count % 2 == 0 : "expected even number of values for intervals, got " + count + " in " + ts;
+            int currentSize = intervals.length;
+            intervals = ArrayUtil.growExact(intervals, currentSize + (count / 2));
+            for (int i = 0; i < count; i += 2) {
+                Interval interval = new Interval(
+                    ts.getLong(tsFirst + i),
+                    vs.get$Type$(vsFirst + i),
+                    ts.getLong(tsFirst + i + 1),
+                    vs.get$Type$(vsFirst + i + 1)
+                );
+                intervals[currentSize++] = interval;
             }
-            this.timestamps = mergedTimestamps;
-            this.values = mergedValues;
         }
     }
 
     private static double computeRateWithoutExtrapolate(ReducedState state) {
-        final int len = state.timestamps.length;
-        assert len >= 2 : "rate requires at least two samples; got " + len;
-        final long firstTS = state.timestamps[state.timestamps.length - 1];
-        final long lastTS = state.timestamps[0];
-        double reset = state.resets;
-        for (int i = 1; i < len; i++) {
-            if (state.values[i - 1] < state.values[i]) {
-                reset += state.values[i];
-            }
-        }
-        final double firstValue = state.values[len - 1];
-        final double lastValue = state.values[0] + reset;
+        assert state.samples >= 2 : "rate requires at least two samples; got " + state.samples;
+        final long firstTS = state.intervals[state.intervals.length - 1].t2;
+        final long lastTS = state.intervals[0].t1;
+        double firstValue = state.intervals[state.intervals.length - 1].v2;
+        double lastValue = state.intervals[0].v1 + state.resets;
         return (lastValue - firstValue) * 1000.0 / (lastTS - firstTS);
     }
 
@@ -667,18 +650,11 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
      * samples (which is our guess for where the series actually starts or ends).
      */
     private static double extrapolateRate(ReducedState state, long rangeStart, long rangeEnd) {
-        final int len = state.timestamps.length;
-        assert len >= 2 : "rate requires at least two samples; got " + len;
-        final long firstTS = state.timestamps[state.timestamps.length - 1];
-        final long lastTS = state.timestamps[0];
-        double reset = state.resets;
-        for (int i = 1; i < len; i++) {
-            if (state.values[i - 1] < state.values[i]) {
-                reset += state.values[i];
-            }
-        }
-        double firstValue = state.values[len - 1];
-        double lastValue = state.values[0] + reset;
+        assert state.samples >= 2 : "rate requires at least two samples; got " + state.samples;
+        final long firstTS = state.intervals[state.intervals.length - 1].t2;
+        final long lastTS = state.intervals[0].t1;
+        double firstValue = state.intervals[state.intervals.length - 1].v2;
+        double lastValue = state.intervals[0].v1 + state.resets;
         final double sampleTS = lastTS - firstTS;
         final double averageSampleInterval = sampleTS / state.samples;
         final double slope = (lastValue - firstValue) / sampleTS;
@@ -698,9 +674,4 @@ public final class Rate$Type$GroupingAggregatorFunction implements GroupingAggre
         }
         return (lastValue - firstValue) * 1000.0 / (rangeEnd - rangeStart);
     }
-
-    // TODO: copied from old rate - simplify this or explain why we need it?
-    static double dv(double v0, double v1) {
-        return v0 > v1 ? v1 : v1 - v0;
-    }
 }