Browse Source

Limit concurrent shards per node for ESQL (#104832)

Today, we allow ESQL to execute against an unlimited number of shards 
concurrently on each node. This can lead to cases where we open and hold
too many shards, equivalent to opening too many file descriptors or
using too much memory for FieldInfos in ValuesSourceReaderOperator.

This change limits the number of concurrent shards to 10 per node. This 
number was chosen based on the _search API, which limits it to 5.
Besides the primary reason stated above, this change has other
implications:

We might execute fewer shards for queries with LIMIT only, leading to 
scenarios where we execute only some high-priority shards then stop. 
For now, we don't have a partial reduce at the node level, but if we
introduce one in the future, it might not be as efficient as executing
all shards at the same time.  There are pauses between batches because
batches are executed sequentially one by one.  However, I believe the
performance of queries executing against many shards (after can_match)
is less important than resiliency.

Closes #103666
Nhat Nguyen 1 year ago
parent
commit
aea4684b52

+ 6 - 0
docs/changelog/104832.yaml

@@ -0,0 +1,6 @@
+pr: 104832
+summary: Limit concurrent shards per node for ESQL
+area: ES|QL
+type: bug
+issues:
+ - 103666

+ 1 - 1
server/src/main/java/org/elasticsearch/search/SearchService.java

@@ -1059,7 +1059,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         return context;
     }
 
