Browse Source

Use CacheService Persisted Cache Size during Searchable Snapshot Shard Allocation (#66237)

Searchable snapshot allocator that reaches out to all data nodes to get the cached size of for a shard, similar to how it's done for normal shard `Store`s but simpler since we only care about the exact byte size for now, are not injecting the size into disk threshold allocators and leave out a few more tricks (see TODOs) that we do for normal allocation.
Armin Braun 4 years ago
parent
commit
7caa471831

+ 1 - 1
server/src/main/java/org/elasticsearch/action/admin/cluster/reroute/ClusterRerouteResponse.java

@@ -44,7 +44,7 @@ public class ClusterRerouteResponse extends AcknowledgedResponse implements ToXC
         explanations = RoutingExplanations.readFrom(in);
     }
 
-    ClusterRerouteResponse(boolean acknowledged, ClusterState state, RoutingExplanations explanations) {
+    public ClusterRerouteResponse(boolean acknowledged, ClusterState state, RoutingExplanations explanations) {
         super(acknowledged);
         this.state = state;
         this.explanations = explanations;

+ 79 - 0
x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotAllocationIntegTests.java

@@ -0,0 +1,79 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.searchablesnapshots;
+
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.routing.allocation.decider.EnableAllocationDecider;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeUnit;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.xpack.searchablesnapshots.cache.CacheService;
+
+import java.util.List;
+
+import static org.elasticsearch.index.IndexSettings.INDEX_SOFT_DELETES_SETTING;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+
+@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0)
+public class SearchableSnapshotAllocationIntegTests extends BaseSearchableSnapshotsIntegTestCase {
+
+    @Override
+    protected Settings nodeSettings(int nodeOrdinal) {
+        return Settings.builder()
+            .put(super.nodeSettings(nodeOrdinal))
+            // ensure the cache is definitely used
+            .put(CacheService.SNAPSHOT_CACHE_SIZE_SETTING.getKey(), new ByteSizeValue(1L, ByteSizeUnit.GB))
+            .build();
+    }
+
+    public void testAllocatesToBestAvailableNodeOnRestart() throws Exception {
+        internalCluster().startMasterOnlyNode();
+        final String firstDataNode = internalCluster().startDataOnlyNode();
+        final String index = "test-idx";
+        createIndexWithContent(index, indexSettingsNoReplicas(1).put(INDEX_SOFT_DELETES_SETTING.getKey(), true).build());
+        final String repoName = "test-repo";
+        createRepository(repoName, "fs");
+        final String snapshotName = "test-snapshot";
+        createSnapshot(repoName, snapshotName, List.of(index));
+        assertAcked(client().admin().indices().prepareDelete(index));
+        final String restoredIndex = mountSnapshot(repoName, snapshotName, index, Settings.EMPTY);
+        ensureGreen(restoredIndex);
+        internalCluster().startDataOnlyNodes(randomIntBetween(1, 4));
+
+        setAllocation(EnableAllocationDecider.Allocation.NONE);
+
+        final CacheService cacheService = internalCluster().getInstance(CacheService.class, firstDataNode);
+        cacheService.synchronizeCache();
+        internalCluster().restartNode(firstDataNode);
+        ensureStableCluster(internalCluster().numDataAndMasterNodes());
+
+        setAllocation(EnableAllocationDecider.Allocation.ALL);
+        ensureGreen(restoredIndex);
+
+        final ClusterState state = client().admin().cluster().prepareState().get().getState();
+        assertEquals(
+            state.nodes().resolveNode(firstDataNode).getId(),
+            state.routingTable().index(restoredIndex).shard(0).primaryShard().currentNodeId()
+        );
+    }
+
+    private void setAllocation(EnableAllocationDecider.Allocation allocation) {
+        logger.info("--> setting allocation to [{}]", allocation);
+        assertAcked(
+            client().admin()
+                .cluster()
+                .prepareUpdateSettings()
+                .setTransientSettings(
+                    Settings.builder()
+                        .put(EnableAllocationDecider.CLUSTER_ROUTING_ALLOCATION_ENABLE_SETTING.getKey(), allocation.name())
+                        .build()
+                )
+                .get()
+        );
+    }
+}

+ 343 - 8
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotAllocator.java

@@ -5,21 +5,49 @@
  */
 package org.elasticsearch.xpack.searchablesnapshots;
 
+import com.carrotsearch.hppc.cursors.ObjectCursor;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
 import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.admin.cluster.reroute.ClusterRerouteResponse;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.routing.RecoverySource;
+import org.elasticsearch.cluster.routing.RoutingNode;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.cluster.routing.allocation.AllocateUnassignedDecision;
 import org.elasticsearch.cluster.routing.allocation.AllocationDecision;
 import org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator;
 import org.elasticsearch.cluster.routing.allocation.FailedShard;
+import org.elasticsearch.cluster.routing.allocation.NodeAllocationResult;
 import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
+import org.elasticsearch.cluster.routing.allocation.decider.Decision;
+import org.elasticsearch.common.Nullable;
+import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
+import org.elasticsearch.gateway.AsyncShardFetch;
+import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.repositories.IndexId;
 import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.xpack.searchablesnapshots.action.cache.TransportSearchableSnapshotCacheStoresAction;
+import org.elasticsearch.xpack.searchablesnapshots.action.cache.TransportSearchableSnapshotCacheStoresAction.NodeCacheFilesMetadata;
+import org.elasticsearch.xpack.searchablesnapshots.action.cache.TransportSearchableSnapshotCacheStoresAction.NodesCacheFilesMetadata;
 
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentMap;
 
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SNAPSHOT_INDEX_ID_SETTING;
 import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SNAPSHOT_INDEX_NAME_SETTING;
@@ -29,8 +57,30 @@ import static org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots.SN
 
 public class SearchableSnapshotAllocator implements ExistingShardsAllocator {
 
+    private static final Logger logger = LogManager.getLogger(SearchableSnapshotAllocator.class);
+
+    private static final ActionListener<ClusterRerouteResponse> REROUTE_LISTENER = new ActionListener<>() {
+        @Override
+        public void onResponse(ClusterRerouteResponse clusterRerouteResponse) {
+            logger.trace("reroute succeeded after loading snapshot cache information");
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            logger.warn("reroute failed", e);
+        }
+    };
+
+    private final ConcurrentMap<ShardId, AsyncCacheStatusFetch> asyncFetchStore = ConcurrentCollections.newConcurrentMap();
+
     public static final String ALLOCATOR_NAME = "searchable_snapshot_allocator";
 
+    private final Client client;
+
+    public SearchableSnapshotAllocator(Client client) {
+        this.client = client;
+    }
+
     @Override
     public void beforeAllocation(RoutingAllocation allocation) {}
 
@@ -43,6 +93,7 @@ public class SearchableSnapshotAllocator implements ExistingShardsAllocator {
         RoutingAllocation allocation,
         UnassignedAllocationHandler unassignedAllocationHandler
     ) {
+        // TODO: cancel and jump to better available allocations?
         if (shardRouting.primary()
             && (shardRouting.recoverySource().getType() == RecoverySource.Type.EXISTING_STORE
                 || shardRouting.recoverySource().getType() == RecoverySource.Type.EMPTY_STORE)) {
@@ -74,8 +125,17 @@ public class SearchableSnapshotAllocator implements ExistingShardsAllocator {
 
         final AllocateUnassignedDecision allocateUnassignedDecision = decideAllocation(allocation, shardRouting);
 
-        if (allocateUnassignedDecision.isDecisionTaken() && allocateUnassignedDecision.getAllocationDecision() != AllocationDecision.YES) {
-            unassignedAllocationHandler.removeAndIgnore(allocateUnassignedDecision.getAllocationStatus(), allocation.changes());
+        if (allocateUnassignedDecision.isDecisionTaken()) {
+            if (allocateUnassignedDecision.getAllocationDecision() == AllocationDecision.YES) {
+                unassignedAllocationHandler.initialize(
+                    allocateUnassignedDecision.getTargetNode().getId(),
+                    allocateUnassignedDecision.getAllocationId(),
+                    allocation.snapshotShardSizeInfo().getShardSize(shardRouting, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE),
+                    allocation.changes()
+                );
+            } else {
+                unassignedAllocationHandler.removeAndIgnore(allocateUnassignedDecision.getAllocationStatus(), allocation.changes());
+            }
         }
     }
 
@@ -90,8 +150,60 @@ public class SearchableSnapshotAllocator implements ExistingShardsAllocator {
             return AllocateUnassignedDecision.no(UnassignedInfo.AllocationStatus.FETCHING_SHARD_DATA, null);
         }
 
-        // let BalancedShardsAllocator take care of allocating this shard
-        // TODO: once we have persistent cache, choose a node that has existing data
+        final AsyncShardFetch.FetchResult<NodeCacheFilesMetadata> fetchedCacheData = fetchData(shardRouting, allocation);
+        if (fetchedCacheData.hasData() == false) {
+            return AllocateUnassignedDecision.no(UnassignedInfo.AllocationStatus.FETCHING_SHARD_DATA, null);
+        }
+
+        final boolean explain = allocation.debugDecision();
+        final MatchingNodes matchingNodes = findMatchingNodes(shardRouting, allocation, fetchedCacheData, explain);
+        assert explain == false || matchingNodes.nodeDecisions != null : "in explain mode, we must have individual node decisions";
+
+        // pre-check if it can be allocated to any node that currently exists, so we won't list the cache sizes for it for nothing
+        // TODO: in the following logic, we do not account for existing cache size when handling disk space checks, should and can we
+        // reliably do this in a world of concurrent cache evictions or are we ok with the cache size just being a best effort hint
+        // here?
+        Tuple<Decision, Map<String, NodeAllocationResult>> result = canBeAllocatedToAtLeastOneNode(shardRouting, allocation);
+        Decision allocateDecision = result.v1();
+        if (allocateDecision.type() != Decision.Type.YES && (explain == false || asyncFetchStore.get(shardRouting.shardId()) == null)) {
+            // only return early if we are not in explain mode, or we are in explain mode but we have not
+            // yet attempted to fetch any shard data
+            logger.trace("{}: ignoring allocation, can't be allocated on any node", shardRouting);
+            return AllocateUnassignedDecision.no(
+                UnassignedInfo.AllocationStatus.fromDecision(allocateDecision.type()),
+                result.v2() != null ? new ArrayList<>(result.v2().values()) : null
+            );
+        }
+
+        List<NodeAllocationResult> nodeDecisions = augmentExplanationsWithStoreInfo(result.v2(), matchingNodes.nodeDecisions);
+        if (allocateDecision.type() != Decision.Type.YES) {
+            return AllocateUnassignedDecision.no(UnassignedInfo.AllocationStatus.fromDecision(allocateDecision.type()), nodeDecisions);
+        } else if (matchingNodes.getNodeWithHighestMatch() != null) {
+            RoutingNode nodeWithHighestMatch = allocation.routingNodes().node(matchingNodes.getNodeWithHighestMatch().getId());
+            // we only check on THROTTLE since we checked before on NO
+            Decision decision = allocation.deciders().canAllocate(shardRouting, nodeWithHighestMatch, allocation);
+            if (decision.type() == Decision.Type.THROTTLE) {
+                // TODO: does this make sense? Unlike with the store we could evict the cache concurrently and wait for nothing?
+                logger.debug(
+                    "[{}][{}]: throttling allocation [{}] to [{}] in order to reuse its unallocated persistent cache",
+                    shardRouting.index(),
+                    shardRouting.id(),
+                    shardRouting,
+                    nodeWithHighestMatch.node()
+                );
+                return AllocateUnassignedDecision.throttle(nodeDecisions);
+            } else {
+                logger.debug(
+                    "[{}][{}]: allocating [{}] to [{}] in order to reuse its persistent cache",
+                    shardRouting.index(),
+                    shardRouting.id(),
+                    shardRouting,
+                    nodeWithHighestMatch.node()
+                );
+                return AllocateUnassignedDecision.yes(nodeWithHighestMatch.node(), null, nodeDecisions, true);
+            }
+        }
+        // TODO: do we need handling of delayed allocation for leaving replicas here?
         return AllocateUnassignedDecision.NOT_TAKEN;
     }
 
@@ -103,16 +215,239 @@ public class SearchableSnapshotAllocator implements ExistingShardsAllocator {
     }
 
     @Override
-    public void cleanCaches() {}
+    public void cleanCaches() {
+        asyncFetchStore.clear();
+    }
 
     @Override
-    public void applyStartedShards(List<ShardRouting> startedShards, RoutingAllocation allocation) {}
+    public void applyStartedShards(List<ShardRouting> startedShards, RoutingAllocation allocation) {
+        for (ShardRouting startedShard : startedShards) {
+            asyncFetchStore.remove(startedShard.shardId());
+        }
+    }
 
     @Override
-    public void applyFailedShards(List<FailedShard> failedShards, RoutingAllocation allocation) {}
+    public void applyFailedShards(List<FailedShard> failedShards, RoutingAllocation allocation) {
+        for (FailedShard failedShard : failedShards) {
+            asyncFetchStore.remove(failedShard.getRoutingEntry().shardId());
+        }
+    }
 
     @Override
     public int getNumberOfInFlightFetches() {
-        return 0;
+        int count = 0;
+        for (AsyncCacheStatusFetch fetch : asyncFetchStore.values()) {
+            count += fetch.numberOfInFlightFetches();
+        }
+        return count;
+    }
+
+    private AsyncShardFetch.FetchResult<NodeCacheFilesMetadata> fetchData(ShardRouting shard, RoutingAllocation allocation) {
+        final ShardId shardId = shard.shardId();
+        final Settings indexSettings = allocation.metadata().index(shard.index()).getSettings();
+        final SnapshotId snapshotId = new SnapshotId(
+            SNAPSHOT_SNAPSHOT_NAME_SETTING.get(indexSettings),
+            SNAPSHOT_SNAPSHOT_ID_SETTING.get(indexSettings)
+        );
+        final AsyncCacheStatusFetch asyncFetch = asyncFetchStore.computeIfAbsent(shardId, sid -> new AsyncCacheStatusFetch());
+        final DiscoveryNodes nodes = allocation.nodes();
+        final DiscoveryNode[] dataNodes = asyncFetch.addFetches(nodes.getDataNodes().values().toArray(DiscoveryNode.class));
+        if (dataNodes.length > 0) {
+            client.execute(
+                TransportSearchableSnapshotCacheStoresAction.TYPE,
+                new TransportSearchableSnapshotCacheStoresAction.Request(snapshotId, shardId, dataNodes),
+                ActionListener.runAfter(new ActionListener<>() {
+                    @Override
+                    public void onResponse(NodesCacheFilesMetadata nodesCacheFilesMetadata) {
+                        final Map<DiscoveryNode, NodeCacheFilesMetadata> res = new HashMap<>(nodesCacheFilesMetadata.getNodesMap().size());
+                        for (Map.Entry<String, NodeCacheFilesMetadata> entry : nodesCacheFilesMetadata.getNodesMap().entrySet()) {
+                            res.put(nodes.get(entry.getKey()), entry.getValue());
+                        }
+                        asyncFetch.addData(res);
+                    }
+
+                    @Override
+                    public void onFailure(Exception e) {
+                        logger.warn("Failure when trying to fetch existing cache sizes", e);
+                        final Map<DiscoveryNode, NodeCacheFilesMetadata> res = new HashMap<>(dataNodes.length);
+                        for (DiscoveryNode dataNode : dataNodes) {
+                            res.put(dataNode, new NodeCacheFilesMetadata(dataNode, 0L));
+                        }
+                        asyncFetch.addData(res);
+                    }
+                }, () -> client.admin().cluster().prepareReroute().execute(REROUTE_LISTENER))
+            );
+        }
+        return new AsyncShardFetch.FetchResult<>(shardId, asyncFetch.data(), Collections.emptySet());
+    }
+
+    /**
+     * Takes the store info for nodes that have a shard store and adds them to the node decisions,
+     * leaving the node explanations untouched for those nodes that do not have any store information.
+     */
+    private static List<NodeAllocationResult> augmentExplanationsWithStoreInfo(
+        Map<String, NodeAllocationResult> nodeDecisions,
+        Map<String, NodeAllocationResult> withShardStores
+    ) {
+        if (nodeDecisions == null || withShardStores == null) {
+            return null;
+        }
+        List<NodeAllocationResult> augmented = new ArrayList<>();
+        for (Map.Entry<String, NodeAllocationResult> entry : nodeDecisions.entrySet()) {
+            if (withShardStores.containsKey(entry.getKey())) {
+                augmented.add(withShardStores.get(entry.getKey()));
+            } else {
+                augmented.add(entry.getValue());
+            }
+        }
+        return augmented;
+    }
+
+    /**
+     * Determines if the shard can be allocated on at least one node based on the allocation deciders.
+     *
+     * Returns the best allocation decision for allocating the shard on any node (i.e. YES if at least one
+     * node decided YES, THROTTLE if at least one node decided THROTTLE, and NO if none of the nodes decided
+     * YES or THROTTLE).  If in explain mode, also returns the node-level explanations as the second element
+     * in the returned tuple.
+     * TODO: dry this method up against ReplicaShardAllocator
+     */
+    private static Tuple<Decision, Map<String, NodeAllocationResult>> canBeAllocatedToAtLeastOneNode(
+        ShardRouting shard,
+        RoutingAllocation allocation
+    ) {
+        Decision madeDecision = Decision.NO;
+        final boolean explain = allocation.debugDecision();
+        Map<String, NodeAllocationResult> nodeDecisions = explain ? new HashMap<>() : null;
+        for (ObjectCursor<DiscoveryNode> cursor : allocation.nodes().getDataNodes().values()) {
+            RoutingNode node = allocation.routingNodes().node(cursor.value.getId());
+            if (node == null) {
+                continue;
+            }
+            // if we can't allocate it on a node, ignore it
+            Decision decision = allocation.deciders().canAllocate(shard, node, allocation);
+            if (decision.type() == Decision.Type.YES && madeDecision.type() != Decision.Type.YES) {
+                if (explain) {
+                    madeDecision = decision;
+                } else {
+                    return Tuple.tuple(decision, null);
+                }
+            } else if (madeDecision.type() == Decision.Type.NO && decision.type() == Decision.Type.THROTTLE) {
+                madeDecision = decision;
+            }
+            if (explain) {
+                nodeDecisions.put(node.nodeId(), new NodeAllocationResult(node.node(), null, decision));
+            }
+        }
+        return Tuple.tuple(madeDecision, nodeDecisions);
+    }
+
+    private MatchingNodes findMatchingNodes(
+        ShardRouting shard,
+        RoutingAllocation allocation,
+        AsyncShardFetch.FetchResult<NodeCacheFilesMetadata> data,
+        boolean explain
+    ) {
+        final Map<DiscoveryNode, Long> matchingNodesCacheSizes = new HashMap<>();
+        final Map<String, NodeAllocationResult> nodeDecisionsDebug = explain ? new HashMap<>() : null;
+        for (Map.Entry<DiscoveryNode, NodeCacheFilesMetadata> nodeStoreEntry : data.getData().entrySet()) {
+            DiscoveryNode discoNode = nodeStoreEntry.getKey();
+            NodeCacheFilesMetadata nodeCacheFilesMetadata = nodeStoreEntry.getValue();
+            // we don't have any existing cached bytes at all
+            if (nodeCacheFilesMetadata.bytesCached() == 0L) {
+                continue;
+            }
+
+            RoutingNode node = allocation.routingNodes().node(discoNode.getId());
+            if (node == null) {
+                continue;
+            }
+
+            // check if we can allocate on the node
+            Decision decision = allocation.deciders().canAllocate(shard, node, allocation);
+            Long matchingBytes = null;
+            if (explain) {
+                matchingBytes = nodeCacheFilesMetadata.bytesCached();
+                NodeAllocationResult.ShardStoreInfo shardStoreInfo = new NodeAllocationResult.ShardStoreInfo(matchingBytes);
+                nodeDecisionsDebug.put(node.nodeId(), new NodeAllocationResult(discoNode, shardStoreInfo, decision));
+            }
+
+            if (decision.type() == Decision.Type.NO) {
+                continue;
+            }
+
+            if (matchingBytes == null) {
+                matchingBytes = nodeCacheFilesMetadata.bytesCached();
+            }
+            matchingNodesCacheSizes.put(discoNode, matchingBytes);
+            if (logger.isTraceEnabled()) {
+                logger.trace(
+                    "{}: node [{}] has [{}/{}] bytes of re-usable cache data",
+                    shard,
+                    discoNode.getName(),
+                    new ByteSizeValue(matchingBytes),
+                    matchingBytes
+                );
+            }
+        }
+
+        return new MatchingNodes(matchingNodesCacheSizes, nodeDecisionsDebug);
+    }
+
+    private static final class AsyncCacheStatusFetch {
+
+        private final Set<DiscoveryNode> fetchingDataNodes = new HashSet<>();
+
+        private final Map<DiscoveryNode, NodeCacheFilesMetadata> data = new HashMap<>();
+
+        AsyncCacheStatusFetch() {}
+
+        synchronized DiscoveryNode[] addFetches(DiscoveryNode[] nodes) {
+            final Collection<DiscoveryNode> nodesToFetch = new ArrayList<>();
+            for (DiscoveryNode node : nodes) {
+                if (data.containsKey(node) == false && fetchingDataNodes.add(node)) {
+                    nodesToFetch.add(node);
+                }
+            }
+            return nodesToFetch.toArray(new DiscoveryNode[0]);
+        }
+
+        synchronized void addData(Map<DiscoveryNode, NodeCacheFilesMetadata> newData) {
+            data.putAll(newData);
+            fetchingDataNodes.removeAll(newData.keySet());
+        }
+
+        @Nullable
+        synchronized Map<DiscoveryNode, NodeCacheFilesMetadata> data() {
+            return fetchingDataNodes.size() > 0 ? null : Map.copyOf(data);
+        }
+
+        synchronized int numberOfInFlightFetches() {
+            return fetchingDataNodes.size();
+        }
+    }
+
+    private static final class MatchingNodes {
+        private final DiscoveryNode nodeWithHighestMatch;
+        @Nullable
+        private final Map<String, NodeAllocationResult> nodeDecisions;
+
+        MatchingNodes(Map<DiscoveryNode, Long> matchingNodes, @Nullable Map<String, NodeAllocationResult> nodeDecisions) {
+            this.nodeDecisions = nodeDecisions;
+            this.nodeWithHighestMatch = matchingNodes.entrySet()
+                .stream()
+                .filter(entry -> entry.getValue() > 0L)
+                .max(Map.Entry.comparingByValue())
+                .map(Map.Entry::getKey)
+                .orElse(null);
+        }
+
+        /**
+         * Returns the node with the highest number of bytes cached for the shard or {@code null} if no node with any bytes matched exists.
+         */
+        @Nullable
+        public DiscoveryNode getNodeWithHighestMatch() {
+            return this.nodeWithHighestMatch;
+        }
     }
 }

+ 23 - 2
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshots.java

@@ -18,6 +18,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator;
 import org.elasticsearch.cluster.routing.allocation.decider.AllocationDecider;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.IndexScopedSettings;
@@ -68,6 +69,7 @@ import org.elasticsearch.xpack.searchablesnapshots.action.SearchableSnapshotsSta
 import org.elasticsearch.xpack.searchablesnapshots.action.TransportClearSearchableSnapshotsCacheAction;
 import org.elasticsearch.xpack.searchablesnapshots.action.TransportMountSearchableSnapshotAction;
 import org.elasticsearch.xpack.searchablesnapshots.action.TransportSearchableSnapshotsStatsAction;
+import org.elasticsearch.xpack.searchablesnapshots.action.cache.TransportSearchableSnapshotCacheStoresAction;
 import org.elasticsearch.xpack.searchablesnapshots.cache.CacheService;
 import org.elasticsearch.xpack.searchablesnapshots.cache.PersistentCache;
 import org.elasticsearch.xpack.searchablesnapshots.rest.RestClearSearchableSnapshotsCacheAction;
@@ -170,6 +172,7 @@ public class SearchableSnapshots extends Plugin implements IndexStorePlugin, Eng
     private final SetOnce<BlobStoreCacheService> blobStoreCacheService = new SetOnce<>();
     private final SetOnce<CacheService> cacheService = new SetOnce<>();
     private final SetOnce<ThreadPool> threadPool = new SetOnce<>();
+    private final SetOnce<Client> client = new SetOnce<>();
     private final SetOnce<FailShardsOnInvalidLicenseClusterListener> failShardsListener = new SetOnce<>();
     private final Settings settings;
 
@@ -231,6 +234,8 @@ public class SearchableSnapshots extends Plugin implements IndexStorePlugin, Eng
         } else {
             PersistentCache.cleanUp(settings, nodeEnvironment);
         }
+        this.client.set(client);
+        components.add(new CacheServiceSupplier(cacheService.get()));
         return Collections.unmodifiableList(components);
     }
 
@@ -315,7 +320,8 @@ public class SearchableSnapshots extends Plugin implements IndexStorePlugin, Eng
             new ActionHandler<>(ClearSearchableSnapshotsCacheAction.INSTANCE, TransportClearSearchableSnapshotsCacheAction.class),
             new ActionHandler<>(MountSearchableSnapshotAction.INSTANCE, TransportMountSearchableSnapshotAction.class),
             new ActionHandler<>(XPackUsageFeatureAction.SEARCHABLE_SNAPSHOTS, SearchableSnapshotsUsageTransportAction.class),
-            new ActionHandler<>(XPackInfoFeatureAction.SEARCHABLE_SNAPSHOTS, SearchableSnapshotsInfoTransportAction.class)
+            new ActionHandler<>(XPackInfoFeatureAction.SEARCHABLE_SNAPSHOTS, SearchableSnapshotsInfoTransportAction.class),
+            new ActionHandler<>(TransportSearchableSnapshotCacheStoresAction.TYPE, TransportSearchableSnapshotCacheStoresAction.class)
         );
     }
 
