Browse Source

Fix esql enrich memory leak (#109275) (#110450)

This PR was reviewed in #109275

Block and Vector use a non-thread-safe RefCounted. Threads that increase
or decrease the references must have a happen-before relationship.
However, this order is not guaranteed in the enrich lookup for the
reference of selectedPositions. The driver can complete the
MergePositionsOperator, which decreases the reference count of
selectedPositions, while the finally block may also decrease it in a
separate thread. These actions occur without a defined happen-before
relationship.

Closes #108532
Nhat Nguyen 1 year ago
parent
commit
c8ece6a78e

+ 91 - 73
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java

@@ -31,7 +31,6 @@ import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.BlockStreamInput;
 import org.elasticsearch.compute.data.BlockStreamInput;
 import org.elasticsearch.compute.data.BytesRefBlock;
 import org.elasticsearch.compute.data.BytesRefBlock;
 import org.elasticsearch.compute.data.ElementType;
 import org.elasticsearch.compute.data.ElementType;
-import org.elasticsearch.compute.data.IntBlock;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.IntVector;
 import org.elasticsearch.compute.data.LocalCircuitBreaker;
 import org.elasticsearch.compute.data.LocalCircuitBreaker;
 import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
 import org.elasticsearch.compute.data.OrdinalBytesRefBlock;
@@ -43,6 +42,7 @@ import org.elasticsearch.compute.operator.Operator;
 import org.elasticsearch.compute.operator.OutputOperator;
 import org.elasticsearch.compute.operator.OutputOperator;
 import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.core.RefCounted;
+import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.core.Releasables;
 import org.elasticsearch.index.mapper.BlockLoader;
 import org.elasticsearch.index.mapper.BlockLoader;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.MappedFieldType;
@@ -247,30 +247,53 @@ public class EnrichLookupService {
         ActionListener<Page> listener
         ActionListener<Page> listener
     ) {
     ) {
         Block inputBlock = inputPage.getBlock(0);
         Block inputBlock = inputPage.getBlock(0);
-        final IntBlock selectedPositions;
-        final OrdinalBytesRefBlock ordinalsBytesRefBlock;
-        if (inputBlock instanceof BytesRefBlock bytesRefBlock && (ordinalsBytesRefBlock = bytesRefBlock.asOrdinals()) != null) {
-            inputBlock = ordinalsBytesRefBlock.getDictionaryVector().asBlock();
-            selectedPositions = ordinalsBytesRefBlock.getOrdinalsBlock();
-            selectedPositions.mustIncRef();
-        } else {
-            selectedPositions = IntVector.range(0, inputBlock.getPositionCount(), blockFactory).asBlock();
+        if (inputBlock.areAllValuesNull()) {
+            listener.onResponse(createNullResponse(inputPage.getPositionCount(), extractFields));
+            return;
         }
         }
-        LocalCircuitBreaker localBreaker = null;
+        final List<Releasable> releasables = new ArrayList<>(6);
+        boolean started = false;
         try {
         try {
-            if (inputBlock.areAllValuesNull()) {
-                listener.onResponse(createNullResponse(inputPage.getPositionCount(), extractFields));
-                return;
-            }
-            ShardSearchRequest shardSearchRequest = new ShardSearchRequest(shardId, 0, AliasFilter.EMPTY);
-            SearchContext searchContext = searchService.createSearchContext(shardSearchRequest, SearchService.NO_TIMEOUT);
-            listener = ActionListener.runBefore(listener, searchContext::close);
-            localBreaker = new LocalCircuitBreaker(
+            final ShardSearchRequest shardSearchRequest = new ShardSearchRequest(shardId, 0, AliasFilter.EMPTY);
+            final SearchContext searchContext = searchService.createSearchContext(shardSearchRequest, SearchService.NO_TIMEOUT);
+            releasables.add(searchContext);
+            final LocalCircuitBreaker localBreaker = new LocalCircuitBreaker(
                 blockFactory.breaker(),
                 blockFactory.breaker(),
                 localBreakerSettings.overReservedBytes(),
                 localBreakerSettings.overReservedBytes(),
                 localBreakerSettings.maxOverReservedBytes()
                 localBreakerSettings.maxOverReservedBytes()
             );
             );
-            DriverContext driverContext = new DriverContext(bigArrays, blockFactory.newChildFactory(localBreaker));
+            releasables.add(localBreaker);
+            final DriverContext driverContext = new DriverContext(bigArrays, blockFactory.newChildFactory(localBreaker));
+            final ElementType[] mergingTypes = new ElementType[extractFields.size()];
+            for (int i = 0; i < extractFields.size(); i++) {
+                mergingTypes[i] = PlannerUtils.toElementType(extractFields.get(i).dataType());
+            }
+            final int[] mergingChannels = IntStream.range(0, extractFields.size()).map(i -> i + 2).toArray();
+            final MergePositionsOperator mergePositionsOperator;
+            final OrdinalBytesRefBlock ordinalsBytesRefBlock;
+            if (inputBlock instanceof BytesRefBlock bytesRefBlock && (ordinalsBytesRefBlock = bytesRefBlock.asOrdinals()) != null) {
+                inputBlock = ordinalsBytesRefBlock.getDictionaryVector().asBlock();
+                var selectedPositions = ordinalsBytesRefBlock.getOrdinalsBlock();
+                mergePositionsOperator = new MergePositionsOperator(
+                    1,
+                    mergingChannels,
+                    mergingTypes,
+                    selectedPositions,
+                    driverContext.blockFactory()
+                );
+
+            } else {
+                try (var selectedPositions = IntVector.range(0, inputBlock.getPositionCount(), blockFactory).asBlock()) {
+                    mergePositionsOperator = new MergePositionsOperator(
+                        1,
+                        mergingChannels,
+                        mergingTypes,
+                        selectedPositions,
+                        driverContext.blockFactory()
+                    );
+                }
+            }
+            releasables.add(mergePositionsOperator);
             SearchExecutionContext searchExecutionContext = searchContext.getSearchExecutionContext();
             SearchExecutionContext searchExecutionContext = searchContext.getSearchExecutionContext();
             MappedFieldType fieldType = searchExecutionContext.getFieldType(matchField);
             MappedFieldType fieldType = searchExecutionContext.getFieldType(matchField);
             var queryList = switch (matchType) {
             var queryList = switch (matchType) {
@@ -284,57 +307,13 @@ public class EnrichLookupService {
                 queryList,
                 queryList,
                 searchExecutionContext.getIndexReader()
                 searchExecutionContext.getIndexReader()
             );
             );
-            List<Operator> intermediateOperators = new ArrayList<>(extractFields.size() + 2);
-            final ElementType[] mergingTypes = new ElementType[extractFields.size()];
-            // load the fields
-            List<ValuesSourceReaderOperator.FieldInfo> fields = new ArrayList<>(extractFields.size());
-            for (int i = 0; i < extractFields.size(); i++) {
-                NamedExpression extractField = extractFields.get(i);
-                final ElementType elementType = PlannerUtils.toElementType(extractField.dataType());
-                mergingTypes[i] = elementType;
-                EsPhysicalOperationProviders.ShardContext ctx = new EsPhysicalOperationProviders.DefaultShardContext(
-                    0,
-                    searchContext.getSearchExecutionContext(),
-                    searchContext.request().getAliasFilter()
-                );
-                BlockLoader loader = ctx.blockLoader(
-                    extractField instanceof Alias a ? ((NamedExpression) a.child()).name() : extractField.name(),
-                    extractField.dataType() == DataType.UNSUPPORTED,
-                    MappedFieldType.FieldExtractPreference.NONE
-                );
-                fields.add(
-                    new ValuesSourceReaderOperator.FieldInfo(
-                        extractField.name(),
-                        PlannerUtils.toElementType(extractField.dataType()),
-                        shardIdx -> {
-                            if (shardIdx != 0) {
-                                throw new IllegalStateException("only one shard");
-                            }
-                            return loader;
-                        }
-                    )
-                );
-            }
-            intermediateOperators.add(
-                new ValuesSourceReaderOperator(
-                    driverContext.blockFactory(),
-                    fields,
-                    List.of(
-                        new ValuesSourceReaderOperator.ShardContext(
-                            searchContext.searcher().getIndexReader(),
-                            searchContext::newSourceLoader
-                        )
-                    ),
-                    0
-                )
-            );
-            // merging field-values by position
-            final int[] mergingChannels = IntStream.range(0, extractFields.size()).map(i -> i + 2).toArray();
-            intermediateOperators.add(
-                new MergePositionsOperator(1, mergingChannels, mergingTypes, selectedPositions, driverContext.blockFactory())
-            );
+            releasables.add(queryOperator);
+            var extractFieldsOperator = extractFieldsOperator(searchContext, driverContext, extractFields);
+            releasables.add(extractFieldsOperator);
+
             AtomicReference<Page> result = new AtomicReference<>();
             AtomicReference<Page> result = new AtomicReference<>();
             OutputOperator outputOperator = new OutputOperator(List.of(), Function.identity(), result::set);
             OutputOperator outputOperator = new OutputOperator(List.of(), Function.identity(), result::set);
+            releasables.add(outputOperator);
             Driver driver = new Driver(
             Driver driver = new Driver(
                 "enrich-lookup:" + sessionId,
                 "enrich-lookup:" + sessionId,
                 System.currentTimeMillis(),
                 System.currentTimeMillis(),
@@ -350,18 +329,16 @@ public class EnrichLookupService {
                     inputPage.getPositionCount()
                     inputPage.getPositionCount()
                 ),
                 ),
                 queryOperator,
                 queryOperator,
-                intermediateOperators,
+                List.of(extractFieldsOperator, mergePositionsOperator),
                 outputOperator,
                 outputOperator,
                 Driver.DEFAULT_STATUS_INTERVAL,
                 Driver.DEFAULT_STATUS_INTERVAL,
-                localBreaker
+                Releasables.wrap(searchContext, localBreaker)
             );
             );
             task.addListener(() -> {
             task.addListener(() -> {
                 String reason = Objects.requireNonNullElse(task.getReasonCancelled(), "task was cancelled");
                 String reason = Objects.requireNonNullElse(task.getReasonCancelled(), "task was cancelled");
                 driver.cancel(reason);
                 driver.cancel(reason);
             });
             });
-
             var threadContext = transportService.getThreadPool().getThreadContext();
             var threadContext = transportService.getThreadPool().getThreadContext();
-            localBreaker = null;
             Driver.start(threadContext, executor, driver, Driver.DEFAULT_MAX_ITERATIONS, listener.map(ignored -> {
             Driver.start(threadContext, executor, driver, Driver.DEFAULT_MAX_ITERATIONS, listener.map(ignored -> {
                 Page out = result.get();
                 Page out = result.get();
                 if (out == null) {
                 if (out == null) {
@@ -369,11 +346,52 @@ public class EnrichLookupService {
                 }
                 }
                 return out;
                 return out;
             }));
             }));
+            started = true;
         } catch (Exception e) {
         } catch (Exception e) {
             listener.onFailure(e);
             listener.onFailure(e);
         } finally {
         } finally {
-            Releasables.close(selectedPositions, localBreaker);
+            if (started == false) {
+                Releasables.close(releasables);
+            }
+        }
+    }
+
+    private static Operator extractFieldsOperator(
+        SearchContext searchContext,
+        DriverContext driverContext,
+        List<NamedExpression> extractFields
+    ) {
+        EsPhysicalOperationProviders.ShardContext shardContext = new EsPhysicalOperationProviders.DefaultShardContext(
+            0,
+            searchContext.getSearchExecutionContext(),
+            searchContext.request().getAliasFilter()
+        );
+        List<ValuesSourceReaderOperator.FieldInfo> fields = new ArrayList<>(extractFields.size());
+        for (NamedExpression extractField : extractFields) {
+            BlockLoader loader = shardContext.blockLoader(
+                extractField instanceof Alias a ? ((NamedExpression) a.child()).name() : extractField.name(),
+                extractField.dataType() == DataType.UNSUPPORTED,
+                MappedFieldType.FieldExtractPreference.NONE
+            );
+            fields.add(
+                new ValuesSourceReaderOperator.FieldInfo(
+                    extractField.name(),
+                    PlannerUtils.toElementType(extractField.dataType()),
+                    shardIdx -> {
+                        if (shardIdx != 0) {
+                            throw new IllegalStateException("only one shard");
+                        }
+                        return loader;
+                    }
+                )
+            );
         }
         }
+        return new ValuesSourceReaderOperator(
+            driverContext.blockFactory(),
+            fields,
+            List.of(new ValuesSourceReaderOperator.ShardContext(searchContext.searcher().getIndexReader(), searchContext::newSourceLoader)),
+            0
+        );
     }
     }
 
 
     private Page createNullResponse(int positionCount, List<NamedExpression> extractFields) {
     private Page createNullResponse(int positionCount, List<NamedExpression> extractFields) {