1
0
Эх сурвалжийг харах

ESQL: Reserve memory TopN (#134235)

Tracks the more memory that's involved in topn.

## Lucene TopN
Lucene doesn't track memory usage for TopN and can use a fair bit of it.
Try this query:
```
FROM big_table
| SORT a, b, c, d, e
| LIMIT 1000000
| STATS MAX(a)
```

We attempt to return all million documents from lucene. Is we did this
with the compute engine we're track all of the memory usage. With lucene
we have to reserve it.

In the case of the query above the sort keys weight 8 bytes each. 40
bytes total. Plus another 72 for Lucene's `FieldDoc`. And another 40 at
least for copying to the values to `FieldDoc`. That totals something
like 152 bytes a piece. That's 145mb. Worth tracking!

## Esql Engine TopN

Esql *does* track memory for topn, but it doesn't track the memory used by the min heap itself. It's just a big array of pointers. But it can get very big!
Nik Everett 1 сар өмнө
parent
commit
e9c145b71f

+ 5 - 0
docs/changelog/134235.yaml

@@ -0,0 +1,5 @@
+pr: 134235
+summary: Reserve memory for Lucene's TopN
+area: ES|QL
+type: bug
+issues: []

+ 66 - 19
test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java

@@ -91,7 +91,7 @@ public class HeapAttackIT extends ESRestTestCase {
      * This used to fail, but we've since compacted top n so it actually succeeds now.
      */
     public void testSortByManyLongsSuccess() throws IOException {
-        initManyLongs();
+        initManyLongs(10);
         Map<String, Object> response = sortByManyLongs(500);
         ListMatcher columns = matchesList().item(matchesMap().entry("name", "a").entry("type", "long"))
             .item(matchesMap().entry("name", "b").entry("type", "long"));
@@ -108,7 +108,7 @@ public class HeapAttackIT extends ESRestTestCase {
      * This used to crash the node with an out of memory, but now it just trips a circuit breaker.
      */
     public void testSortByManyLongsTooMuchMemory() throws IOException {
-        initManyLongs();
+        initManyLongs(10);
         // 5000 is plenty to break on most nodes
         assertCircuitBreaks(attempt -> sortByManyLongs(attempt * 5000));
     }
@@ -117,7 +117,7 @@ public class HeapAttackIT extends ESRestTestCase {
      * This should record an async response with a {@link CircuitBreakingException}.
      */
     public void testSortByManyLongsTooMuchMemoryAsync() throws IOException {
-        initManyLongs();
+        initManyLongs(10);
         Request request = new Request("POST", "/_query/async");
         request.addParameter("error_trace", "");
         request.setJsonEntity(makeSortByManyLongs(5000).toString().replace("\n", "\\n"));
@@ -194,6 +194,29 @@ public class HeapAttackIT extends ESRestTestCase {
         );
     }
 
+    public void testSortByManyLongsGiantTopN() throws IOException {
+        initManyLongs(10);
+        assertMap(
+            sortBySomeLongsLimit(100000),
+            matchesMap().entry("took", greaterThan(0))
+                .entry("is_partial", false)
+                .entry("columns", List.of(Map.of("name", "MAX(a)", "type", "long")))
+                .entry("values", List.of(List.of(9)))
+                .entry("documents_found", greaterThan(0))
+                .entry("values_loaded", greaterThan(0))
+        );
+    }
+
+    public void testSortByManyLongsGiantTopNTooMuchMemory() throws IOException {
+        initManyLongs(20);
+        assertCircuitBreaks(attempt -> sortBySomeLongsLimit(attempt * 500000));
+    }
+
+    public void testStupidTopN() throws IOException {
+        initManyLongs(1); // Doesn't actually matter how much data there is.
+        assertCircuitBreaks(attempt -> sortBySomeLongsLimit(2147483630));
+    }
+
     private static final int MAX_ATTEMPTS = 5;
 
     interface TryCircuitBreaking {
@@ -252,11 +275,25 @@ public class HeapAttackIT extends ESRestTestCase {
         return query;
     }
 
+    private Map<String, Object> sortBySomeLongsLimit(int count) throws IOException {
+        logger.info("sorting by 5 longs, keeping {}", count);
+        return responseAsMap(query(makeSortBySomeLongsLimit(count), null));
+    }
+
+    private String makeSortBySomeLongsLimit(int count) {
+        StringBuilder query = new StringBuilder("{\"query\": \"FROM manylongs\n");
+        query.append("| SORT a, b, c, d, e\n");
+        query.append("| LIMIT ").append(count).append("\n");
+        query.append("| STATS MAX(a)\n");
+        query.append("\"}");
+        return query.toString();
+    }
+
     /**
      * This groups on about 200 columns which is a lot but has never caused us trouble.
      */
     public void testGroupOnSomeLongs() throws IOException {
-        initManyLongs();
+        initManyLongs(10);
         Response resp = groupOnManyLongs(200);
         Map<String, Object> map = responseAsMap(resp);
         ListMatcher columns = matchesList().item(matchesMap().entry("name", "MAX(a)").entry("type", "long"));
@@ -268,7 +305,7 @@ public class HeapAttackIT extends ESRestTestCase {
      * This groups on 5000 columns which used to throw a {@link StackOverflowError}.
      */
     public void testGroupOnManyLongs() throws IOException {
-        initManyLongs();
+        initManyLongs(10);
         Response resp = groupOnManyLongs(5000);
         Map<String, Object> map = responseAsMap(resp);
         ListMatcher columns = matchesList().item(matchesMap().entry("name", "MAX(a)").entry("type", "long"));
@@ -336,7 +373,7 @@ public class HeapAttackIT extends ESRestTestCase {
      */
     public void testManyConcat() throws IOException {
         int strings = 300;
-        initManyLongs();
+        initManyLongs(10);
         assertManyStrings(manyConcat("FROM manylongs", strings), strings);
     }
 
@@ -344,7 +381,7 @@ public class HeapAttackIT extends ESRestTestCase {
      * Hits a circuit breaker by building many moderately long strings.
      */
     public void testHugeManyConcat() throws IOException {
-        initManyLongs();
+        initManyLongs(10);
         // 2000 is plenty to break on most nodes
         assertCircuitBreaks(attempt -> manyConcat("FROM manylongs", attempt * 2000));
     }
@@ -415,7 +452,7 @@ public class HeapAttackIT extends ESRestTestCase {
      */
     public void testManyRepeat() throws IOException {
         int strings = 30;
-        initManyLongs();
+        initManyLongs(10);
         assertManyStrings(manyRepeat("FROM manylongs", strings), 30);
     }
 
@@ -423,7 +460,7 @@ public class HeapAttackIT extends ESRestTestCase {
      * Hits a circuit breaker by building many moderately long strings.
      */
     public void testHugeManyRepeat() throws IOException {
-        initManyLongs();
+        initManyLongs(10);
         // 75 is plenty to break on most nodes
         assertCircuitBreaks(attempt -> manyRepeat("FROM manylongs", attempt * 75));
     }
@@ -481,7 +518,7 @@ public class HeapAttackIT extends ESRestTestCase {
     }
 
     public void testManyEval() throws IOException {
-        initManyLongs();
+        initManyLongs(10);
         Map<String, Object> response = manyEval(1);
         ListMatcher columns = matchesList();
         columns = columns.item(matchesMap().entry("name", "a").entry("type", "long"));
@@ -496,7 +533,7 @@ public class HeapAttackIT extends ESRestTestCase {
     }
 
     public void testTooManyEval() throws IOException {
-        initManyLongs();
+        initManyLongs(10);
         // 490 is plenty to fail on most nodes
         assertCircuitBreaks(attempt -> manyEval(attempt * 490));
     }
@@ -855,24 +892,34 @@ public class HeapAttackIT extends ESRestTestCase {
         }
     }
 
-    private void initManyLongs() throws IOException {
+    private void initManyLongs(int countPerLong) throws IOException {
         logger.info("loading many documents with longs");
         StringBuilder bulk = new StringBuilder();
-        for (int a = 0; a < 10; a++) {
-            for (int b = 0; b < 10; b++) {
-                for (int c = 0; c < 10; c++) {
-                    for (int d = 0; d < 10; d++) {
-                        for (int e = 0; e < 10; e++) {
+        int flush = 0;
+        for (int a = 0; a < countPerLong; a++) {
+            for (int b = 0; b < countPerLong; b++) {
+                for (int c = 0; c < countPerLong; c++) {
+                    for (int d = 0; d < countPerLong; d++) {
+                        for (int e = 0; e < countPerLong; e++) {
                             bulk.append(String.format(Locale.ROOT, """
                                 {"create":{}}
                                 {"a":%d,"b":%d,"c":%d,"d":%d,"e":%d}
                                 """, a, b, c, d, e));
+                            flush++;
+                            if (flush % 10_000 == 0) {
+                                bulk("manylongs", bulk.toString());
+                                bulk.setLength(0);
+                                logger.info(
+                                    "flushing {}/{} to manylongs",
+                                    flush,
+                                    countPerLong * countPerLong * countPerLong * countPerLong * countPerLong
+                                );
+
+                            }
                         }
                     }
                 }
             }
-            bulk("manylongs", bulk.toString());
-            bulk.setLength(0);
         }
         initIndex("manylongs", bulk.toString());
     }

+ 150 - 28
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java

@@ -16,10 +16,14 @@ import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.Sort;
 import org.apache.lucene.search.SortField;
+import org.apache.lucene.search.SortedNumericSortField;
+import org.apache.lucene.search.SortedSetSortField;
 import org.apache.lucene.search.TopDocsCollector;
 import org.apache.lucene.search.TopFieldCollectorManager;
 import org.apache.lucene.search.TopScoreDocCollectorManager;
+import org.apache.lucene.util.RamUsageEstimator;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.DocBlock;
 import org.elasticsearch.compute.data.DocVector;
@@ -44,7 +48,11 @@ import java.util.function.Function;
 import java.util.stream.Collectors;
 
 /**
- * Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
+ * Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN).
+ * <p>
+ *     Makes {@link Page}s of the shape {@code (docBlock)} or {@code (docBlock, score)}.
+ *     Lucene loads the sort keys, but we don't read them from lucene. Yet. We should.
+ * </p>
  */
 public final class LuceneTopNSourceOperator extends LuceneOperator {
 
@@ -52,6 +60,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
         private final List<? extends ShardContext> contexts;
         private final int maxPageSize;
         private final List<SortBuilder<?>> sorts;
+        private final long estimatedPerRowSortSize;
 
         public Factory(
             List<? extends ShardContext> contexts,
@@ -61,6 +70,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
             int maxPageSize,
             int limit,
             List<SortBuilder<?>> sorts,
+            long estimatedPerRowSortSize,
             boolean needsScore
         ) {
             super(
@@ -76,11 +86,22 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
             this.contexts = contexts;
             this.maxPageSize = maxPageSize;
             this.sorts = sorts;
+            this.estimatedPerRowSortSize = estimatedPerRowSortSize;
         }
 
         @Override
         public SourceOperator get(DriverContext driverContext) {
-            return new LuceneTopNSourceOperator(contexts, driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, needsScore);
+            return new LuceneTopNSourceOperator(
+                contexts,
+                driverContext.breaker(),
+                driverContext.blockFactory(),
+                maxPageSize,
+                sorts,
+                estimatedPerRowSortSize,
+                limit,
+                sliceQueue,
+                needsScore
+            );
         }
 
         public int maxPageSize() {
@@ -104,10 +125,16 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
         }
     }
 
+    private final CircuitBreaker breaker;
+    private final List<SortBuilder<?>> sorts;
+    private final long estimatedPerRowSortSize;
+    private final int limit;
+    private final boolean needsScore;
+
     /**
-     * Collected docs. {@code null} until we're {@link #emit(boolean)}.
+     * Collected docs. {@code null} until we're ready to {@link #emit()}.
      */
-    private ScoreDoc[] scoreDocs;
+    private ScoreDoc[] topDocs;
 
     /**
      * {@link ShardRefCounted} for collected docs.
@@ -115,28 +142,30 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
     private ShardRefCounted shardRefCounted;
 
     /**
-     * The offset in {@link #scoreDocs} of the next page.
+     * The offset in {@link #topDocs} of the next page.
      */
     private int offset = 0;
 
     private PerShardCollector perShardCollector;
-    private final List<SortBuilder<?>> sorts;
-    private final int limit;
-    private final boolean needsScore;
 
     public LuceneTopNSourceOperator(
         List<? extends ShardContext> contexts,
+        CircuitBreaker breaker,
         BlockFactory blockFactory,
         int maxPageSize,
         List<SortBuilder<?>> sorts,
+        long estimatedPerRowSortSize,
         int limit,
         LuceneSliceQueue sliceQueue,
         boolean needsScore
     ) {
         super(contexts, blockFactory, maxPageSize, sliceQueue);
+        this.breaker = breaker;
         this.sorts = sorts;
+        this.estimatedPerRowSortSize = estimatedPerRowSortSize;
         this.limit = limit;
         this.needsScore = needsScore;
+        breaker.addEstimateBytesAndMaybeBreak(reserveSize(), "esql lucene topn");
     }
 
     @Override
@@ -147,7 +176,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
     @Override
     public void finish() {
         doneCollecting = true;
-        scoreDocs = null;
+        topDocs = null;
         shardRefCounted = null;
         assert isFinished();
     }
@@ -160,7 +189,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
         long start = System.nanoTime();
         try {
             if (isEmitting()) {
-                return emit(false);
+                return emit();
             } else {
                 return collect();
             }
@@ -174,7 +203,8 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
         var scorer = getCurrentOrLoadNextScorer();
         if (scorer == null) {
             doneCollecting = true;
-            return emit(true);
+            startEmitting();
+            return emit();
         }
         try {
             if (scorer.tags().isEmpty() == false) {
@@ -193,32 +223,47 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
         if (scorer.isDone()) {
             var nextScorer = getCurrentOrLoadNextScorer();
             if (nextScorer == null || nextScorer.shardContext().index() != scorer.shardContext().index()) {
-                return emit(true);
+                startEmitting();
+                return emit();
             }
         }
         return null;
     }
 
     private boolean isEmitting() {
-        return scoreDocs != null && offset < scoreDocs.length;
+        return topDocs != null;
     }
 
-    private Page emit(boolean startEmitting) {
-        if (startEmitting) {
-            assert isEmitting() == false : "offset=" + offset + " score_docs=" + Arrays.toString(scoreDocs);
-            offset = 0;
-            if (perShardCollector != null) {
-                scoreDocs = perShardCollector.collector.topDocs().scoreDocs;
-                int shardId = perShardCollector.shardContext.index();
-                shardRefCounted = new ShardRefCounted.Single(shardId, shardContextCounters.get(shardId));
-            } else {
-                scoreDocs = new ScoreDoc[0];
-            }
+    private void startEmitting() {
+        assert isEmitting() == false : "offset=" + offset + " score_docs=" + Arrays.toString(topDocs);
+        offset = 0;
+        if (perShardCollector != null) {
+            /*
+             * Important note for anyone who looks at this and has bright ideas:
+             * There *is* a method in lucene to return topDocs with an offset
+             * and a limit. So you'd *think* you can scroll the top docs there.
+             * But you can't. It's expressly forbidden to call any of the `topDocs`
+             * methods more than once. You *must* call `topDocs` once and use the
+             * array.
+             */
+            topDocs = perShardCollector.collector.topDocs().scoreDocs;
+            int shardId = perShardCollector.shardContext.index();
+            shardRefCounted = new ShardRefCounted.Single(shardId, shardContextCounters.get(shardId));
+        } else {
+            topDocs = new ScoreDoc[0];
         }
-        if (offset >= scoreDocs.length) {
+    }
+
+    private void stopEmitting() {
+        topDocs = null;
+    }
+
+    private Page emit() {
+        if (offset >= topDocs.length) {
+            stopEmitting();
             return null;
         }
-        int size = Math.min(maxPageSize, scoreDocs.length - offset);
+        int size = Math.min(maxPageSize, topDocs.length - offset);
         IntBlock shard = null;
         IntVector segments = null;
         IntVector docs = null;
@@ -234,14 +279,16 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
             offset += size;
             List<LeafReaderContext> leafContexts = perShardCollector.shardContext.searcher().getLeafContexts();
             for (int i = start; i < offset; i++) {
-                int doc = scoreDocs[i].doc;
+                int doc = topDocs[i].doc;
                 int segment = ReaderUtil.subIndex(doc, leafContexts);
                 currentSegmentBuilder.appendInt(segment);
                 currentDocsBuilder.appendInt(doc - leafContexts.get(segment).docBase); // the offset inside the segment
                 if (currentScoresBuilder != null) {
-                    float score = getScore(scoreDocs[i]);
+                    float score = getScore(topDocs[i]);
                     currentScoresBuilder.appendDouble(score);
                 }
+                // Null the top doc so it can be GCed early, just in case.
+                topDocs[i] = null;
             }
 
             int shardId = perShardCollector.shardContext.index();
@@ -298,9 +345,20 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
         sb.append(", sorts = [").append(notPrettySorts).append("]");
     }
 
+    @Override
+    protected void additionalClose() {
+        Releasables.close(() -> breaker.addWithoutBreaking(-reserveSize()));
+    }
+
+    private long reserveSize() {
+        long perRowSize = FIELD_DOC_SIZE + estimatedPerRowSortSize;
+        return limit * perRowSize;
+    }
+
     abstract static class PerShardCollector {
         private final ShardContext shardContext;
         private final TopDocsCollector<?> collector;
+
         private int leafIndex;
         private LeafCollector leafCollector;
         private Thread currentThread;
@@ -366,4 +424,68 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
         sort = new Sort(l.toArray(SortField[]::new));
         return new ScoringPerShardCollector(context, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
     }
+
+    private static int perDocMemoryUsage(SortField[] sorts) {
+        int usage = FIELD_DOC_SIZE;
+        for (SortField sort : sorts) {
+            usage += perDocMemoryUsage(sort);
+        }
+        return usage;
+    }
+
+    private static int perDocMemoryUsage(SortField sort) {
+        if (sort.getType() == SortField.Type.CUSTOM) {
+            return perDocMemoryUsageForCustom(sort);
+        }
+        return perDocMemoryUsageByType(sort, sort.getType());
+
+    }
+
+    private static int perDocMemoryUsageByType(SortField sort, SortField.Type type) {
+        return switch (type) {
+            case SCORE, DOC ->
+                /* SCORE and DOC are always part of ScoreDoc/FieldDoc
+                 * So they are in FIELD_DOC_SIZE already.
+                 * And they can't be removed. */
+                0;
+            case DOUBLE, LONG ->
+                // 8 for the long, 8 for the long copied to the topDoc.
+                16;
+            case INT, FLOAT ->
+                // 4 for the int, 8 boxed object copied to topDoc.
+                12;
+            case STRING ->
+                /* `keyword`-like fields. Compares ordinals when possible, otherwise
+                 * the strings. Does a bunch of deduplication, but in the worst
+                 * case we end up with the string itself, plus two BytesRefs. Let's
+                 * presume short-ish strings. */
+                1024;
+            case STRING_VAL ->
+                /* Other string fields. Compares the string itself. Let's assume two
+                 * 2kb per string because they tend to be bigger than the keyword
+                 * versions. */
+                2048;
+            case CUSTOM -> throw new IllegalArgumentException("unsupported type " + sort.getClass() + ": " + sort);
+            case REWRITEABLE -> {
+                assert false : "rewriteable  " + sort.getClass() + ": " + sort;
+                yield 2048;
+            }
+        };
+    }
+
+    private static int perDocMemoryUsageForCustom(SortField sort) {
+        return switch (sort) {
+            case SortedNumericSortField f -> perDocMemoryUsageByType(f, f.getNumericType());
+            case SortedSetSortField f -> perDocMemoryUsageByType(f, SortField.Type.STRING);
+            default -> {
+                if (sort.getClass().getName().equals("org.apache.lucene.document.LatLonPointSortField")) {
+                    yield perDocMemoryUsageByType(sort, SortField.Type.DOUBLE);
+                }
+                assert false : "unknown type " + sort.getClass() + ": " + sort;
+                yield 2048;
+            }
+        };
+    }
+
+    private static final int FIELD_DOC_SIZE = Math.toIntExact(RamUsageEstimator.shallowSizeOf(FieldDoc.class));
 }

+ 54 - 22
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/topn/TopNOperator.java

@@ -294,7 +294,6 @@ public class TopNOperator implements Operator, Accountable {
 
     private final BlockFactory blockFactory;
     private final CircuitBreaker breaker;
-    private final Queue inputQueue;
 
     private final int maxPageSize;
 
@@ -302,6 +301,7 @@ public class TopNOperator implements Operator, Accountable {
     private final List<TopNEncoder> encoders;
     private final List<SortOrder> sortOrders;
 
+    private Queue inputQueue;
     private Row spare;
     private int spareValuesPreAllocSize = 0;
     private int spareKeysPreAllocSize = 0;
@@ -346,7 +346,7 @@ public class TopNOperator implements Operator, Accountable {
         this.elementTypes = elementTypes;
         this.encoders = encoders;
         this.sortOrders = sortOrders;
-        this.inputQueue = new Queue(topCount);
+        this.inputQueue = Queue.build(breaker, topCount);
     }
 
     static int compareRows(Row r1, Row r2) {
@@ -457,6 +457,8 @@ public class TopNOperator implements Operator, Accountable {
                 list.add(inputQueue.pop());
             }
             Collections.reverse(list);
+            inputQueue.close();
+            inputQueue = null;
 
             int p = 0;
             int size = 0;
@@ -563,19 +565,27 @@ public class TopNOperator implements Operator, Accountable {
 
     @Override
     public void close() {
-        /*
-         * If we close before calling finish then spare and inputQueue will be live rows
-         * that need closing. If we close after calling finish then the output iterator
-         * will contain pages of results that have yet to be returned.
-         */
         Releasables.closeExpectNoException(
+            /*
+             * The spare is used during most collections. It's cleared when this Operator
+             * is finish()ed. So it could be null here.
+             */
             spare,
-            inputQueue == null ? null : Releasables.wrap(inputQueue),
+            /*
+             * The inputQueue is a min heap of all live rows. Closing it will close all
+             * the rows it contains and all decrement the breaker for the size of
+             * the heap itself.
+             */
+            inputQueue,
+            /*
+             * If we're in the process of outputting pages then output will contain all
+             * allocated but un-emitted pages.
+             */
             output == null ? null : Releasables.wrap(() -> Iterators.map(output, p -> p::releaseBlocks))
         );
     }
 
-    private static long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNOperator.class) + RamUsageEstimator
+    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TopNOperator.class) + RamUsageEstimator
         .shallowSizeOfInstance(List.class) * 3;
 
     @Override
@@ -589,7 +599,9 @@ public class TopNOperator implements Operator, Accountable {
         size += RamUsageEstimator.alignObjectSize(arrHeader + ref * encoders.size());
         size += RamUsageEstimator.alignObjectSize(arrHeader + ref * sortOrders.size());
         size += sortOrders.size() * SortOrder.SHALLOW_SIZE;
-        size += inputQueue.ramBytesUsed();
+        if (inputQueue != null) {
+            size += inputQueue.ramBytesUsed();
+        }
         return size;
     }
 
@@ -598,7 +610,7 @@ public class TopNOperator implements Operator, Accountable {
         return new TopNOperatorStatus(
             receiveNanos,
             emitNanos,
-            inputQueue.size(),
+            inputQueue != null ? inputQueue.size() : 0,
             ramBytesUsed(),
             pagesReceived,
             pagesEmitted,
@@ -620,17 +632,23 @@ public class TopNOperator implements Operator, Accountable {
             + "]";
     }
 
-    CircuitBreaker breaker() {
-        return breaker;
-    }
-
-    private static class Queue extends PriorityQueue<Row> implements Accountable {
+    private static class Queue extends PriorityQueue<Row> implements Accountable, Releasable {
         private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Queue.class);
-        private final int maxSize;
+        private final CircuitBreaker breaker;
+        private final int topCount;
 
-        Queue(int maxSize) {
-            super(maxSize);
-            this.maxSize = maxSize;
+        /**
+         * Track memory usage in the breaker then build the {@link Queue}.
+         */
+        static Queue build(CircuitBreaker breaker, int topCount) {
+            breaker.addEstimateBytesAndMaybeBreak(Queue.sizeOf(topCount), "esql engine topn");
+            return new Queue(breaker, topCount);
+        }
+
+        private Queue(CircuitBreaker breaker, int topCount) {
+            super(topCount);
+            this.breaker = breaker;
+            this.topCount = topCount;
         }
 
         @Override
@@ -640,19 +658,33 @@ public class TopNOperator implements Operator, Accountable {
 
         @Override
         public String toString() {
-            return size() + "/" + maxSize;
+            return size() + "/" + topCount;
         }
 
         @Override
         public long ramBytesUsed() {
             long total = SHALLOW_SIZE;
             total += RamUsageEstimator.alignObjectSize(
-                RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * (maxSize + 1)
+                RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1)
             );
             for (Row r : this) {
                 total += r == null ? 0 : r.ramBytesUsed();
             }
             return total;
         }
+
+        @Override
+        public void close() {
+            Releasables.close(Releasables.wrap(this), () -> breaker.addWithoutBreaking(-Queue.sizeOf(topCount)));
+
+        }
+
+        public static long sizeOf(int topCount) {
+            long total = SHALLOW_SIZE;
+            total += RamUsageEstimator.alignObjectSize(
+                RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + RamUsageEstimator.NUM_BYTES_OBJECT_REF * ((long) topCount + 1)
+            );
+            return total;
+        }
     }
 }

+ 2 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperatorScoringTests.java

@@ -97,6 +97,7 @@ public class LuceneTopNSourceOperatorScoringTests extends LuceneTopNSourceOperat
         int taskConcurrency = 0;
         int maxPageSize = between(10, Math.max(10, size));
         List<SortBuilder<?>> sorts = List.of(new FieldSortBuilder("s"));
+        long estimatedPerRowSortSize = 16;
         return new LuceneTopNSourceOperator.Factory(
             List.of(ctx),
             queryFunction,
@@ -105,6 +106,7 @@ public class LuceneTopNSourceOperatorScoringTests extends LuceneTopNSourceOperat
             maxPageSize,
             limit,
             sorts,
+            estimatedPerRowSortSize,
             true // scoring
         );
     }

+ 2 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperatorTests.java

@@ -103,6 +103,7 @@ public class LuceneTopNSourceOperatorTests extends SourceOperatorTestCase {
         int taskConcurrency = 0;
         int maxPageSize = between(10, Math.max(10, size));
         List<SortBuilder<?>> sorts = List.of(new FieldSortBuilder("s"));
+        long estimatedPerRowSortSize = 16;
         return new LuceneTopNSourceOperator.Factory(
             List.of(ctx),
             queryFunction,
@@ -111,6 +112,7 @@ public class LuceneTopNSourceOperatorTests extends SourceOperatorTestCase {
             maxPageSize,
             limit,
             sorts,
+            estimatedPerRowSortSize,
             scoring
         );
     }

+ 3 - 1
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/topn/TopNOperatorTests.java

@@ -1489,7 +1489,9 @@ public class TopNOperatorTests extends OperatorTestCase {
             block.decRef();
             op.addInput(new Page(blocks));
 
-            assertThat(breaker.getMemoryRequestCount(), is(94L));
+            // 94 are from the collection process
+            // 1 is for the min-heap itself
+            assertThat(breaker.getMemoryRequestCount(), is(95L));
         }
     }
 

+ 21 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/EsQueryExec.java

@@ -69,6 +69,12 @@ public class EsQueryExec extends LeafExec implements EstimatesRowSize {
         Order.OrderDirection direction();
 
         FieldAttribute field();
+
+        /**
+         * Type of the <strong>result</strong> of the sort. For example,
+         * geo distance will be {@link DataType#DOUBLE}.
+         */
+        DataType resulType();
     }
 
     public record FieldSort(FieldAttribute field, Order.OrderDirection direction, Order.NullsPosition nulls) implements Sort {
@@ -80,6 +86,11 @@ public class EsQueryExec extends LeafExec implements EstimatesRowSize {
             builder.unmappedType(field.dataType().esType());
             return builder;
         }
+
+        @Override
+        public DataType resulType() {
+            return field.dataType();
+        }
     }
 
     public record GeoDistanceSort(FieldAttribute field, Order.OrderDirection direction, double lat, double lon) implements Sort {
@@ -89,6 +100,11 @@ public class EsQueryExec extends LeafExec implements EstimatesRowSize {
             builder.order(Direction.from(direction).asOrder());
             return builder;
         }
+
+        @Override
+        public DataType resulType() {
+            return DataType.DOUBLE;
+        }
     }
 
     public record ScoreSort(Order.OrderDirection direction) implements Sort {
@@ -102,6 +118,11 @@ public class EsQueryExec extends LeafExec implements EstimatesRowSize {
             // TODO: refactor this: not all Sorts are backed by FieldAttributes
             return null;
         }
+
+        @Override
+        public DataType resulType() {
+            return DataType.DOUBLE;
+        }
     }
 
     public record QueryBuilderAndTags(QueryBuilder query, List<Object> tags) {

+ 10 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java

@@ -73,6 +73,7 @@ import org.elasticsearch.xpack.esql.core.type.PotentiallyUnmappedKeywordEsField;
 import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
 import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
 import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec.Sort;
+import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize;
 import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec;
 import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec;
 import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesFieldExtractExec;
@@ -294,9 +295,17 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
         boolean scoring = esQueryExec.hasScoring();
         if ((sorts != null && sorts.isEmpty() == false)) {
             List<SortBuilder<?>> sortBuilders = new ArrayList<>(sorts.size());
+            long estimatedPerRowSortSize = 0;
             for (Sort sort : sorts) {
                 sortBuilders.add(sort.sortBuilder());
+                estimatedPerRowSortSize += EstimatesRowSize.estimateSize(sort.resulType());
             }
+            /*
+             * In the worst case Lucene's TopN keeps each value in memory twice. Once
+             * for the actual sort and once for the top doc. In the best case they share
+             * references to the same underlying data, but we're being a bit paranoid here.
+             */
+            estimatedPerRowSortSize *= 2;
             // LuceneTopNSourceOperator does not support QueryAndTags, if there are multiple queries or if the single query has tags,
             // UnsupportedOperationException will be thrown by esQueryExec.query()
             luceneFactory = new LuceneTopNSourceOperator.Factory(
@@ -307,6 +316,7 @@ public class EsPhysicalOperationProviders extends AbstractPhysicalOperationProvi
                 context.pageSize(rowEstimatedSize),
                 limit,
                 sortBuilders,
+                estimatedPerRowSortSize,
                 scoring
             );
         } else {