Browse Source

Deduplicate ShardRouting instances when building ClusterInfo (#89190)

The equality checks on these in `DiskThresholdDecider` become very expensive
during reroute in a large cluster. Deduplicating these when building the `ClusterInfo`
saves more than 2% CPU time during many-shards benchmark bootstrapping because
the lookup of the shard data path by shard-routing mostly hit instance equality.
Also, this saves a little memory.

This PR also moves the callback for building `ClusterInfo` from the stats response to
the management pool as it is now more expensive (though the overall CPU use from it is trivial
relative to the cost savings during reroute) and was questionable to run on
a transport thread in a large cluster to begin with.

Co-authored-by: David Turner <david.turner@elastic.co>
Armin Braun 3 năm trước cách đây
mục cha
commit
c6c05bb625

+ 88 - 67
server/src/main/java/org/elasticsearch/cluster/InternalClusterInfoService.java

@@ -18,12 +18,13 @@ import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
 import org.elasticsearch.action.admin.indices.stats.IndicesStatsRequest;
 import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse;
 import org.elasticsearch.action.admin.indices.stats.ShardStats;
-import org.elasticsearch.action.support.DefaultShardOperationFailedException;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.action.support.ThreadedActionListener;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.cluster.block.ClusterBlockException;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.routing.RoutingTable;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.allocation.DiskThresholdSettings;
 import org.elasticsearch.cluster.service.ClusterService;
@@ -93,6 +94,8 @@ public class InternalClusterInfoService implements ClusterInfoService, ClusterSt
 
     private final Object mutex = new Object();
     private final List<ActionListener<ClusterInfo>> nextRefreshListeners = new ArrayList<>();
+
+    private final ClusterService clusterService;
     private AsyncRefresh currentRefresh;
     private RefreshScheduler refreshScheduler;
 
@@ -102,6 +105,7 @@ public class InternalClusterInfoService implements ClusterInfoService, ClusterSt
         this.indicesStatsSummary = IndicesStatsSummary.EMPTY;
         this.threadPool = threadPool;
         this.client = client;
+        this.clusterService = clusterService;
         this.updateFrequency = INTERNAL_CLUSTER_INFO_UPDATE_INTERVAL_SETTING.get(settings);
         this.fetchTimeout = INTERNAL_CLUSTER_INFO_TIMEOUT_SETTING.get(settings);
         this.enabled = DiskThresholdSettings.CLUSTER_ROUTING_ALLOCATION_DISK_THRESHOLD_ENABLED_SETTING.get(settings);
@@ -191,76 +195,92 @@ public class InternalClusterInfoService implements ClusterInfoService, ClusterSt
             indicesStatsRequest.store(true);
             indicesStatsRequest.indicesOptions(IndicesOptions.STRICT_EXPAND_OPEN_CLOSED_HIDDEN);
             indicesStatsRequest.timeout(fetchTimeout);
-            client.admin().indices().stats(indicesStatsRequest, ActionListener.runAfter(new ActionListener<>() {
-                @Override
-                public void onResponse(IndicesStatsResponse indicesStatsResponse) {
-                    logger.trace("received indices stats response");
-
-                    if (indicesStatsResponse.getShardFailures().length > 0) {
-                        final Set<String> failedNodeIds = new HashSet<>();
-                        for (final DefaultShardOperationFailedException shardFailure : indicesStatsResponse.getShardFailures()) {
-                            if (shardFailure.getCause()instanceof final FailedNodeException failedNodeException) {
-                                if (failedNodeIds.add(failedNodeException.nodeId())) {
-                                    logger.warn(
-                                        () -> format("failed to retrieve shard stats from node [%s]", failedNodeException.nodeId()),
-                                        failedNodeException.getCause()
-                                    );
+            client.admin()
+                .indices()
+                .stats(
+                    indicesStatsRequest,
+                    new ThreadedActionListener<>(
+                        logger,
+                        threadPool,
+                        ThreadPool.Names.MANAGEMENT,
+                        ActionListener.runAfter(new ActionListener<>() {
+                            @Override
+                            public void onResponse(IndicesStatsResponse indicesStatsResponse) {
+                                logger.trace("received indices stats response");
+
+                                if (indicesStatsResponse.getShardFailures().length > 0) {
+                                    final Set<String> failedNodeIds = new HashSet<>();
+                                    for (final var shardFailure : indicesStatsResponse.getShardFailures()) {
+                                        if (shardFailure.getCause()instanceof final FailedNodeException failedNodeException) {
+                                            if (failedNodeIds.add(failedNodeException.nodeId())) {
+                                                logger.warn(
+                                                    () -> format(
+                                                        "failed to retrieve shard stats from node [%s]",
+                                                        failedNodeException.nodeId()
+                                                    ),
+                                                    failedNodeException.getCause()
+                                                );
+                                            }
+                                            logger.trace(
+                                                () -> format(
+                                                    "failed to retrieve stats for shard [%s][%s]",
+                                                    shardFailure.index(),
+                                                    shardFailure.shardId()
+                                                ),
+                                                shardFailure.getCause()
+                                            );
+                                        } else {
+                                            logger.warn(
+                                                () -> format(
+                                                    "failed to retrieve stats for shard [%s][%s]",
+                                                    shardFailure.index(),
+                                                    shardFailure.shardId()
+                                                ),
+                                                shardFailure.getCause()
+                                            );
+                                        }
+                                    }
                                 }
-                                logger.trace(
-                                    () -> format(
-                                        "failed to retrieve stats for shard [%s][%s]",
-                                        shardFailure.index(),
-                                        shardFailure.shardId()
-                                    ),
-                                    shardFailure.getCause()
-                                );
-                            } else {
-                                logger.warn(
-                                    () -> format(
-                                        "failed to retrieve stats for shard [%s][%s]",
-                                        shardFailure.index(),
-                                        shardFailure.shardId()
-                                    ),
-                                    shardFailure.getCause()
-                                );
-                            }
-                        }
-                    }
 
-                    final ShardStats[] stats = indicesStatsResponse.getShards();
-                    final Map<String, Long> shardSizeByIdentifierBuilder = new HashMap<>();
-                    final Map<ShardId, Long> shardDataSetSizeBuilder = new HashMap<>();
-                    final Map<ShardRouting, String> dataPathByShardRoutingBuilder = new HashMap<>();
-                    final Map<ClusterInfo.NodeAndPath, ClusterInfo.ReservedSpace.Builder> reservedSpaceBuilders = new HashMap<>();
-                    buildShardLevelInfo(
-                        stats,
-                        shardSizeByIdentifierBuilder,
-                        shardDataSetSizeBuilder,
-                        dataPathByShardRoutingBuilder,
-                        reservedSpaceBuilders
-                    );
+                                final ShardStats[] stats = indicesStatsResponse.getShards();
+                                final Map<String, Long> shardSizeByIdentifierBuilder = new HashMap<>();
+                                final Map<ShardId, Long> shardDataSetSizeBuilder = new HashMap<>();
+                                final Map<ShardRouting, String> dataPathByShardRoutingBuilder = new HashMap<>();
+                                final Map<ClusterInfo.NodeAndPath, ClusterInfo.ReservedSpace.Builder> reservedSpaceBuilders =
+                                    new HashMap<>();
+                                buildShardLevelInfo(
+                                    clusterService.state().routingTable(),
+                                    stats,
+                                    shardSizeByIdentifierBuilder,
+                                    shardDataSetSizeBuilder,
+                                    dataPathByShardRoutingBuilder,
+                                    reservedSpaceBuilders
+                                );
 
-                    final Map<ClusterInfo.NodeAndPath, ClusterInfo.ReservedSpace> rsrvdSpace = new HashMap<>();
-                    reservedSpaceBuilders.forEach((nodeAndPath, builder) -> rsrvdSpace.put(nodeAndPath, builder.build()));
+                                final Map<ClusterInfo.NodeAndPath, ClusterInfo.ReservedSpace> rsrvdSpace = new HashMap<>();
+                                reservedSpaceBuilders.forEach((nodeAndPath, builder) -> rsrvdSpace.put(nodeAndPath, builder.build()));
 
-                    indicesStatsSummary = new IndicesStatsSummary(
-                        Map.copyOf(shardSizeByIdentifierBuilder),
-                        Map.copyOf(shardDataSetSizeBuilder),
-                        Map.copyOf(dataPathByShardRoutingBuilder),
-                        Map.copyOf(rsrvdSpace)
-                    );
-                }
+                                indicesStatsSummary = new IndicesStatsSummary(
+                                    Map.copyOf(shardSizeByIdentifierBuilder),
+                                    Map.copyOf(shardDataSetSizeBuilder),
+                                    Map.copyOf(dataPathByShardRoutingBuilder),
+                                    Map.copyOf(rsrvdSpace)
+                                );
+                            }
 
-                @Override
-                public void onFailure(Exception e) {
-                    if (e instanceof ClusterBlockException) {
-                        logger.trace("failed to retrieve indices stats", e);
-                    } else {
-                        logger.warn("failed to retrieve indices stats", e);
-                    }
-                    indicesStatsSummary = IndicesStatsSummary.EMPTY;
-                }
-            }, this::onStatsProcessed));
+                            @Override
+                            public void onFailure(Exception e) {
+                                if (e instanceof ClusterBlockException) {
+                                    logger.trace("failed to retrieve indices stats", e);
+                                } else {
+                                    logger.warn("failed to retrieve indices stats", e);
+                                }
+                                indicesStatsSummary = IndicesStatsSummary.EMPTY;
+                            }
+                        }, this::onStatsProcessed),
+                        false
+                    )
+                );
         }
 
         private void fetchNodeStats() {
@@ -426,6 +446,7 @@ public class InternalClusterInfoService implements ClusterInfoService, ClusterSt
     }
 
     static void buildShardLevelInfo(
+        RoutingTable routingTable,
         ShardStats[] stats,
         Map<String, Long> shardSizes,
         Map<ShardId, Long> shardDataSetSizeBuilder,
@@ -433,7 +454,7 @@ public class InternalClusterInfoService implements ClusterInfoService, ClusterSt
         Map<ClusterInfo.NodeAndPath, ClusterInfo.ReservedSpace.Builder> reservedSpaceByShard
     ) {
         for (ShardStats s : stats) {
-            final ShardRouting shardRouting = s.getShardRouting();
+            final ShardRouting shardRouting = routingTable.deduplicate(s.getShardRouting());
             newShardRoutingToDataPath.put(shardRouting, s.getDataPath());
 
             final StoreStats storeStats = s.getStats().getStore();

+ 27 - 0
server/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java

@@ -147,6 +147,33 @@ public class RoutingTable implements Iterable<IndexRoutingTable>, Diffable<Routi
         return shard;
     }
 
+    /**
+     * Try to deduplicate the given shard routing with an equal instance found in this routing table. This is used by the logic of the
+     * {@link org.elasticsearch.cluster.routing.allocation.decider.DiskThresholdDecider} and
+     * {@link org.elasticsearch.cluster.InternalClusterInfoService} to deduplicate instances created by a master node and those read from
+     * the network to speed up the use of {@link ShardRouting} as a map key in {@link org.elasticsearch.cluster.ClusterInfo#getDataPath}.
+     *
+     * @param shardRouting shard routing to deduplicate
+     * @return deduplicated shard routing from this routing table if an equivalent shard routing was found or the given instance otherwise
+     */
+    public ShardRouting deduplicate(ShardRouting shardRouting) {
+        final IndexRoutingTable indexShardRoutingTable = indicesRouting.get(shardRouting.index().getName());
+        if (indexShardRoutingTable == null) {
+            return shardRouting;
+        }
+        final IndexShardRoutingTable shardRoutingTable = indexShardRoutingTable.shard(shardRouting.id());
+        if (shardRoutingTable == null) {
+            return shardRouting;
+        }
+        for (int i = 0; i < shardRoutingTable.size(); i++) {
+            ShardRouting found = shardRoutingTable.shard(i);
+            if (shardRouting.equals(found)) {
+                return found;
+            }
+        }
+        return shardRouting;
+    }
+
     @Nullable
     public ShardRouting getByAllocationId(ShardId shardId, String allocationId) {
         final IndexRoutingTable indexRoutingTable = index(shardId.getIndex());

+ 1 - 1
server/src/main/java/org/elasticsearch/cluster/routing/ShardRouting.java

@@ -671,7 +671,7 @@ public final class ShardRouting implements Writeable, ToXContentObject {
     /** returns true if the current routing is identical to the other routing in all but meta fields, i.e., unassigned info */
     public boolean equalsIgnoringMetadata(ShardRouting other) {
         return primary == other.primary
-            && Objects.equals(shardId, other.shardId)
+            && shardId.equals(other.shardId)
             && Objects.equals(currentNodeId, other.currentNodeId)
             && Objects.equals(relocatingNodeId, other.relocatingNodeId)
             && Objects.equals(allocationId, other.allocationId)

+ 0 - 8
server/src/main/java/org/elasticsearch/cluster/routing/allocation/decider/DiskThresholdDecider.java

@@ -605,14 +605,6 @@ public class DiskThresholdDecider extends AllocationDecider {
             return diskUsage.getPath();
         }
 
-        String getNodeId() {
-            return diskUsage.getNodeId();
-        }
-
-        String getNodeName() {
-            return diskUsage.getNodeName();
-        }
-
         long getTotalBytes() {
             return diskUsage.getTotalBytes();
         }

+ 9 - 1
server/src/test/java/org/elasticsearch/cluster/DiskUsageTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.action.admin.indices.stats.CommonStats;
 import org.elasticsearch.action.admin.indices.stats.ShardStats;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.routing.RecoverySource.PeerRecoverySource;
+import org.elasticsearch.cluster.routing.RoutingTable;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRoutingHelper;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
@@ -124,7 +125,14 @@ public class DiskUsageTests extends ESTestCase {
         Map<String, Long> shardSizes = new HashMap<>();
         Map<ShardId, Long> shardDataSetSizes = new HashMap<>();
         Map<ShardRouting, String> routingToPath = new HashMap<>();
-        InternalClusterInfoService.buildShardLevelInfo(stats, shardSizes, shardDataSetSizes, routingToPath, new HashMap<>());
+        InternalClusterInfoService.buildShardLevelInfo(
+            RoutingTable.EMPTY_ROUTING_TABLE,
+            stats,
+            shardSizes,
+            shardDataSetSizes,
+            routingToPath,
+            new HashMap<>()
+        );
         assertEquals(2, shardSizes.size());
         assertTrue(shardSizes.containsKey(ClusterInfo.shardIdentifierFromRouting(test_0)));
         assertTrue(shardSizes.containsKey(ClusterInfo.shardIdentifierFromRouting(test_1)));

+ 1 - 0
server/src/test/java/org/elasticsearch/cluster/InternalClusterInfoServiceSchedulingTests.java

@@ -105,6 +105,7 @@ public class InternalClusterInfoServiceSchedulingTests extends ESTestCase {
             setFlagOnSuccess(becameMaster2)
         );
         runUntilFlag(deterministicTaskQueue, becameMaster2);
+        deterministicTaskQueue.runAllRunnableTasks();
 
         for (int i = 0; i < 3; i++) {
             final int initialRequestCount = client.requestCount;