@@ -337,7 +343,7 @@ public class SearchableSnapshots extends Plugin implements IndexStorePlugin, Eng
 
     @Override
     public Map<String, ExistingShardsAllocator> getExistingShardsAllocators() {
-        return Map.of(SearchableSnapshotAllocator.ALLOCATOR_NAME, new SearchableSnapshotAllocator());
+        return Map.of(SearchableSnapshotAllocator.ALLOCATOR_NAME, new SearchableSnapshotAllocator(client.get()));
     }
 
     // overridable by tests
@@ -481,4 +487,19 @@ public class SearchableSnapshots extends Plugin implements IndexStorePlugin, Eng
             throw new UncheckedIOException("Failed to build " + SNAPSHOT_BLOB_CACHE_INDEX + " index mappings", e);
         }
     }
+
+    public static final class CacheServiceSupplier implements Supplier<CacheService> {
+
+        @Nullable
+        private final CacheService cacheService;
+
+        CacheServiceSupplier(@Nullable CacheService cacheService) {
+            this.cacheService = cacheService;
+        }
+
+        @Override
+        public CacheService get() {
+            return cacheService;
+        }
+    }
 }

+ 187 - 0
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/action/cache/TransportSearchableSnapshotCacheStoresAction.java

@@ -0,0 +1,187 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.searchablesnapshots.action.cache;
+
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.FailedNodeException;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.nodes.BaseNodeResponse;
+import org.elasticsearch.action.support.nodes.BaseNodesRequest;
+import org.elasticsearch.action.support.nodes.BaseNodesResponse;
+import org.elasticsearch.action.support.nodes.TransportNodesAction;
+import org.elasticsearch.cluster.ClusterName;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportRequest;
+import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshots;
+import org.elasticsearch.xpack.searchablesnapshots.cache.CacheService;
+
+import java.io.IOException;
+import java.util.List;
+
+public class TransportSearchableSnapshotCacheStoresAction extends TransportNodesAction<
+    TransportSearchableSnapshotCacheStoresAction.Request,
+    TransportSearchableSnapshotCacheStoresAction.NodesCacheFilesMetadata,
+    TransportSearchableSnapshotCacheStoresAction.NodeRequest,
+    TransportSearchableSnapshotCacheStoresAction.NodeCacheFilesMetadata> {
+
+    public static final String ACTION_NAME = "cluster:admin/xpack/searchable_snapshots/cache/store";
+
+    public static final ActionType<NodesCacheFilesMetadata> TYPE = new ActionType<>(ACTION_NAME, NodesCacheFilesMetadata::new);
+
+    private final CacheService cacheService;
+
+    @Inject
+    public TransportSearchableSnapshotCacheStoresAction(
+        ThreadPool threadPool,
+        ClusterService clusterService,
+        TransportService transportService,
+        SearchableSnapshots.CacheServiceSupplier cacheService,
+        ActionFilters actionFilters
+    ) {
+        super(
+            ACTION_NAME,
+            threadPool,
+            clusterService,
+            transportService,
+            actionFilters,
+            Request::new,
+            NodeRequest::new,
+            ThreadPool.Names.MANAGEMENT,
+            ThreadPool.Names.SAME,
+            NodeCacheFilesMetadata.class
+        );
+        this.cacheService = cacheService.get();
+    }
+
+    @Override
+    protected NodesCacheFilesMetadata newResponse(
+        Request request,
+        List<NodeCacheFilesMetadata> nodesCacheFilesMetadata,
+        List<FailedNodeException> failures
+    ) {
+        return new NodesCacheFilesMetadata(clusterService.getClusterName(), nodesCacheFilesMetadata, failures);
+    }
+
+    @Override
+    protected NodeRequest newNodeRequest(Request request) {
+        return new NodeRequest(request);
+    }
+
+    @Override
+    protected NodeCacheFilesMetadata newNodeResponse(StreamInput in) throws IOException {
+        return new NodeCacheFilesMetadata(in);
+    }
+
+    @Override
+    protected NodeCacheFilesMetadata nodeOperation(NodeRequest request, Task task) {
+        assert cacheService != null;
+        return new NodeCacheFilesMetadata(clusterService.localNode(), cacheService.getCachedSize(request.shardId, request.snapshotId));
+    }
+
+    public static final class Request extends BaseNodesRequest<Request> {
+
+        private final SnapshotId snapshotId;
+        private final ShardId shardId;
+
+        public Request(SnapshotId snapshotId, ShardId shardId, DiscoveryNode[] nodes) {
+            super(nodes);
+            this.snapshotId = snapshotId;
+            this.shardId = shardId;
+        }
+
+        public Request(StreamInput in) throws IOException {
+            super(in);
+            snapshotId = new SnapshotId(in);
+            shardId = new ShardId(in);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            super.writeTo(out);
+            snapshotId.writeTo(out);
+            shardId.writeTo(out);
+        }
+    }
+
+    public static final class NodeRequest extends TransportRequest {
+
+        private final SnapshotId snapshotId;
+        private final ShardId shardId;
+
+        public NodeRequest(Request request) {
+            this.snapshotId = request.snapshotId;
+            this.shardId = request.shardId;
+        }
+
+        public NodeRequest(StreamInput in) throws IOException {
+            super(in);
+            this.snapshotId = new SnapshotId(in);
+            this.shardId = new ShardId(in);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            super.writeTo(out);
+            snapshotId.writeTo(out);
+            shardId.writeTo(out);
+        }
+    }
+
+    public static class NodeCacheFilesMetadata extends BaseNodeResponse {
+
+        private final long bytesCached;
+
+        public NodeCacheFilesMetadata(StreamInput in) throws IOException {
+            super(in);
+            bytesCached = in.readLong();
+        }
+
+        public NodeCacheFilesMetadata(DiscoveryNode node, long bytesCached) {
+            super(node);
+            this.bytesCached = bytesCached;
+        }
+
+        public long bytesCached() {
+            return bytesCached;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            super.writeTo(out);
+            out.writeLong(bytesCached);
+        }
+    }
+
+    public static class NodesCacheFilesMetadata extends BaseNodesResponse<NodeCacheFilesMetadata> {
+
+        public NodesCacheFilesMetadata(StreamInput in) throws IOException {
+            super(in);
+        }
+
+        public NodesCacheFilesMetadata(ClusterName clusterName, List<NodeCacheFilesMetadata> nodes, List<FailedNodeException> failures) {
+            super(clusterName, nodes, failures);
+        }
+
+        @Override
+        protected List<NodeCacheFilesMetadata> readNodesFrom(StreamInput in) throws IOException {
+            return in.readList(NodeCacheFilesMetadata::new);
+        }
+
+        @Override
+        protected void writeNodesTo(StreamOutput out, List<NodeCacheFilesMetadata> nodes) throws IOException {
+            out.writeList(nodes);
+        }
+    }
+}

