Browse Source

Simplify `IndicesShardStoresAction` (#94507)

- No need to use an `AsyncShardFetch` here, there is no caching
- Response may be very large, introduce chunking
- Fan-out may be very large, introduce throttling
- Processing time may be nontrivial, introduce cancellability
- Eliminate many unnecessary intermediate data structures
- Do shard-level response processing more eagerly
- Determine allocation from `RoutingTable` not `RoutingNodes`
- Add tests

Relates #81081
David Turner 2 years ago
parent
commit
e377a8601a

+ 10 - 0
docs/reference/indices/shard-stores.asciidoc

@@ -93,6 +93,16 @@ regardless of health status.
 Defaults to `yellow,red`.
 --
 
+`max_concurrent_shard_requests`::
++
+--
+(Optional, integer)
+Maximum number of concurrent shard-level requests sent by the coordinating
+node. Defaults to `100`. Larger values may yield a quicker response to requests
+that target many shards, but may also cause a larger impact on other cluster
+operations.
+--
+
 [[index-shard-stores-api-example]]
 ==== {api-examples-title}
 

+ 38 - 0
server/src/main/java/org/elasticsearch/action/admin/indices/shards/IndicesShardStoresRequest.java

@@ -7,6 +7,7 @@
  */
 package org.elasticsearch.action.admin.indices.shards;
 
+import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.IndicesRequest;
 import org.elasticsearch.action.support.IndicesOptions;
@@ -15,18 +16,25 @@ import org.elasticsearch.cluster.health.ClusterHealthStatus;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 
 import java.io.IOException;
 import java.util.EnumSet;
+import java.util.Map;
 
 /**
  * Request for {@link IndicesShardStoresAction}
  */
 public class IndicesShardStoresRequest extends MasterNodeReadRequest<IndicesShardStoresRequest> implements IndicesRequest.Replaceable {
 
+    static final int DEFAULT_MAX_CONCURRENT_SHARD_REQUESTS = 100;
+
     private String[] indices = Strings.EMPTY_ARRAY;
     private IndicesOptions indicesOptions = IndicesOptions.strictExpand();
     private EnumSet<ClusterHealthStatus> statuses = EnumSet.of(ClusterHealthStatus.YELLOW, ClusterHealthStatus.RED);
+    private int maxConcurrentShardRequests = DEFAULT_MAX_CONCURRENT_SHARD_REQUESTS;
 
     /**
      * Create a request for shard stores info for <code>indices</code>
@@ -46,6 +54,12 @@ public class IndicesShardStoresRequest extends MasterNodeReadRequest<IndicesShar
             statuses.add(ClusterHealthStatus.readFrom(in));
         }
         indicesOptions = IndicesOptions.readIndicesOptions(in);
+        if (in.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
+            maxConcurrentShardRequests = in.readVInt();
+        } else {
+            // earlier versions had unlimited concurrency
+            maxConcurrentShardRequests = Integer.MAX_VALUE;
+        }
     }
 
     @Override
@@ -54,6 +68,17 @@ public class IndicesShardStoresRequest extends MasterNodeReadRequest<IndicesShar
         out.writeStringArrayNullable(indices);
         out.writeCollection(statuses, (o, v) -> o.writeByte(v.value()));
         indicesOptions.writeIndicesOptions(out);
+        if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
+            out.writeVInt(maxConcurrentShardRequests);
+        } else if (maxConcurrentShardRequests != DEFAULT_MAX_CONCURRENT_SHARD_REQUESTS) {
+            throw new IllegalArgumentException(
+                "support for maxConcurrentShardRequests=["
+                    + maxConcurrentShardRequests
+                    + "] was added in version [8.8.0], cannot send this request using transport version ["
+                    + out.getTransportVersion()
+                    + "]"
+            );
+        } // else just drop the value and use the default behaviour
     }
 
     /**
@@ -114,8 +139,21 @@ public class IndicesShardStoresRequest extends MasterNodeReadRequest<IndicesShar
         return indicesOptions;
     }
 
+    public void maxConcurrentShardRequests(int maxConcurrentShardRequests) {
+        this.maxConcurrentShardRequests = maxConcurrentShardRequests;
+    }
+
+    public int maxConcurrentShardRequests() {
+        return maxConcurrentShardRequests;
+    }
+
     @Override
     public ActionRequestValidationException validate() {
         return null;
     }
+
+    @Override
+    public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
+        return new CancellableTask(id, type, action, "", parentTaskId, headers);
+    }
 }

+ 47 - 34
server/src/main/java/org/elasticsearch/action/admin/indices/shards/IndicesShardStoresResponse.java

@@ -12,13 +12,19 @@ import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.ActionResponse;
 import org.elasticsearch.action.support.DefaultShardOperationFailedException;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.collect.Iterators;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
+import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
+import org.elasticsearch.xcontent.ToXContent;
 import org.elasticsearch.xcontent.ToXContentFragment;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
+import java.util.Collections;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 
@@ -28,7 +34,7 @@ import java.util.Map;
  * Consists of {@link StoreStatus}s for requested indices grouped by
  * indices and shard ids and a list of encountered node {@link Failure}s
  */
-public class IndicesShardStoresResponse extends ActionResponse implements ToXContentFragment {
+public class IndicesShardStoresResponse extends ActionResponse implements ChunkedToXContentObject {
 
     /**
      * Shard store information from a node
@@ -196,7 +202,7 @@ public class IndicesShardStoresResponse extends ActionResponse implements ToXCon
      * Single node failure while retrieving shard store information
      */
     public static class Failure extends DefaultShardOperationFailedException {
-        private String nodeId;
+        private final String nodeId;
 
         public Failure(String nodeId, String index, int shardId, Throwable reason) {
             super(index, shardId, reason);
@@ -273,38 +279,45 @@ public class IndicesShardStoresResponse extends ActionResponse implements ToXCon
     }
 
     @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        if (failures.size() > 0) {
-            builder.startArray(Fields.FAILURES);
-            for (Failure failure : failures) {
-                failure.toXContent(builder, params);
-            }
-            builder.endArray();
-        }
-
-        builder.startObject(Fields.INDICES);
-        for (Map.Entry<String, Map<Integer, List<StoreStatus>>> indexShards : storeStatuses.entrySet()) {
-            builder.startObject(indexShards.getKey());
-
-            builder.startObject(Fields.SHARDS);
-            for (Map.Entry<Integer, List<StoreStatus>> shardStatusesEntry : indexShards.getValue().entrySet()) {
-                builder.startObject(String.valueOf(shardStatusesEntry.getKey()));
-                builder.startArray(Fields.STORES);
-                for (StoreStatus storeStatus : shardStatusesEntry.getValue()) {
-                    builder.startObject();
-                    storeStatus.toXContent(builder, params);
-                    builder.endObject();
-                }
-                builder.endArray();
-
-                builder.endObject();
-            }
-            builder.endObject();
-
-            builder.endObject();
-        }
-        builder.endObject();
-        return builder;
+    public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params outerParams) {
+        return Iterators.concat(
+            ChunkedToXContentHelper.startObject(),
+
+            failures.isEmpty()
+                ? Collections.emptyIterator()
+                : Iterators.concat(
+                    ChunkedToXContentHelper.startArray(Fields.FAILURES),
+                    failures.iterator(),
+                    ChunkedToXContentHelper.endArray()
+                ),
+
+            ChunkedToXContentHelper.startObject(Fields.INDICES),
+
+            Iterators.flatMap(
+                storeStatuses.entrySet().iterator(),
+                indexShards -> Iterators.concat(
+                    ChunkedToXContentHelper.startObject(indexShards.getKey()),
+                    ChunkedToXContentHelper.startObject(Fields.SHARDS),
+                    Iterators.flatMap(
+                        indexShards.getValue().entrySet().iterator(),
+                        shardStatusesEntry -> Iterators.single((ToXContent) (builder, params) -> {
+                            builder.startObject(String.valueOf(shardStatusesEntry.getKey())).startArray(Fields.STORES);
+                            for (StoreStatus storeStatus : shardStatusesEntry.getValue()) {
+                                builder.startObject();
+                                storeStatus.toXContent(builder, params);
+                                builder.endObject();
+                            }
+                            return builder.endArray().endObject();
+                        })
+                    ),
+                    ChunkedToXContentHelper.endObject(),
+                    ChunkedToXContentHelper.endObject()
+                )
+            ),
+
+            ChunkedToXContentHelper.endObject(),
+            ChunkedToXContentHelper.endObject()
+        );
     }
 
     static final class Fields {

+ 178 - 154
server/src/main/java/org/elasticsearch/action/admin/indices/shards/TransportIndicesShardStoresAction.java

@@ -9,16 +9,14 @@ package org.elasticsearch.action.admin.indices.shards;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
-import org.apache.lucene.util.CollectionUtil;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.admin.indices.shards.IndicesShardStoresResponse.Failure;
 import org.elasticsearch.action.admin.indices.shards.IndicesShardStoresResponse.StoreStatus;
 import org.elasticsearch.action.admin.indices.shards.IndicesShardStoresResponse.StoreStatus.AllocationStatus;
 import org.elasticsearch.action.support.ActionFilters;
-import org.elasticsearch.action.support.RefCountingRunnable;
+import org.elasticsearch.action.support.RefCountingListener;
 import org.elasticsearch.action.support.master.TransportMasterNodeReadAction;
-import org.elasticsearch.action.support.nodes.BaseNodesResponse;
 import org.elasticsearch.client.internal.node.NodeClient;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.block.ClusterBlockException;
@@ -27,33 +25,32 @@ import org.elasticsearch.cluster.health.ClusterHealthStatus;
 import org.elasticsearch.cluster.health.ClusterShardHealth;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
-import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
 import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
-import org.elasticsearch.cluster.routing.RoutingNodes;
 import org.elasticsearch.cluster.routing.RoutingTable;
-import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.collect.Iterators;
 import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.util.Maps;
+import org.elasticsearch.common.util.concurrent.ThrottledIterator;
 import org.elasticsearch.core.Releasable;
-import org.elasticsearch.core.Tuple;
-import org.elasticsearch.gateway.AsyncShardFetch;
 import org.elasticsearch.gateway.TransportNodesListGatewayStartedShards;
 import org.elasticsearch.gateway.TransportNodesListGatewayStartedShards.NodeGatewayStartedShards;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
+import java.util.EnumSet;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Queue;
-import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
 
 /**
@@ -98,35 +95,22 @@ public class TransportIndicesShardStoresAction extends TransportMasterNodeReadAc
         ClusterState state,
         ActionListener<IndicesShardStoresResponse> listener
     ) {
-        final RoutingTable routingTables = state.routingTable();
-        final RoutingNodes routingNodes = state.getRoutingNodes();
+        final DiscoveryNode[] nodes = state.nodes().getDataNodes().values().toArray(new DiscoveryNode[0]);
         final String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(state, request);
-        final Set<Tuple<ShardId, String>> shardsToFetch = new HashSet<>();
-
+        final RoutingTable routingTable = state.routingTable();
+        final Metadata metadata = state.metadata();
         logger.trace("using cluster state version [{}] to determine shards", state.version());
-        // collect relevant shard ids of the requested indices for fetching store infos
-        for (String index : concreteIndices) {
-            IndexRoutingTable indexShardRoutingTables = routingTables.index(index);
-            if (indexShardRoutingTables == null) {
-                continue;
-            }
-            final String customDataPath = IndexMetadata.INDEX_DATA_PATH_SETTING.get(state.metadata().index(index).getSettings());
-            for (int i = 0; i < indexShardRoutingTables.size(); i++) {
-                IndexShardRoutingTable routing = indexShardRoutingTables.shard(i);
-                final int shardId = routing.shardId().id();
-                ClusterShardHealth shardHealth = new ClusterShardHealth(shardId, routing);
-                if (request.shardStatuses().contains(shardHealth.getStatus())) {
-                    shardsToFetch.add(Tuple.tuple(routing.shardId(), customDataPath));
-                }
-            }
-        }
-
-        // async fetch store infos from all the nodes
-        // NOTE: instead of fetching shard store info one by one from every node (nShards * nNodes requests)
-        // we could fetch all shard store info from every node once (nNodes requests)
-        // we have to implement a TransportNodesAction instead of using TransportNodesListGatewayStartedShards
-        // for fetching shard stores info, that operates on a list of shards instead of a single shard
-        new AsyncShardStoresInfoFetches(state.nodes(), routingNodes, shardsToFetch, listener).start();
+        assert task instanceof CancellableTask;
+        new AsyncAction(
+            (CancellableTask) task,
+            concreteIndices,
+            request.shardStatuses(),
+            nodes,
+            routingTable,
+            metadata,
+            request.maxConcurrentShardRequests(),
+            listener
+        ).run();
     }
 
     @Override
@@ -135,152 +119,192 @@ public class TransportIndicesShardStoresAction extends TransportMasterNodeReadAc
             .indicesBlockedException(ClusterBlockLevel.METADATA_READ, indexNameExpressionResolver.concreteIndexNames(state, request));
     }
 
-    private class AsyncShardStoresInfoFetches {
-        private final DiscoveryNodes nodes;
-        private final RoutingNodes routingNodes;
-        private final Set<Tuple<ShardId, String>> shards;
-        private final ActionListener<IndicesShardStoresResponse> listener;
-        private final RefCountingRunnable refs = new RefCountingRunnable(this::finish);
-        private final Queue<InternalAsyncFetch.Response> fetchResponses;
+    // exposed for tests
+    void listShardStores(
+        TransportNodesListGatewayStartedShards.Request request,
+        ActionListener<TransportNodesListGatewayStartedShards.NodesGatewayStartedShards> listener
+    ) {
+        // async fetch store infos from all the nodes for this one shard
+        // NOTE: instead of fetching shard store info one by one from every node (nShards * nNodes requests)
+        // we could fetch all shard store info from every node once (nNodes requests)
+        // we have to implement a TransportNodesAction instead of using TransportNodesListGatewayStartedShards
+        // for fetching shard stores info, that operates on a list of shards instead of a single shard
+
+        client.executeLocally(TransportNodesListGatewayStartedShards.TYPE, request, listener);
+    }
+
+    private record ShardRequestContext(
+        ShardId shardId,
+        String customDataPath,
+        ActionListener<TransportNodesListGatewayStartedShards.NodesGatewayStartedShards> listener
+    ) {}
 
-        AsyncShardStoresInfoFetches(
-            DiscoveryNodes nodes,
-            RoutingNodes routingNodes,
-            Set<Tuple<ShardId, String>> shards,
+    private final class AsyncAction {
+        private final CancellableTask task;
+        private final DiscoveryNode[] nodes;
+        private final String[] concreteIndices;
+        private final RoutingTable routingTable;
+        private final Metadata metadata;
+        private final Map<String, Map<Integer, List<StoreStatus>>> indicesStatuses;
+        private final int maxConcurrentShardRequests;
+        private final Queue<Failure> failures;
+        private final EnumSet<ClusterHealthStatus> requestedStatuses;
+        private final RefCountingListener outerListener;
+
+        private AsyncAction(
+            CancellableTask task,
+            String[] concreteIndices,
+            EnumSet<ClusterHealthStatus> requestedStatuses,
+            DiscoveryNode[] nodes,
+            RoutingTable routingTable,
+            Metadata metadata,
+            int maxConcurrentShardRequests,
             ActionListener<IndicesShardStoresResponse> listener
         ) {
+            this.task = task;
             this.nodes = nodes;
-            this.routingNodes = routingNodes;
-            this.shards = shards;
-            this.listener = listener;
-            this.fetchResponses = new ConcurrentLinkedQueue<>();
+            this.concreteIndices = concreteIndices;
+            this.routingTable = routingTable;
+            this.metadata = metadata;
+            this.requestedStatuses = requestedStatuses;
+
+            this.indicesStatuses = Collections.synchronizedMap(Maps.newHashMapWithExpectedSize(concreteIndices.length));
+            this.maxConcurrentShardRequests = maxConcurrentShardRequests;
+            this.failures = new ConcurrentLinkedQueue<>();
+            this.outerListener = new RefCountingListener(1, listener.map(ignored -> {
+                task.ensureNotCancelled();
+                return new IndicesShardStoresResponse(Map.copyOf(indicesStatuses), List.copyOf(failures));
+            }));
         }
 
-        void start() {
-            try {
-                for (Tuple<ShardId, String> shard : shards) {
-                    new InternalAsyncFetch(logger, "shard_stores", shard.v1(), shard.v2(), routingNodes.size()).fetchData(
-                        nodes,
-                        Collections.emptySet()
-                    );
-                }
-            } finally {
-                refs.close();
-            }
+        private boolean isFailing() {
+            return outerListener.isFailing() || task.isCancelled();
         }
 
-        private void listStartedShards(
-            ShardId shardId,
-            String customDataPath,
-            DiscoveryNode[] nodes,
-            ActionListener<BaseNodesResponse<NodeGatewayStartedShards>> listener
-        ) {
-            var request = new TransportNodesListGatewayStartedShards.Request(shardId, customDataPath, nodes);
-            client.executeLocally(
-                TransportNodesListGatewayStartedShards.TYPE,
-                request,
-                ActionListener.wrap(listener::onResponse, listener::onFailure)
+        void run() {
+            ThrottledIterator.run(
+                Iterators.flatMap(Iterators.forArray(concreteIndices), this::getIndexIterator),
+                this::doShardRequest,
+                maxConcurrentShardRequests,
+                () -> {},
+                outerListener::close
             );
         }
 
-        private class InternalAsyncFetch extends AsyncShardFetch<NodeGatewayStartedShards> {
-
-            private final Releasable ref = refs.acquire();
-
-            InternalAsyncFetch(Logger logger, String type, ShardId shardId, String customDataPath, int expectedSize) {
-                super(logger, type, shardId, customDataPath, expectedSize);
+        private Iterator<ShardRequestContext> getIndexIterator(String indexName) {
+            if (isFailing()) {
+                return Collections.emptyIterator();
             }
 
-            @Override
-            protected synchronized void processAsyncFetch(
-                List<NodeGatewayStartedShards> responses,
-                List<FailedNodeException> failures,
-                long fetchingRound
-            ) {
-                fetchResponses.add(new Response(shardId, responses, failures));
-                ref.close();
+            final var indexRoutingTable = routingTable.index(indexName);
+            if (indexRoutingTable == null) {
+                return Collections.emptyIterator();
             }
 
-            @Override
-            protected void list(
-                ShardId shardId,
-                String customDataPath,
-                DiscoveryNode[] nodes,
-                ActionListener<BaseNodesResponse<NodeGatewayStartedShards>> listener
-            ) {
-                listStartedShards(shardId, customDataPath, nodes, listener);
-            }
+            return new IndexRequestContext(indexRoutingTable).getShardRequestContexts();
+        }
 
-            @Override
-            protected void reroute(ShardId shardId, String reason) {
-                // no-op
-            }
+        private void doShardRequest(Releasable ref, ShardRequestContext shardRequestContext) {
+            ActionListener.run(ActionListener.releaseAfter(shardRequestContext.listener(), ref), l -> {
+                if (isFailing()) {
+                    l.onResponse(null);
+                } else {
+                    listShardStores(
+                        new TransportNodesListGatewayStartedShards.Request(
+                            shardRequestContext.shardId(),
+                            shardRequestContext.customDataPath(),
+                            nodes
+                        ),
+                        l
+                    );
+                }
+            });
+        }
 
-            public class Response {
-                private final ShardId shardId;
-                private final List<NodeGatewayStartedShards> responses;
-                private final List<FailedNodeException> failures;
+        private class IndexRequestContext {
+            private final IndexRoutingTable indexRoutingTable;
+            private final Map<Integer, List<StoreStatus>> indexResults;
+
+            IndexRequestContext(IndexRoutingTable indexRoutingTable) {
+                this.indexRoutingTable = indexRoutingTable;
+                this.indexResults = Collections.synchronizedMap(Maps.newHashMapWithExpectedSize(indexRoutingTable.size()));
+            }
 
-                Response(ShardId shardId, List<NodeGatewayStartedShards> responses, List<FailedNodeException> failures) {
-                    this.shardId = shardId;
-                    this.responses = responses;
-                    this.failures = failures;
+            Iterator<ShardRequestContext> getShardRequestContexts() {
+                try (var shardListeners = new RefCountingListener(1, outerListener.acquire(ignored -> putResults()))) {
+                    final var customDataPath = IndexMetadata.INDEX_DATA_PATH_SETTING.get(
+                        metadata.index(indexRoutingTable.getIndex()).getSettings()
+                    );
+                    final var shardRequestContexts = new ArrayList<ShardRequestContext>(indexRoutingTable.size());
+                    for (int shardNum = 0; shardNum < indexRoutingTable.size(); shardNum++) {
+                        final var indexShardRoutingTable = indexRoutingTable.shard(shardNum);
+                        final var clusterShardHealth = new ClusterShardHealth(shardNum, indexShardRoutingTable);
+                        if (requestedStatuses.contains(clusterShardHealth.getStatus())) {
+                            shardRequestContexts.add(
+                                new ShardRequestContext(
+                                    indexShardRoutingTable.shardId(),
+                                    customDataPath,
+                                    shardListeners.acquire(fetchResponse -> handleFetchResponse(indexShardRoutingTable, fetchResponse))
+                                )
+                            );
+                        }
+                    }
+                    return shardRequestContexts.iterator();
                 }
             }
-        }
 
-        void finish() {
-            Map<String, Map<Integer, List<StoreStatus>>> indicesStatuses = new HashMap<>();
-            List<Failure> failures = new ArrayList<>();
-            for (InternalAsyncFetch.Response fetchResponse : fetchResponses) {
-                var indexName = fetchResponse.shardId.getIndexName();
-                var shardId = fetchResponse.shardId.id();
-                var indexStatuses = indicesStatuses.computeIfAbsent(indexName, k -> new HashMap<>());
-                var storeStatuses = indexStatuses.computeIfAbsent(shardId, k -> new ArrayList<>());
+            private void handleFetchResponse(
+                IndexShardRoutingTable indexShardRoutingTable,
+                TransportNodesListGatewayStartedShards.NodesGatewayStartedShards fetchResponse
+            ) {
+                if (isFailing()) {
+                    return;
+                }
 
-                for (NodeGatewayStartedShards r : fetchResponse.responses) {
-                    if (shardExistsInNode(r)) {
-                        var allocationStatus = getAllocationStatus(indexName, shardId, r.getNode());
-                        storeStatuses.add(new StoreStatus(r.getNode(), r.allocationId(), allocationStatus, r.storeException()));
-                    }
+                final var shardId = indexShardRoutingTable.shardId();
+
+                for (FailedNodeException failure : fetchResponse.failures()) {
+                    failures.add(new Failure(failure.nodeId(), shardId.getIndexName(), shardId.getId(), failure.getCause()));
                 }
 
-                for (FailedNodeException failure : fetchResponse.failures) {
-                    failures.add(new Failure(failure.nodeId(), indexName, shardId, failure.getCause()));
+                final var shardStores = fetchResponse.getNodes()
+                    .stream()
+                    .filter(IndexRequestContext::shardExistsInNode)
+                    .map(
+                        nodeResponse -> new StoreStatus(
+                            nodeResponse.getNode(),
+                            nodeResponse.allocationId(),
+                            getAllocationStatus(indexShardRoutingTable, nodeResponse.getNode()),
+                            nodeResponse.storeException()
+                        )
+                    )
+                    .sorted()
+                    .toList();
+
+                indexResults.put(shardId.getId(), shardStores);
+            }
+
+            private void putResults() {
+                if (isFailing() == false && indexResults.isEmpty() == false) {
+                    indicesStatuses.put(indexRoutingTable.getIndex().getName(), Map.copyOf(indexResults));
                 }
             }
-            // make the status structure immutable
-            indicesStatuses.replaceAll((k, v) -> {
-                v.replaceAll((s, l) -> {
-                    CollectionUtil.timSort(l);
-                    return List.copyOf(l);
-                });
-                return Map.copyOf(v);
-            });
-            listener.onResponse(new IndicesShardStoresResponse(Map.copyOf(indicesStatuses), List.copyOf(failures)));
-        }
 
-        private AllocationStatus getAllocationStatus(String index, int shardID, DiscoveryNode node) {
-            for (ShardRouting shardRouting : routingNodes.node(node.getId())) {
-                ShardId shardId = shardRouting.shardId();
-                if (shardId.id() == shardID && shardId.getIndexName().equals(index)) {
-                    if (shardRouting.primary()) {
-                        return AllocationStatus.PRIMARY;
-                    } else if (shardRouting.assignedToNode()) {
-                        return AllocationStatus.REPLICA;
-                    } else {
-                        return AllocationStatus.UNUSED;
+            /**
+             * A shard exists/existed in a node only if shard state file exists in the node
+             */
+            private static boolean shardExistsInNode(final NodeGatewayStartedShards response) {
+                return response.storeException() != null || response.allocationId() != null;
+            }
+
+            private static AllocationStatus getAllocationStatus(IndexShardRoutingTable indexShardRoutingTable, DiscoveryNode node) {
+                for (final var shardRouting : indexShardRoutingTable.assignedShards()) {
+                    if (node.getId().equals(shardRouting.currentNodeId())) {
+                        return shardRouting.primary() ? AllocationStatus.PRIMARY : AllocationStatus.REPLICA;
                     }
                 }
+                return AllocationStatus.UNUSED;
             }
-            return AllocationStatus.UNUSED;
-        }
-
-        /**
-         * A shard exists/existed in a node only if shard state file exists in the node
-         */
-        private static boolean shardExistsInNode(final NodeGatewayStartedShards response) {
-            return response.storeException() != null || response.allocationId() != null;
         }
     }
 }

+ 5 - 1
server/src/main/java/org/elasticsearch/action/support/RefCountingListener.java

@@ -104,7 +104,7 @@ public final class RefCountingListener implements Releasable {
     }
 
     /**
-     * Release the original reference to this object, which commpletes the delegate {@link ActionListener} if there are no other references.
+     * Release the original reference to this object, which completes the delegate {@link ActionListener} if there are no other references.
      *
      * It is invalid to call this method more than once. Doing so will trip an assertion if assertions are enabled, but will be ignored
      * otherwise. This deviates from the contract of {@link java.io.Closeable}.
@@ -225,4 +225,8 @@ public final class RefCountingListener implements Releasable {
     public String toString() {
         return "refCounting[" + delegate + "]";
     }
+
+    public boolean isFailing() {
+        return exceptionRef.get() != null;
+    }
 }

+ 7 - 15
server/src/main/java/org/elasticsearch/rest/action/admin/indices/RestIndicesShardStoresAction.java

@@ -10,21 +10,18 @@ package org.elasticsearch.rest.action.admin.indices;
 
 import org.elasticsearch.action.admin.indices.shards.IndicesShardStoresAction;
 import org.elasticsearch.action.admin.indices.shards.IndicesShardStoresRequest;
-import org.elasticsearch.action.admin.indices.shards.IndicesShardStoresResponse;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.client.internal.node.NodeClient;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.rest.BaseRestHandler;
 import org.elasticsearch.rest.RestRequest;
-import org.elasticsearch.rest.RestResponse;
-import org.elasticsearch.rest.action.RestBuilderListener;
-import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.rest.action.RestCancellableNodeClient;
+import org.elasticsearch.rest.action.RestChunkedToXContentListener;
 
 import java.io.IOException;
 import java.util.List;
 
 import static org.elasticsearch.rest.RestRequest.Method.GET;
-import static org.elasticsearch.rest.RestStatus.OK;
 
 /**
  * Rest action for {@link IndicesShardStoresAction}
@@ -54,17 +51,12 @@ public class RestIndicesShardStoresAction extends BaseRestHandler {
         if (request.hasParam("status")) {
             indicesShardStoresRequest.shardStatuses(Strings.splitStringByCommaToArray(request.param("status")));
         }
+        indicesShardStoresRequest.maxConcurrentShardRequests(
+            request.paramAsInt("max_concurrent_shard_requests", indicesShardStoresRequest.maxConcurrentShardRequests())
+        );
         indicesShardStoresRequest.indicesOptions(IndicesOptions.fromRequest(request, indicesShardStoresRequest.indicesOptions()));
-        return channel -> client.admin()
+        return channel -> new RestCancellableNodeClient(client, request.getHttpChannel()).admin()
             .indices()
-            .shardStores(indicesShardStoresRequest, new RestBuilderListener<IndicesShardStoresResponse>(channel) {
-                @Override
-                public RestResponse buildResponse(IndicesShardStoresResponse response, XContentBuilder builder) throws Exception {
-                    builder.startObject();
-                    response.toXContent(builder, request);
-                    builder.endObject();
-                    return new RestResponse(OK, builder);
-                }
-            });
+            .shardStores(indicesShardStoresRequest, new RestChunkedToXContentListener<>(channel));
     }
 }

+ 9 - 3
server/src/test/java/org/elasticsearch/action/admin/indices/shards/IndicesShardStoreResponseTests.java

@@ -16,6 +16,8 @@ import org.elasticsearch.action.admin.indices.shards.IndicesShardStoresResponse.
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.ChunkedToXContent;
+import org.elasticsearch.test.AbstractChunkedSerializingTestCase;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.transport.NodeDisconnectedException;
 import org.elasticsearch.xcontent.ToXContent;
@@ -48,10 +50,10 @@ public class IndicesShardStoreResponseTests extends ESTestCase {
         var failures = List.of(new Failure("node1", "test", 3, new NodeDisconnectedException(node1, "")));
         var storesResponse = new IndicesShardStoresResponse(indexStoreStatuses, failures);
 
+        AbstractChunkedSerializingTestCase.assertChunkCount(storesResponse, this::getExpectedChunkCount);
+
         XContentBuilder contentBuilder = XContentFactory.jsonBuilder();
-        contentBuilder.startObject();
-        storesResponse.toXContent(contentBuilder, ToXContent.EMPTY_PARAMS);
-        contentBuilder.endObject();
+        ChunkedToXContent.wrapAsToXContent(storesResponse).toXContent(contentBuilder, ToXContent.EMPTY_PARAMS);
         BytesReference bytes = BytesReference.bytes(contentBuilder);
 
         try (XContentParser parser = createParser(JsonXContent.jsonXContent, bytes)) {
@@ -97,6 +99,10 @@ public class IndicesShardStoreResponseTests extends ESTestCase {
         }
     }
 
+    private int getExpectedChunkCount(IndicesShardStoresResponse response) {
+        return 6 + response.getFailures().size() + response.getStoreStatuses().values().stream().mapToInt(m -> 4 + m.size()).sum();
+    }
+
     public void testStoreStatusOrdering() throws Exception {
         DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT);
         List<StoreStatus> orderedStoreStatuses = new ArrayList<>();

+ 295 - 0
server/src/test/java/org/elasticsearch/action/admin/indices/shards/TransportIndicesShardStoresActionTests.java

@@ -0,0 +1,295 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.admin.indices.shards;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.FailedNodeException;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.cluster.ClusterName;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.TestShardRoutingRoleStrategies;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.cluster.routing.IndexRoutingTable;
+import org.elasticsearch.cluster.routing.RoutingTable;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.settings.ClusterSettings;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
+import org.elasticsearch.gateway.TransportNodesListGatewayStartedShards;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.indices.SystemIndices;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.TaskCancelHelper;
+import org.elasticsearch.tasks.TaskCancelledException;
+import org.elasticsearch.tasks.TaskId;
+import org.elasticsearch.test.ClusterServiceUtils;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.transport.FakeTransport;
+import org.elasticsearch.transport.TransportService;
+
+import java.io.Closeable;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.Semaphore;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.action.admin.indices.shards.IndicesShardStoresRequest.DEFAULT_MAX_CONCURRENT_SHARD_REQUESTS;
+import static org.hamcrest.Matchers.anEmptyMap;
+import static org.hamcrest.Matchers.empty;
+
+public class TransportIndicesShardStoresActionTests extends ESTestCase {
+
+    public void testEmpty() {
+        runTest(new TestHarness() {
+            @Override
+            void runTest() {
+                final var request = new IndicesShardStoresRequest();
+                request.shardStatuses("green", "red"); // newly-created shards are in yellow health so this matches none of them
+                final var future = new PlainActionFuture<IndicesShardStoresResponse>();
+                action.execute(
+                    new CancellableTask(1, "transport", IndicesShardStoresAction.NAME, "", TaskId.EMPTY_TASK_ID, Map.of()),
+                    request,
+                    future
+                );
+                assertTrue(future.isDone());
+
+                final var response = future.actionGet(0L);
+                assertThat(response.getFailures(), empty());
+                assertThat(response.getStoreStatuses(), anEmptyMap());
+                assertThat(shardsWithFailures, empty());
+                assertThat(foundShards, empty());
+            }
+        });
+    }
+
+    public void testNonempty() {
+        runTest(new TestHarness() {
+            @Override
+            void runTest() {
+                final var request = new IndicesShardStoresRequest();
+                request.shardStatuses(randomFrom("yellow", "all")); // newly-created shards are in yellow health so this matches all of them
+                final var future = new PlainActionFuture<IndicesShardStoresResponse>();
+                action.execute(
+                    new CancellableTask(1, "transport", IndicesShardStoresAction.NAME, "", TaskId.EMPTY_TASK_ID, Map.of()),
+                    request,
+                    future
+                );
+                assertFalse(future.isDone());
+
+                deterministicTaskQueue.runAllTasks();
+                assertTrue(future.isDone());
+                final var response = future.actionGet();
+
+                assertEquals(
+                    shardsWithFailures,
+                    response.getFailures().stream().map(f -> f.index() + "/" + f.shardId()).collect(Collectors.toSet())
+                );
+
+                for (final var indexRoutingTable : clusterState.routingTable()) {
+                    final var indexResponse = response.getStoreStatuses().get(indexRoutingTable.getIndex().getName());
+                    assertNotNull(indexResponse);
+                    for (int shardNum = 0; shardNum < indexRoutingTable.size(); shardNum++) {
+                        final var shardResponse = indexResponse.get(shardNum);
+                        assertNotNull(shardResponse);
+                        if (foundShards.contains(indexRoutingTable.shard(shardNum).shardId())) {
+                            assertEquals(1, shardResponse.size());
+                            assertSame(localNode, shardResponse.get(0).getNode());
+                        } else {
+                            assertThat(shardResponse, empty());
+                        }
+                    }
+                }
+            }
+        });
+    }
+
+    public void testCancellation() {
+        runTest(new TestHarness() {
+            @Override
+            void runTest() {
+                final var task = new CancellableTask(1, "transport", IndicesShardStoresAction.NAME, "", TaskId.EMPTY_TASK_ID, Map.of());
+                final var request = new IndicesShardStoresRequest();
+                request.shardStatuses(randomFrom("yellow", "all"));
+                final var future = new PlainActionFuture<IndicesShardStoresResponse>();
+                action.execute(task, request, future);
+                TaskCancelHelper.cancel(task, "testing");
+                listExpected = false;
+                assertFalse(future.isDone());
+                deterministicTaskQueue.runAllTasks();
+                assertTrue(future.isDone());
+                expectThrows(TaskCancelledException.class, () -> future.actionGet(0L));
+            }
+        });
+    }
+
+    public void testFailure() {
+        runTest(new TestHarness() {
+            @Override
+            void runTest() {
+                final var request = new IndicesShardStoresRequest();
+                request.shardStatuses(randomFrom("yellow", "all"));
+                final var future = new PlainActionFuture<IndicesShardStoresResponse>();
+                action.execute(
+                    new CancellableTask(1, "transport", IndicesShardStoresAction.NAME, "", TaskId.EMPTY_TASK_ID, Map.of()),
+                    request,
+                    future
+                );
+                assertFalse(future.isDone());
+                failOneRequest = true;
+                deterministicTaskQueue.runAllTasks();
+                assertTrue(future.isDone());
+                assertFalse(failOneRequest);
+                assertEquals("simulated", expectThrows(ElasticsearchException.class, () -> future.actionGet(0L)).getMessage());
+            }
+        });
+    }
+
+    private static void runTest(TestHarness testHarness) {
+        try (testHarness) {
+            testHarness.runTest();
+        }
+    }
+
+    private abstract static class TestHarness implements Closeable {
+        final DeterministicTaskQueue deterministicTaskQueue;
+        final DiscoveryNode localNode;
+        final ClusterState clusterState;
+        final HashSet<String> shardsWithFailures = new HashSet<>();
+        final HashSet<ShardId> foundShards = new HashSet<>();
+        final TransportIndicesShardStoresAction action;
+        final ClusterService clusterService;
+
+        boolean listExpected = true;
+        boolean failOneRequest = false;
+
+        TestHarness() {
+            this.deterministicTaskQueue = new DeterministicTaskQueue();
+            this.localNode = new DiscoveryNode("local", buildNewFakeTransportAddress(), Version.CURRENT);
+
+            final var threadPool = deterministicTaskQueue.getThreadPool();
+
+            final var settings = Settings.EMPTY;
+            final var clusterSettings = ClusterSettings.createBuiltInClusterSettings(settings);
+
+            final var transportService = new TransportService(
+                settings,
+                new FakeTransport(),
+                threadPool,
+                TransportService.NOOP_TRANSPORT_INTERCEPTOR,
+                ignored -> localNode,
+                clusterSettings,
+                Set.of()
+            );
+
+            final var nodes = DiscoveryNodes.builder();
+            nodes.add(localNode).localNodeId(localNode.getId()).masterNodeId(localNode.getId());
+
+            final var indexCount = between(1, 100);
+            final var metadata = Metadata.builder();
+            final var routingTable = RoutingTable.builder();
+            for (int i = 0; i < indexCount; i++) {
+                final var indexMetadata = IndexMetadata.builder("index-" + i)
+                    .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT))
+                    .numberOfShards(between(1, 3))
+                    .numberOfReplicas(between(0, 2))
+                    .build();
+                metadata.put(indexMetadata, false);
+
+                final var irt = IndexRoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY, indexMetadata.getIndex())
+                    .initializeAsNew(indexMetadata);
+                routingTable.add(irt);
+            }
+            this.clusterState = ClusterState.builder(ClusterName.DEFAULT)
+                .nodes(nodes)
+                .metadata(metadata)
+                .routingTable(routingTable)
+                .build();
+
+            this.clusterService = ClusterServiceUtils.createClusterService(clusterState, threadPool, clusterSettings);
+
+            this.action = new TransportIndicesShardStoresAction(
+                transportService,
+                clusterService,
+                threadPool,
+                new ActionFilters(Set.of()),
+                new IndexNameExpressionResolver(threadPool.getThreadContext(), new SystemIndices(List.of())),
+                null
+            ) {
+                private final Semaphore pendingActionPermits = new Semaphore(DEFAULT_MAX_CONCURRENT_SHARD_REQUESTS);
+
+                @Override
+                void listShardStores(
+                    TransportNodesListGatewayStartedShards.Request request,
+                    ActionListener<TransportNodesListGatewayStartedShards.NodesGatewayStartedShards> listener
+                ) {
+                    assertTrue(pendingActionPermits.tryAcquire());
+                    assertTrue(listExpected);
+                    deterministicTaskQueue.scheduleNow(() -> {
+                        pendingActionPermits.release();
+
+                        final var simulateNodeFailure = rarely();
+                        if (simulateNodeFailure) {
+                            assertTrue(shardsWithFailures.add(request.shardId().getIndexName() + "/" + request.shardId().getId()));
+                        }
+
+                        final var foundShardStore = rarely();
+                        if (foundShardStore) {
+                            assertTrue(foundShards.add(request.shardId()));
+                        }
+
+                        if (failOneRequest) {
+                            failOneRequest = false;
+                            listener.onFailure(new ElasticsearchException("simulated"));
+                        } else {
+                            listener.onResponse(
+                                new TransportNodesListGatewayStartedShards.NodesGatewayStartedShards(
+                                    clusterService.getClusterName(),
+                                    foundShardStore
+                                        ? List.of(
+                                            new TransportNodesListGatewayStartedShards.NodeGatewayStartedShards(
+                                                localNode,
+                                                randomAlphaOfLength(10),
+                                                randomBoolean()
+                                            )
+                                        )
+                                        : List.of(),
+                                    simulateNodeFailure
+                                        ? List.of(
+                                            new FailedNodeException(
+                                                randomAlphaOfLength(10),
+                                                "test failure",
+                                                new ElasticsearchException("simulated")
+                                            )
+                                        )
+                                        : List.of()
+                                )
+                            );
+                        }
+                    });
+                }
+            };
+        }
+
+        abstract void runTest();
+
+        @Override
+        public void close() {
+            clusterService.close();
+        }
+    }
+}

+ 6 - 0
server/src/test/java/org/elasticsearch/action/support/RefCountingListenerTests.java

@@ -78,11 +78,13 @@ public class RefCountingListenerTests extends ESTestCase {
 
             var listener = refs.acquire();
             assertThat(listener.toString(), containsString("refCounting[test listener]"));
+            assertFalse(refs.isFailing());
             if (randomBoolean()) {
                 listener.onResponse(null);
             } else {
                 listener.onFailure(new ElasticsearchException("simulated"));
                 exceptionCount.incrementAndGet();
+                assertTrue(refs.isFailing());
             }
 
             var reachChecker = new ReachabilityChecker();
@@ -109,6 +111,9 @@ public class RefCountingListenerTests extends ESTestCase {
                 assertFalse(consumed.get());
                 exceptionCount.incrementAndGet();
             }
+
+            assertEquals(exceptionCount.get() > 0, refs.isFailing());
+
             reachChecker.ensureUnreachable();
             assertThat(consumingListener.toString(), containsString("refCounting[test listener][null]"));
 
@@ -133,6 +138,7 @@ public class RefCountingListenerTests extends ESTestCase {
                 }
             }
 
+            assertEquals(exceptionCount.get() > 0, refs.isFailing());
             assertFalse(executed.get());
         }