-    public DefaultSearchContext createSearchContext(ShardSearchRequest request, TimeValue timeout) throws IOException {
+    public SearchContext createSearchContext(ShardSearchRequest request, TimeValue timeout) throws IOException {
         final IndexService indexService = indicesService.indexServiceSafe(request.shardId().getIndex());
         final IndexShard indexShard = indexService.getShard(request.shardId().getId());
         final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier();

+ 1 - 1
server/src/test/java/org/elasticsearch/search/SearchServiceTests.java

@@ -1252,7 +1252,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
             nowInMillis,
             clusterAlias
         );
-        try (DefaultSearchContext searchContext = service.createSearchContext(request, new TimeValue(System.currentTimeMillis()))) {
+        try (SearchContext searchContext = service.createSearchContext(request, new TimeValue(System.currentTimeMillis()))) {
             SearchShardTarget searchShardTarget = searchContext.shardTarget();
             SearchExecutionContext searchExecutionContext = searchContext.getSearchExecutionContext();
             String expectedIndexName = clusterAlias == null ? index : clusterAlias + ":" + index;

+ 15 - 0
test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java

@@ -11,6 +11,7 @@ package org.elasticsearch.search;
 import org.elasticsearch.action.search.SearchShardTask;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.indices.ExecutorSelector;
 import org.elasticsearch.indices.IndicesService;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
@@ -41,6 +42,7 @@ public class MockSearchService extends SearchService {
     private static final Map<ReaderContext, Throwable> ACTIVE_SEARCH_CONTEXTS = new ConcurrentHashMap<>();
 
     private Consumer<ReaderContext> onPutContext = context -> {};
+    private Consumer<ReaderContext> onRemoveContext = context -> {};
 
     private Consumer<SearchContext> onCreateSearchContext = context -> {};
 
@@ -110,6 +112,7 @@ public class MockSearchService extends SearchService {
     protected ReaderContext removeReaderContext(long id) {
         final ReaderContext removed = super.removeReaderContext(id);
         if (removed != null) {
+            onRemoveContext.accept(removed);
             removeActiveContext(removed);
         }
         return removed;
@@ -119,6 +122,10 @@ public class MockSearchService extends SearchService {
         this.onPutContext = onPutContext;
     }
 
+    public void setOnRemoveContext(Consumer<ReaderContext> onRemoveContext) {
+        this.onRemoveContext = onRemoveContext;
+    }
+
     public void setOnCreateSearchContext(Consumer<SearchContext> onCreateSearchContext) {
         this.onCreateSearchContext = onCreateSearchContext;
     }
@@ -141,6 +148,14 @@ public class MockSearchService extends SearchService {
         return searchContext;
     }
 
+    @Override
+    public SearchContext createSearchContext(ShardSearchRequest request, TimeValue timeout) throws IOException {
+        SearchContext searchContext = super.createSearchContext(request, timeout);
+        onPutContext.accept(searchContext.readerContext());
+        searchContext.addReleasable(() -> onRemoveContext.accept(searchContext.readerContext()));
+        return searchContext;
+    }
+
     public void setOnCheckCancelled(Function<SearchShardTask, SearchShardTask> onCheckCancelled) {
         this.onCheckCancelled = onCheckCancelled;
     }

+ 4 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSinkHandler.java

@@ -108,7 +108,10 @@ public final class ExchangeSinkHandler {
         completionFuture.addListener(listener);
     }
 
-    boolean isFinished() {
+    /**
+     * Returns true if an exchange is finished
+     */
+    public boolean isFinished() {
         return completionFuture.isDone();
     }
 

+ 3 - 0
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractEsqlIntegTestCase.java

@@ -185,6 +185,9 @@ public abstract class AbstractEsqlIntegTestCase extends ESIntegTestCase {
                 };
                 settings.put("page_size", pageSize);
             }
+            if (randomBoolean()) {
+                settings.put("max_concurrent_shards_per_node", randomIntBetween(1, 10));
+            }
         }
         return new QueryPragmas(settings.build());
     }

+ 74 - 1
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/ManyShardsIT.java

@@ -13,11 +13,20 @@ import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.search.MockSearchService;
+import org.elasticsearch.search.SearchService;
 import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
+import org.hamcrest.Matchers;
+import org.junit.Before;
 
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
 
 /**
  * Make sures that we can run many concurrent requests with large number of shards with any data_partitioning.
@@ -25,7 +34,15 @@ import java.util.concurrent.TimeUnit;
 @LuceneTestCase.SuppressFileSystems(value = "HandleLimitFS")
 public class ManyShardsIT extends AbstractEsqlIntegTestCase {
 
-    public void testConcurrentQueries() throws Exception {
+    @Override
+    protected Collection<Class<? extends Plugin>> getMockPlugins() {
+        var plugins = new ArrayList<>(super.getMockPlugins());
+        plugins.add(MockSearchService.TestPlugin.class);
+        return plugins;
+    }
+
+    @Before
+    public void setupIndices() {
         int numIndices = between(10, 20);
         for (int i = 0; i < numIndices; i++) {
             String index = "test-" + i;
@@ -49,6 +66,9 @@ public class ManyShardsIT extends AbstractEsqlIntegTestCase {
             }
             bulk.get();
         }
+    }
+
+    public void testConcurrentQueries() throws Exception {
         int numQueries = between(10, 20);
         Thread[] threads = new Thread[numQueries];
         CountDownLatch latch = new CountDownLatch(1);
@@ -76,4 +96,57 @@ public class ManyShardsIT extends AbstractEsqlIntegTestCase {
             thread.join();
         }
     }
+
+    static class SearchContextCounter {
+        private final int maxAllowed;
+        private final AtomicInteger current = new AtomicInteger();
+
+        SearchContextCounter(int maxAllowed) {
+            this.maxAllowed = maxAllowed;
+        }
+
+        void onNewContext() {
+            int total = current.incrementAndGet();
+            assertThat("opening more shards than the limit", total, Matchers.lessThanOrEqualTo(maxAllowed));
+        }
+
+        void onContextReleased() {
+            int total = current.decrementAndGet();
+            assertThat(total, Matchers.greaterThanOrEqualTo(0));
+        }
+    }
+
+    public void testLimitConcurrentShards() {
+        Iterable<SearchService> searchServices = internalCluster().getInstances(SearchService.class);
+        try {
+            var queries = List.of(
+                "from test-* | stats count(user) by tags",
+                "from test-* | stats count(user) by tags | LIMIT 0",
+                "from test-* | stats count(user) by tags | LIMIT 1",
+                "from test-* | stats count(user) by tags | LIMIT 1000",
+                "from test-* | LIMIT 0",
+                "from test-* | LIMIT 1",
+                "from test-* | LIMIT 1000",
+                "from test-* | SORT tags | LIMIT 0",
+                "from test-* | SORT tags | LIMIT 1",
+                "from test-* | SORT tags | LIMIT 1000"
+            );
+            for (String q : queries) {
+                QueryPragmas pragmas = randomPragmas();
+                for (SearchService searchService : searchServices) {
+                    SearchContextCounter counter = new SearchContextCounter(pragmas.maxConcurrentShardsPerNode());
+                    var mockSearchService = (MockSearchService) searchService;
+                    mockSearchService.setOnPutContext(r -> counter.onNewContext());
+                    mockSearchService.setOnRemoveContext(r -> counter.onContextReleased());
+                }
+                run(q, pragmas).close();
+            }
+        } finally {
+            for (SearchService searchService : searchServices) {
+                var mockSearchService = (MockSearchService) searchService;
+                mockSearchService.setOnPutContext(r -> {});
+                mockSearchService.setOnRemoveContext(r -> {});
+            }
+        }
+    }
 }

+ 11 - 2
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/WarningsIT.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.esql.action;
 
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.transport.TransportService;
@@ -38,7 +39,11 @@ public class WarningsIT extends AbstractEsqlIntegTestCase {
             client().admin()
                 .indices()
                 .prepareCreate("index-1")
-                .setSettings(Settings.builder().put("index.routing.allocation.require._name", node1))
+                .setSettings(
+                    Settings.builder()
+                        .put("index.routing.allocation.require._name", node1)
+                        .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 5))
+                )
                 .setMapping("host", "type=keyword")
         );
         for (int i = 0; i < numDocs1; i++) {
@@ -49,7 +54,11 @@ public class WarningsIT extends AbstractEsqlIntegTestCase {
             client().admin()
                 .indices()
                 .prepareCreate("index-2")
-                .setSettings(Settings.builder().put("index.routing.allocation.require._name", node2))
+                .setSettings(
+                    Settings.builder()
+                        .put("index.routing.allocation.require._name", node2)
+                        .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 5))
+                )
                 .setMapping("host", "type=keyword")
         );
         for (int i = 0; i < numDocs2; i++) {

+ 88 - 37
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java

@@ -31,6 +31,7 @@ import org.elasticsearch.compute.operator.DriverProfile;
 import org.elasticsearch.compute.operator.DriverTaskRunner;
 import org.elasticsearch.compute.operator.ResponseHeadersCollector;
 import org.elasticsearch.compute.operator.exchange.ExchangeService;
+import org.elasticsearch.compute.operator.exchange.ExchangeSink;
 import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
 import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler;
 import org.elasticsearch.core.IOUtils;
@@ -369,7 +370,7 @@ public class ComputeService {
     }
 
     void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener<List<DriverProfile>> listener) {
-        listener = ActionListener.runAfter(listener, () -> Releasables.close(context.searchContexts));
+        listener = ActionListener.runBefore(listener, () -> Releasables.close(context.searchContexts));
         List<EsPhysicalOperationProviders.ShardContext> contexts = new ArrayList<>(context.searchContexts.size());
         for (int i = 0; i < context.searchContexts.size(); i++) {
             SearchContext searchContext = context.searchContexts.get(i);
@@ -457,6 +458,8 @@ public class ComputeService {
                         aliasFilter,
                         clusterAlias
                     );
+                    // TODO: `searchService.createSearchContext` allows opening search contexts without limits,
+                    // we need to limit the number of active search contexts here or in SearchService
                     SearchContext context = searchService.createSearchContext(shardRequest, SearchService.NO_TIMEOUT);
                     searchContexts.add(context);
                 }
@@ -576,46 +579,94 @@ public class ComputeService {
     // TODO: Use an internal action here
     public static final String DATA_ACTION_NAME = EsqlQueryAction.NAME + "/data";
 
-    private class DataNodeRequestHandler implements TransportRequestHandler<DataNodeRequest> {
-        @Override
-        public void messageReceived(DataNodeRequest request, TransportChannel channel, Task task) {
-            final var parentTask = (CancellableTask) task;
-            final var sessionId = request.sessionId();
-            final var exchangeSink = exchangeService.getSinkHandler(sessionId);
+    private class DataNodeRequestExecutor {
+        private final DataNodeRequest request;
+        private final CancellableTask parentTask;
+        private final ExchangeSinkHandler exchangeSink;
+        private final ActionListener<ComputeResponse> listener;
+        private final List<DriverProfile> driverProfiles;
+        private final int maxConcurrentShards;
+        private final ExchangeSink blockingSink; // block until we have completed on all shards or the coordinator has enough data
+
+        DataNodeRequestExecutor(
+            DataNodeRequest request,
+            CancellableTask parentTask,
+            ExchangeSinkHandler exchangeSink,
+            int maxConcurrentShards,
+            ActionListener<ComputeResponse> listener
+        ) {
+            this.request = request;
+            this.parentTask = parentTask;
+            this.exchangeSink = exchangeSink;
+            this.listener = listener;
+            this.driverProfiles = request.configuration().profile() ? Collections.synchronizedList(new ArrayList<>()) : List.of();
+            this.maxConcurrentShards = maxConcurrentShards;
+            this.blockingSink = exchangeSink.createExchangeSink();
+        }
+
+        void start() {
             parentTask.addListener(
-                () -> exchangeService.finishSinkHandler(sessionId, new TaskCancelledException(parentTask.getReasonCancelled()))
+                () -> exchangeService.finishSinkHandler(request.sessionId(), new TaskCancelledException(parentTask.getReasonCancelled()))
             );
-            final ActionListener<ComputeResponse> listener = new ChannelActionListener<>(channel);
+            runBatch(0);
+        }
+
+        private void runBatch(int startBatchIndex) {
             final EsqlConfiguration configuration = request.configuration();
-            String clusterAlias = request.clusterAlias();
-            acquireSearchContexts(
-                clusterAlias,
-                request.shardIds(),
-                configuration,
-                request.aliasFilters(),
-                ActionListener.wrap(searchContexts -> {
-                    assert ThreadPool.assertCurrentThreadPool(ESQL_THREAD_POOL_NAME);
-                    var computeContext = new ComputeContext(sessionId, clusterAlias, searchContexts, configuration, null, exchangeSink);
-                    runCompute(parentTask, computeContext, request.plan(), ActionListener.wrap(driverProfiles -> {
-                        // don't return until all pages are fetched
-                        exchangeSink.addCompletionListener(
-                            ContextPreservingActionListener.wrapPreservingContext(
-                                ActionListener.releaseAfter(
-                                    listener.map(nullValue -> new ComputeResponse(driverProfiles)),
-                                    () -> exchangeService.finishSinkHandler(sessionId, null)
-                                ),
-                                transportService.getThreadPool().getThreadContext()
-                            )
-                        );
-                    }, e -> {
-                        exchangeService.finishSinkHandler(sessionId, e);
-                        listener.onFailure(e);
-                    }));
-                }, e -> {
-                    exchangeService.finishSinkHandler(sessionId, e);
-                    listener.onFailure(e);
-                })
+            final String clusterAlias = request.clusterAlias();
+            final var sessionId = request.sessionId();
+            final int endBatchIndex = Math.min(startBatchIndex + maxConcurrentShards, request.shardIds().size());
+            List<ShardId> shardIds = request.shardIds().subList(startBatchIndex, endBatchIndex);
+            acquireSearchContexts(clusterAlias, shardIds, configuration, request.aliasFilters(), ActionListener.wrap(searchContexts -> {
+                assert ThreadPool.assertCurrentThreadPool(ESQL_THREAD_POOL_NAME, ESQL_WORKER_THREAD_POOL_NAME);
+                var computeContext = new ComputeContext(sessionId, clusterAlias, searchContexts, configuration, null, exchangeSink);
+                runCompute(
+                    parentTask,
+                    computeContext,
+                    request.plan(),
+                    ActionListener.wrap(profiles -> onBatchCompleted(endBatchIndex, profiles), this::onFailure)
+                );
+            }, this::onFailure));
+        }
+
+        private void onBatchCompleted(int lastBatchIndex, List<DriverProfile> batchProfiles) {
+            if (request.configuration().profile()) {
+                driverProfiles.addAll(batchProfiles);
+            }
+            if (lastBatchIndex < request.shardIds().size() && exchangeSink.isFinished() == false) {
+                runBatch(lastBatchIndex);
+            } else {
+                blockingSink.finish();
+                // don't return until all pages are fetched
+                exchangeSink.addCompletionListener(
+                    ContextPreservingActionListener.wrapPreservingContext(
+                        ActionListener.runBefore(
+                            listener.map(nullValue -> new ComputeResponse(driverProfiles)),
+                            () -> exchangeService.finishSinkHandler(request.sessionId(), null)
+                        ),
+                        transportService.getThreadPool().getThreadContext()
+                    )
+                );
+            }
+        }
+
+        private void onFailure(Exception e) {
+            exchangeService.finishSinkHandler(request.sessionId(), e);
+            listener.onFailure(e);
+        }
+    }
+
+    private class DataNodeRequestHandler implements TransportRequestHandler<DataNodeRequest> {
+        @Override
+        public void messageReceived(DataNodeRequest request, TransportChannel channel, Task task) {
+            DataNodeRequestExecutor executor = new DataNodeRequestExecutor(
+                request,
+                (CancellableTask) task,
+                exchangeService.getSinkHandler(request.sessionId()),
+                request.configuration().pragmas().maxConcurrentShardsPerNode(),
+                new ChannelActionListener<>(channel)
             );
+            executor.start();
         }
     }
 

+ 10 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java

@@ -53,6 +53,8 @@ public final class QueryPragmas implements Writeable {
      */
     public static final Setting<TimeValue> STATUS_INTERVAL = Setting.timeSetting("status_interval", Driver.DEFAULT_STATUS_INTERVAL);
 
+    public static final Setting<Integer> MAX_CONCURRENT_SHARDS_PER_NODE = Setting.intSetting("max_concurrent_shards_per_node", 10, 1, 100);
+
     public static final QueryPragmas EMPTY = new QueryPragmas(Settings.EMPTY);
 
     private final Settings settings;
@@ -114,6 +116,14 @@ public final class QueryPragmas implements Writeable {
         return ENRICH_MAX_WORKERS.get(settings);
     }
 
+    /**
+     * The maximum number of shards can be executed concurrently on a single node by this query. This is a safeguard to avoid
+     * opening and holding many shards (equivalent to many file descriptors) or having too many field infos created by a single query.
+     */
+    public int maxConcurrentShardsPerNode() {
+        return MAX_CONCURRENT_SHARDS_PER_NODE.get(settings);
+    }
+
     public boolean isEmpty() {
         return settings.isEmpty();
     }