+ 12 - 1
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/cache/CacheService.java

@@ -283,6 +283,16 @@ public class CacheService extends AbstractLifecycleComponent {
         });
     }
 
+    /**
+     * Get the number of bytes cached for the given shard id in the given snapshot id.
+     * @param shardId    shard id
+     * @param snapshotId snapshot id
+     * @return number of bytes cached
+     */
+    public long getCachedSize(ShardId shardId, SnapshotId snapshotId) {
+        return persistentCache.getCacheSize(shardId, snapshotId);
+    }
+
     /**
      * Computes a new {@link CacheFile} instance using the specified cache file information (file length, file name, parent directory and
      * already available cache ranges) and associates it with the specified {@link CacheKey} in the cache. If the key is already
@@ -467,7 +477,8 @@ public class CacheService extends AbstractLifecycleComponent {
      * non empty set of completed ranges this method also fsync the shard's snapshot cache directory, which is the parent directory of the
      * cache entry. Note that cache files might be evicted during the synchronization.
      */
-    protected void synchronizeCache() {
+    // public for tests only
+    public void synchronizeCache() {
         cacheSyncLock.lock();
         try {
             long count = 0L;

+ 53 - 0
x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/cache/PersistentCache.java

@@ -23,6 +23,14 @@ import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.SerialMergeScheduler;
 import org.apache.lucene.index.Term;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.Weight;
 import org.apache.lucene.store.AlreadyClosedException;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.FSDirectory;
@@ -65,6 +73,7 @@ import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeSet;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.IntPredicate;
 
 import static java.util.Collections.synchronizedMap;
 import static java.util.Collections.unmodifiableList;
@@ -129,6 +138,50 @@ public class PersistentCache implements Closeable {
         getWriter(cacheFile).deleteCacheFile(cacheFile);
     }
 
+    public long getCacheSize(ShardId shardId, SnapshotId snapshotId) {
+        long aggregateSize = 0L;
+        for (CacheIndexWriter writer : writers) {
+            try (IndexReader indexReader = DirectoryReader.open(writer.indexWriter)) {
+                final IndexSearcher searcher = new IndexSearcher(indexReader);
+                searcher.setQueryCache(null);
+                final Weight weight = searcher.createWeight(
+                    new BooleanQuery.Builder().add(
+                        new TermQuery(new Term(SNAPSHOT_ID_FIELD, snapshotId.getUUID())),
+                        BooleanClause.Occur.MUST
+                    )
+                        .add(new TermQuery(new Term(SHARD_INDEX_ID_FIELD, shardId.getIndex().getUUID())), BooleanClause.Occur.MUST)
+                        .add(new TermQuery(new Term(SHARD_ID_FIELD, String.valueOf(shardId.getId()))), BooleanClause.Occur.MUST)
+                        .build(),
+                    ScoreMode.COMPLETE_NO_SCORES,
+                    0.0f
+                );
+                for (LeafReaderContext leafReaderContext : searcher.getIndexReader().leaves()) {
+                    final Scorer scorer = weight.scorer(leafReaderContext);
+                    if (scorer != null) {
+                        final Bits liveDocs = leafReaderContext.reader().getLiveDocs();
+                        final IntPredicate isLiveDoc = liveDocs == null ? i -> true : liveDocs::get;
+                        final DocIdSetIterator docIdSetIterator = scorer.iterator();
+                        while (docIdSetIterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
+                            if (isLiveDoc.test(docIdSetIterator.docID())) {
+                                final Document document = leafReaderContext.reader().document(docIdSetIterator.docID());
+                                var ranges = buildCacheFileRanges(document);
+                                for (Tuple<Long, Long> range : ranges) {
+                                    aggregateSize += range.v2() - range.v1();
+                                }
+                            }
+                        }
+                    }
+                }
+            } catch (IOException e) {
+                throw new UncheckedIOException(e);
+            }
+            if (aggregateSize > 0L) {
+                return aggregateSize;
+            }
+        }
+        return 0L;
+    }
+
     /**
      * This method repopulates the {@link CacheService} by looking at the files on the disk and for each file found, retrieves the latest
      * synchronized information and puts the cache file into the searchable snapshots cache.

+ 179 - 0
x-pack/plugin/searchable-snapshots/src/test/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotAllocatorTests.java

@@ -0,0 +1,179 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.searchablesnapshots;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.admin.cluster.reroute.ClusterRerouteAction;
+import org.elasticsearch.action.admin.cluster.reroute.ClusterRerouteResponse;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.cluster.ClusterName;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.ESAllocationTestCase;
+import org.elasticsearch.cluster.coordination.DeterministicTaskQueue;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.cluster.routing.RecoverySource;
+import org.elasticsearch.cluster.routing.RoutingNodes;
+import org.elasticsearch.cluster.routing.RoutingTable;
+import org.elasticsearch.cluster.routing.ShardRouting;
+import org.elasticsearch.cluster.routing.allocation.ExistingShardsAllocator;
+import org.elasticsearch.cluster.routing.allocation.RoutingAllocation;
+import org.elasticsearch.cluster.routing.allocation.RoutingExplanations;
+import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.collect.ImmutableOpenMap;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.IndexModule;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.repositories.IndexId;
+import org.elasticsearch.snapshots.Snapshot;
+import org.elasticsearch.snapshots.SnapshotId;
+import org.elasticsearch.snapshots.SnapshotShardSizeInfo;
+import org.elasticsearch.test.client.NoOpNodeClient;
+import org.elasticsearch.xpack.searchablesnapshots.action.cache.TransportSearchableSnapshotCacheStoresAction;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.node.Node.NODE_NAME_SETTING;
+
+public class SearchableSnapshotAllocatorTests extends ESAllocationTestCase {
+
+    public void testAllocateToNodeWithLargestCache() {
+        final ShardId shardId = new ShardId("test", "_na_", 0);
+        final List<DiscoveryNode> nodes = randomList(1, 10, () -> newNode("node-" + UUIDs.randomBase64UUID(random())));
+        final DiscoveryNode localNode = randomFrom(nodes);
+        final Settings localNodeSettings = Settings.builder().put(NODE_NAME_SETTING.getKey(), localNode.getName()).build();
+
+        final ClusterName clusterName = org.elasticsearch.cluster.ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY);
+
+        final DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue(localNodeSettings, random());
+
+        final Metadata metadata = Metadata.builder()
+            .put(
+                IndexMetadata.builder(shardId.getIndexName())
+                    .settings(
+                        settings(Version.CURRENT).put(
+                            ExistingShardsAllocator.EXISTING_SHARDS_ALLOCATOR_SETTING.getKey(),
+                            SearchableSnapshotAllocator.ALLOCATOR_NAME
+                        ).put(IndexModule.INDEX_STORE_TYPE_SETTING.getKey(), SearchableSnapshotsConstants.SNAPSHOT_DIRECTORY_FACTORY_KEY)
+                    )
+                    .numberOfShards(1)
+                    .numberOfReplicas(0)
+                    .putInSyncAllocationIds(shardId.id(), Collections.emptySet())
+            )
+            .build();
+        final RoutingTable.Builder routingTableBuilder = RoutingTable.builder();
+        routingTableBuilder.addAsRestore(metadata.index(shardId.getIndex()), randomSnapshotSource(shardId));
+
+        final DiscoveryNodes.Builder nodesBuilder = DiscoveryNodes.builder();
+        for (DiscoveryNode node : nodes) {
+            nodesBuilder.add(node);
+        }
+        final DiscoveryNodes discoveryNodes = nodesBuilder.build();
+        final ClusterState state = ClusterState.builder(clusterName)
+            .metadata(metadata)
+            .routingTable(routingTableBuilder.build())
+            .nodes(discoveryNodes)
+            .build();
+        final long shardSize = randomNonNegativeLong();
+        final RoutingAllocation allocation = new RoutingAllocation(
+            yesAllocationDeciders(),
+            new RoutingNodes(state, false),
+            state,
+            null,
+            new SnapshotShardSizeInfo(ImmutableOpenMap.of()) {
+                @Override
+                public Long getShardSize(ShardRouting shardRouting) {
+                    return shardSize;
+                }
+            },
+            TimeUnit.MILLISECONDS.toNanos(deterministicTaskQueue.getCurrentTimeMillis())
+        );
+
+        final AtomicInteger reroutesTriggered = new AtomicInteger(0);
+
+        final Map<DiscoveryNode, Long> existingCacheSizes = nodes.stream()
+            .collect(Collectors.toUnmodifiableMap(Function.identity(), k -> randomBoolean() ? 0L : randomLongBetween(0, shardSize)));
+
+        final Client client = new NoOpNodeClient(deterministicTaskQueue.getThreadPool()) {
+
+            @SuppressWarnings("unchecked")
+            @Override
+            public <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
+                ActionType<Response> action,
+                Request request,
+                ActionListener<Response> listener
+            ) {
+                if (action == ClusterRerouteAction.INSTANCE) {
+                    reroutesTriggered.incrementAndGet();
+                    listener.onResponse((Response) new ClusterRerouteResponse(true, state, new RoutingExplanations()));
+                } else if (action == TransportSearchableSnapshotCacheStoresAction.TYPE) {
+                    listener.onResponse(
+                        (Response) new TransportSearchableSnapshotCacheStoresAction.NodesCacheFilesMetadata(
+                            clusterName,
+                            existingCacheSizes.entrySet()
+                                .stream()
+                                .map(
+                                    entry -> new TransportSearchableSnapshotCacheStoresAction.NodeCacheFilesMetadata(
+                                        entry.getKey(),
+                                        entry.getValue()
+                                    )
+                                )
+                                .collect(Collectors.toList()),
+                            List.of()
+                        )
+                    );
+                } else {
+                    throw new AssertionError("Unexpected action [" + action + "]");
+                }
+            }
+        };
+
+        final SearchableSnapshotAllocator allocator = new SearchableSnapshotAllocator(client);
+        allocateAllUnassigned(allocation, allocator);
+
+        assertEquals(1, reroutesTriggered.get());
+        if (existingCacheSizes.values().stream().allMatch(size -> size == 0L)) {
+            assertFalse("If there are no existing caches the allocator should not take a decision", allocation.routingNodesChanged());
+        } else {
+            assertTrue(allocation.routingNodesChanged());
+            final long bestCacheSize = existingCacheSizes.values().stream().mapToLong(l -> l).max().orElseThrow();
+
+            final ShardRouting primaryRouting = allocation.routingNodes().assignedShards(shardId).get(0);
+            final String primaryNodeId = primaryRouting.currentNodeId();
+            final DiscoveryNode primaryNode = discoveryNodes.get(primaryNodeId);
+            assertEquals(bestCacheSize, (long) existingCacheSizes.get(primaryNode));
+        }
+    }
+
+    private static void allocateAllUnassigned(RoutingAllocation allocation, ExistingShardsAllocator allocator) {
+        final RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator();
+        while (iterator.hasNext()) {
+            allocator.allocateUnassigned(iterator.next(), allocation, iterator);
+        }
+    }
+
+    private static RecoverySource.SnapshotRecoverySource randomSnapshotSource(ShardId shardId) {
+        return new RecoverySource.SnapshotRecoverySource(
+            UUIDs.randomBase64UUID(random()),
+            new Snapshot("test-repo", new SnapshotId("test-snap", UUIDs.randomBase64UUID(random()))),
+            Version.CURRENT,
+            new IndexId(shardId.getIndexName(), UUIDs.randomBase64UUID(random()))
+        );
+    }
+}

+ 1 - 0
x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java

@@ -157,6 +157,7 @@ public class Constants {
         "cluster:admin/xpack/rollup/start",
         "cluster:admin/xpack/rollup/stop",
         "cluster:admin/xpack/searchable_snapshots/cache/clear",
+        "cluster:admin/xpack/searchable_snapshots/cache/store",
         "cluster:admin/xpack/security/api_key/create",
         "cluster:admin/xpack/security/api_key/get",
         "cluster:admin/xpack/security/api_key/grant",