Pārlūkot izejas kodu

Add undesired shard count (#101426)

This change add undesired shard (ones that are allocated not on the desired
node) counts to the api output.
Ievgen Degtiarenko 1 gadu atpakaļ
vecāks
revīzija
70ecb12556

+ 5 - 0
docs/changelog/101426.yaml

@@ -0,0 +1,5 @@
+pr: 101426
+summary: Add undesired shard count
+area: Allocation
+type: enhancement
+issues: []

+ 25 - 1
docs/reference/cluster/get-desired-balance.asciidoc

@@ -6,7 +6,12 @@
 
 
 NOTE: {cloud-only}
 NOTE: {cloud-only}
 
 
-Exposes the desired balance and basic metrics.
+Exposes:
+* the desired balance computation and reconciliation stats
+* balancing stats such as distribution of shards, disk and ingest forecasts
+  across nodes and data tiers (based on the current cluster state)
+* routing table with each shard current and desired location
+* cluster info with nodes disk usages
 
 
 [[get-desired-balance-request]]
 [[get-desired-balance-request]]
 ==== {api-request-title}
 ==== {api-request-title}
@@ -33,6 +38,8 @@ The API returns the following result:
     "reconciliation_time_in_millis": 0
     "reconciliation_time_in_millis": 0
   },
   },
   "cluster_balance_stats" : {
   "cluster_balance_stats" : {
+    "shard_count": 37,
+    "undesired_shard_allocation_count": 0,
     "tiers": {
     "tiers": {
       "data_hot" : {
       "data_hot" : {
         "shard_count" : {
         "shard_count" : {
@@ -42,6 +49,13 @@ The API returns the following result:
           "average" : 2.3333333333333335,
           "average" : 2.3333333333333335,
           "std_dev" : 0.4714045207910317
           "std_dev" : 0.4714045207910317
         },
         },
+        "undesired_shard_allocation_count" : {
+          "total" : 0.0,
+          "min" : 0.0,
+          "max" : 0.0,
+          "average" : 0.0,
+          "std_dev" : 0.0
+        },
         "forecast_write_load" : {
         "forecast_write_load" : {
           "total" : 21.0,
           "total" : 21.0,
           "min" : 6.0,
           "min" : 6.0,
@@ -72,6 +86,13 @@ The API returns the following result:
           "average" : 1.0,
           "average" : 1.0,
           "std_dev" : 0.0
           "std_dev" : 0.0
         },
         },
+        "undesired_shard_allocation_count" : {
+          "total" : 0.0,
+          "min" : 0.0,
+          "max" : 0.0,
+          "average" : 0.0,
+          "std_dev" : 0.0
+        },
         "forecast_write_load" : {
         "forecast_write_load" : {
           "total" : 0.0,
           "total" : 0.0,
           "min" : 0.0,
           "min" : 0.0,
@@ -100,6 +121,7 @@ The API returns the following result:
         "node_id": "UPYt8VwWTt-IADAEbqpLxA",
         "node_id": "UPYt8VwWTt-IADAEbqpLxA",
         "roles": ["data_content"],
         "roles": ["data_content"],
         "shard_count": 10,
         "shard_count": 10,
+        "undesired_shard_allocation_count": 0,
         "forecast_write_load": 8.5,
         "forecast_write_load": 8.5,
         "forecast_disk_usage_bytes": 498435,
         "forecast_disk_usage_bytes": 498435,
         "actual_disk_usage_bytes": 498435
         "actual_disk_usage_bytes": 498435
@@ -108,6 +130,7 @@ The API returns the following result:
         "node_id": "bgC66tboTIeFQ0VgRGI4Gg",
         "node_id": "bgC66tboTIeFQ0VgRGI4Gg",
         "roles": ["data_content"],
         "roles": ["data_content"],
         "shard_count": 15,
         "shard_count": 15,
+        "undesired_shard_allocation_count": 0,
         "forecast_write_load": 3.25,
         "forecast_write_load": 3.25,
         "forecast_disk_usage_bytes": 384935,
         "forecast_disk_usage_bytes": 384935,
         "actual_disk_usage_bytes": 384935
         "actual_disk_usage_bytes": 384935
@@ -116,6 +139,7 @@ The API returns the following result:
         "node_id": "2x1VTuSOQdeguXPdN73yRw",
         "node_id": "2x1VTuSOQdeguXPdN73yRw",
         "roles": ["data_content"],
         "roles": ["data_content"],
         "shard_count": 12,
         "shard_count": 12,
+        "undesired_shard_allocation_count": 0,
         "forecast_write_load": 6.0,
         "forecast_write_load": 6.0,
         "forecast_disk_usage_bytes": 648766,
         "forecast_disk_usage_bytes": 648766,
         "actual_disk_usage_bytes": 648766
         "actual_disk_usage_bytes": 648766

+ 38 - 0
qa/smoke-test-multinode/src/yamlRestTest/resources/rest-api-spec/test/smoke_test_multinode/30_desired_balance.yml

@@ -148,3 +148,41 @@ setup:
       _internal.get_desired_balance: { }
       _internal.get_desired_balance: { }
 
 
   - is_true: 'cluster_info'
   - is_true: 'cluster_info'
+
+---
+"Test undesired_shard_allocation_count":
+
+  - skip:
+      version: " - 8.11.99"
+      reason: "undesired_shard_allocation_count added in in 8.12.0"
+
+  - do:
+      indices.create:
+        index: test
+        body:
+          settings:
+            number_of_shards: 1
+            number_of_replicas: 0
+
+  - do:
+      cluster.health:
+        index: test
+        wait_for_status: green
+
+  - do:
+      cluster.state: {}
+  - set: { nodes._arbitrary_key_ : node_id }
+  - set: { nodes.$node_id.name : node_name }
+
+  - do:
+      _internal.get_desired_balance: { }
+
+  - gte: { 'cluster_balance_stats.shard_count' : 0 }
+  - gte: { 'cluster_balance_stats.undesired_shard_allocation_count' : 0 }
+  - gte: { 'cluster_balance_stats.nodes.$node_name.undesired_shard_allocation_count' : 0 }
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count'
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count.total'
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count.min'
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count.max'
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count.average'
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count.std_dev'

+ 38 - 0
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/cluster.desired_balance/10_basic.yml

@@ -183,3 +183,41 @@ setup:
 
 
   - do:
   - do:
       _internal.delete_desired_balance: { }
       _internal.delete_desired_balance: { }
+
+---
+"Test undesired_shard_allocation_count":
+
+  - skip:
+      version: " - 8.11.99"
+      reason: "undesired_shard_allocation_count added in in 8.12.0"
+
+  - do:
+      indices.create:
+        index: test
+        body:
+          settings:
+            number_of_shards: 1
+            number_of_replicas: 0
+
+  - do:
+      cluster.health:
+        index: test
+        wait_for_status: green
+
+  - do:
+      cluster.state: {}
+  - set: { nodes._arbitrary_key_ : node_id }
+  - set: { nodes.$node_id.name : node_name }
+
+  - do:
+      _internal.get_desired_balance: { }
+
+  - gte: { 'cluster_balance_stats.shard_count' : 0 }
+  - gte: { 'cluster_balance_stats.undesired_shard_allocation_count' : 0 }
+  - gte: { 'cluster_balance_stats.nodes.$node_name.undesired_shard_allocation_count' : 0 }
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count'
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count.total'
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count.min'
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count.max'
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count.average'
+  - exists: 'cluster_balance_stats.tiers.data_content.undesired_shard_allocation_count.std_dev'

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -158,6 +158,7 @@ public class TransportVersions {
     public static final TransportVersion DSL_ERROR_STORE_INFORMATION_ENHANCED = def(8_527_00_0);
     public static final TransportVersion DSL_ERROR_STORE_INFORMATION_ENHANCED = def(8_527_00_0);
     public static final TransportVersion INVALID_BUCKET_PATH_EXCEPTION_INTRODUCED = def(8_528_00_0);
     public static final TransportVersion INVALID_BUCKET_PATH_EXCEPTION_INTRODUCED = def(8_528_00_0);
     public static final TransportVersion KNN_AS_QUERY_ADDED = def(8_529_00_0);
     public static final TransportVersion KNN_AS_QUERY_ADDED = def(8_529_00_0);
+    public static final TransportVersion UNDESIRED_SHARD_ALLOCATIONS_COUNT_ADDED = def(8_530_00_0);
 
 
     /*
     /*
      * STOP! READ THIS FIRST! No, really,
      * STOP! READ THIS FIRST! No, really,

+ 1 - 1
server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetDesiredBalanceAction.java

@@ -95,7 +95,7 @@ public class TransportGetDesiredBalanceAction extends TransportMasterNodeReadAct
         listener.onResponse(
         listener.onResponse(
             new DesiredBalanceResponse(
             new DesiredBalanceResponse(
                 desiredBalanceShardsAllocator.getStats(),
                 desiredBalanceShardsAllocator.getStats(),
-                ClusterBalanceStats.createFrom(state, clusterInfo, writeLoadForecaster),
+                ClusterBalanceStats.createFrom(state, latestDesiredBalance, clusterInfo, writeLoadForecaster),
                 createRoutingTable(state, latestDesiredBalance),
                 createRoutingTable(state, latestDesiredBalance),
                 clusterInfo
                 clusterInfo
             )
             )

+ 69 - 9
server/src/main/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterBalanceStats.java

@@ -31,15 +31,18 @@ import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.function.ToDoubleFunction;
 import java.util.function.ToDoubleFunction;
 
 
-public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<String, NodeBalanceStats> nodes)
-    implements
-        Writeable,
-        ToXContentObject {
+public record ClusterBalanceStats(
+    int shards,
+    int undesiredShardAllocations,
+    Map<String, TierBalanceStats> tiers,
+    Map<String, NodeBalanceStats> nodes
+) implements Writeable, ToXContentObject {
 
 
-    public static ClusterBalanceStats EMPTY = new ClusterBalanceStats(Map.of(), Map.of());
+    public static ClusterBalanceStats EMPTY = new ClusterBalanceStats(0, 0, Map.of(), Map.of());
 
 
     public static ClusterBalanceStats createFrom(
     public static ClusterBalanceStats createFrom(
         ClusterState clusterState,
         ClusterState clusterState,
+        DesiredBalance desiredBalance,
         ClusterInfo clusterInfo,
         ClusterInfo clusterInfo,
         WriteLoadForecaster writeLoadForecaster
         WriteLoadForecaster writeLoadForecaster
     ) {
     ) {
@@ -50,32 +53,60 @@ public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<Strin
             if (dataRoles.isEmpty()) {
             if (dataRoles.isEmpty()) {
                 continue;
                 continue;
             }
             }
-            var nodeStats = NodeBalanceStats.createFrom(routingNode, clusterState.metadata(), clusterInfo, writeLoadForecaster);
+            var nodeStats = NodeBalanceStats.createFrom(
+                routingNode,
+                clusterState.metadata(),
+                desiredBalance,
+                clusterInfo,
+                writeLoadForecaster
+            );
             nodes.put(routingNode.node().getName(), nodeStats);
             nodes.put(routingNode.node().getName(), nodeStats);
             for (DiscoveryNodeRole role : dataRoles) {
             for (DiscoveryNodeRole role : dataRoles) {
                 tierToNodeStats.computeIfAbsent(role.roleName(), ignored -> new ArrayList<>()).add(nodeStats);
                 tierToNodeStats.computeIfAbsent(role.roleName(), ignored -> new ArrayList<>()).add(nodeStats);
             }
             }
         }
         }
-        return new ClusterBalanceStats(Maps.transformValues(tierToNodeStats, TierBalanceStats::createFrom), nodes);
+        return new ClusterBalanceStats(
+            nodes.values().stream().mapToInt(NodeBalanceStats::shards).sum(),
+            nodes.values().stream().mapToInt(NodeBalanceStats::undesiredShardAllocations).sum(),
+            Maps.transformValues(tierToNodeStats, TierBalanceStats::createFrom),
+            nodes
+        );
     }
     }
 
 
     public static ClusterBalanceStats readFrom(StreamInput in) throws IOException {
     public static ClusterBalanceStats readFrom(StreamInput in) throws IOException {
-        return new ClusterBalanceStats(in.readImmutableMap(TierBalanceStats::readFrom), in.readImmutableMap(NodeBalanceStats::readFrom));
+        return new ClusterBalanceStats(
+            in.getTransportVersion().onOrAfter(TransportVersions.UNDESIRED_SHARD_ALLOCATIONS_COUNT_ADDED) ? in.readVInt() : -1,
+            in.getTransportVersion().onOrAfter(TransportVersions.UNDESIRED_SHARD_ALLOCATIONS_COUNT_ADDED) ? in.readVInt() : -1,
+            in.readImmutableMap(TierBalanceStats::readFrom),
+            in.readImmutableMap(NodeBalanceStats::readFrom)
+        );
     }
     }
 
 
     @Override
     @Override
     public void writeTo(StreamOutput out) throws IOException {
     public void writeTo(StreamOutput out) throws IOException {
+        if (out.getTransportVersion().onOrAfter(TransportVersions.UNDESIRED_SHARD_ALLOCATIONS_COUNT_ADDED)) {
+            out.writeVInt(shards);
+        }
+        if (out.getTransportVersion().onOrAfter(TransportVersions.UNDESIRED_SHARD_ALLOCATIONS_COUNT_ADDED)) {
+            out.writeVInt(undesiredShardAllocations);
+        }
         out.writeMap(tiers, StreamOutput::writeWriteable);
         out.writeMap(tiers, StreamOutput::writeWriteable);
         out.writeMap(nodes, StreamOutput::writeWriteable);
         out.writeMap(nodes, StreamOutput::writeWriteable);
     }
     }
 
 
     @Override
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        return builder.startObject().field("tiers").map(tiers).field("nodes").map(nodes).endObject();
+        return builder.startObject()
+            .field("shard_count", shards)
+            .field("undesired_shard_allocation_count", undesiredShardAllocations)
+            .field("tiers", tiers)
+            .field("nodes", nodes)
+            .endObject();
     }
     }
 
 
     public record TierBalanceStats(
     public record TierBalanceStats(
         MetricStats shardCount,
         MetricStats shardCount,
+        MetricStats undesiredShardAllocations,
         MetricStats forecastWriteLoad,
         MetricStats forecastWriteLoad,
         MetricStats forecastShardSize,
         MetricStats forecastShardSize,
         MetricStats actualShardSize
         MetricStats actualShardSize
@@ -84,6 +115,7 @@ public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<Strin
         private static TierBalanceStats createFrom(List<NodeBalanceStats> nodes) {
         private static TierBalanceStats createFrom(List<NodeBalanceStats> nodes) {
             return new TierBalanceStats(
             return new TierBalanceStats(
                 MetricStats.createFrom(nodes, it -> it.shards),
                 MetricStats.createFrom(nodes, it -> it.shards),
+                MetricStats.createFrom(nodes, it -> it.undesiredShardAllocations),
                 MetricStats.createFrom(nodes, it -> it.forecastWriteLoad),
                 MetricStats.createFrom(nodes, it -> it.forecastWriteLoad),
                 MetricStats.createFrom(nodes, it -> it.forecastShardSize),
                 MetricStats.createFrom(nodes, it -> it.forecastShardSize),
                 MetricStats.createFrom(nodes, it -> it.actualShardSize)
                 MetricStats.createFrom(nodes, it -> it.actualShardSize)
@@ -93,6 +125,9 @@ public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<Strin
         public static TierBalanceStats readFrom(StreamInput in) throws IOException {
         public static TierBalanceStats readFrom(StreamInput in) throws IOException {
             return new TierBalanceStats(
             return new TierBalanceStats(
                 MetricStats.readFrom(in),
                 MetricStats.readFrom(in),
+                in.getTransportVersion().onOrAfter(TransportVersions.UNDESIRED_SHARD_ALLOCATIONS_COUNT_ADDED)
+                    ? MetricStats.readFrom(in)
+                    : new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                 MetricStats.readFrom(in),
                 MetricStats.readFrom(in),
                 MetricStats.readFrom(in),
                 MetricStats.readFrom(in),
                 MetricStats.readFrom(in)
                 MetricStats.readFrom(in)
@@ -102,6 +137,9 @@ public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<Strin
         @Override
         @Override
         public void writeTo(StreamOutput out) throws IOException {
         public void writeTo(StreamOutput out) throws IOException {
             shardCount.writeTo(out);
             shardCount.writeTo(out);
+            if (out.getTransportVersion().onOrAfter(TransportVersions.UNDESIRED_SHARD_ALLOCATIONS_COUNT_ADDED)) {
+                undesiredShardAllocations.writeTo(out);
+            }
             forecastWriteLoad.writeTo(out);
             forecastWriteLoad.writeTo(out);
             forecastShardSize.writeTo(out);
             forecastShardSize.writeTo(out);
             actualShardSize.writeTo(out);
             actualShardSize.writeTo(out);
@@ -111,6 +149,7 @@ public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<Strin
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             return builder.startObject()
             return builder.startObject()
                 .field("shard_count", shardCount)
                 .field("shard_count", shardCount)
+                .field("undesired_shard_allocation_count", undesiredShardAllocations)
                 .field("forecast_write_load", forecastWriteLoad)
                 .field("forecast_write_load", forecastWriteLoad)
                 .field("forecast_disk_usage", forecastShardSize)
                 .field("forecast_disk_usage", forecastShardSize)
                 .field("actual_disk_usage", actualShardSize)
                 .field("actual_disk_usage", actualShardSize)
@@ -172,6 +211,7 @@ public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<Strin
         String nodeId,
         String nodeId,
         List<String> roles,
         List<String> roles,
         int shards,
         int shards,
+        int undesiredShardAllocations,
         double forecastWriteLoad,
         double forecastWriteLoad,
         long forecastShardSize,
         long forecastShardSize,
         long actualShardSize
         long actualShardSize
@@ -182,9 +222,11 @@ public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<Strin
         private static NodeBalanceStats createFrom(
         private static NodeBalanceStats createFrom(
             RoutingNode routingNode,
             RoutingNode routingNode,
             Metadata metadata,
             Metadata metadata,
+            DesiredBalance desiredBalance,
             ClusterInfo clusterInfo,
             ClusterInfo clusterInfo,
             WriteLoadForecaster writeLoadForecaster
             WriteLoadForecaster writeLoadForecaster
         ) {
         ) {
+            int undesired = 0;
             double forecastWriteLoad = 0.0;
             double forecastWriteLoad = 0.0;
             long forecastShardSize = 0L;
             long forecastShardSize = 0L;
             long actualShardSize = 0L;
             long actualShardSize = 0L;
@@ -196,23 +238,37 @@ public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<Strin
                 forecastWriteLoad += writeLoadForecaster.getForecastedWriteLoad(indexMetadata).orElse(0.0);
                 forecastWriteLoad += writeLoadForecaster.getForecastedWriteLoad(indexMetadata).orElse(0.0);
                 forecastShardSize += indexMetadata.getForecastedShardSizeInBytes().orElse(shardSize);
                 forecastShardSize += indexMetadata.getForecastedShardSizeInBytes().orElse(shardSize);
                 actualShardSize += shardSize;
                 actualShardSize += shardSize;
+                if (isDesiredShardAllocation(shardRouting, desiredBalance) == false) {
+                    undesired++;
+                }
             }
             }
 
 
             return new NodeBalanceStats(
             return new NodeBalanceStats(
                 routingNode.nodeId(),
                 routingNode.nodeId(),
                 routingNode.node().getRoles().stream().map(DiscoveryNodeRole::roleName).toList(),
                 routingNode.node().getRoles().stream().map(DiscoveryNodeRole::roleName).toList(),
                 routingNode.size(),
                 routingNode.size(),
+                undesired,
                 forecastWriteLoad,
                 forecastWriteLoad,
                 forecastShardSize,
                 forecastShardSize,
                 actualShardSize
                 actualShardSize
             );
             );
         }
         }
 
 
+        private static boolean isDesiredShardAllocation(ShardRouting shardRouting, DesiredBalance desiredBalance) {
+            if (shardRouting.relocating()) {
+                // relocating out shards are temporarily accepted
+                return true;
+            }
+            var assignment = desiredBalance.getAssignment(shardRouting.shardId());
+            return assignment != null && assignment.nodeIds().contains(shardRouting.currentNodeId());
+        }
+
         public static NodeBalanceStats readFrom(StreamInput in) throws IOException {
         public static NodeBalanceStats readFrom(StreamInput in) throws IOException {
             return new NodeBalanceStats(
             return new NodeBalanceStats(
                 in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0) ? in.readString() : UNKNOWN_NODE_ID,
                 in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0) ? in.readString() : UNKNOWN_NODE_ID,
                 in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0) ? in.readStringCollectionAsList() : List.of(),
                 in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0) ? in.readStringCollectionAsList() : List.of(),
                 in.readInt(),
                 in.readInt(),
+                in.getTransportVersion().onOrAfter(TransportVersions.UNDESIRED_SHARD_ALLOCATIONS_COUNT_ADDED) ? in.readVInt() : -1,
                 in.readDouble(),
                 in.readDouble(),
                 in.readLong(),
                 in.readLong(),
                 in.readLong()
                 in.readLong()
@@ -228,6 +284,9 @@ public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<Strin
                 out.writeStringCollection(roles);
                 out.writeStringCollection(roles);
             }
             }
             out.writeInt(shards);
             out.writeInt(shards);
+            if (out.getTransportVersion().onOrAfter(TransportVersions.UNDESIRED_SHARD_ALLOCATIONS_COUNT_ADDED)) {
+                out.writeVInt(undesiredShardAllocations);
+            }
             out.writeDouble(forecastWriteLoad);
             out.writeDouble(forecastWriteLoad);
             out.writeLong(forecastShardSize);
             out.writeLong(forecastShardSize);
             out.writeLong(actualShardSize);
             out.writeLong(actualShardSize);
@@ -241,6 +300,7 @@ public record ClusterBalanceStats(Map<String, TierBalanceStats> tiers, Map<Strin
             }
             }
             return builder.field("roles", roles)
             return builder.field("roles", roles)
                 .field("shard_count", shards)
                 .field("shard_count", shards)
+                .field("undesired_shard_allocation_count", undesiredShardAllocations)
                 .field("forecast_write_load", forecastWriteLoad)
                 .field("forecast_write_load", forecastWriteLoad)
                 .humanReadableField("forecast_disk_usage_bytes", "forecast_disk_usage", ByteSizeValue.ofBytes(forecastShardSize))
                 .humanReadableField("forecast_disk_usage_bytes", "forecast_disk_usage", ByteSizeValue.ofBytes(forecastShardSize))
                 .humanReadableField("actual_disk_usage_bytes", "actual_disk_usage", ByteSizeValue.ofBytes(actualShardSize))
                 .humanReadableField("actual_disk_usage_bytes", "actual_disk_usage", ByteSizeValue.ofBytes(actualShardSize))

+ 40 - 9
server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/DesiredBalanceResponseTests.java

@@ -65,6 +65,8 @@ public class DesiredBalanceResponseTests extends AbstractWireSerializingTestCase
 
 
     private ClusterBalanceStats randomClusterBalanceStats() {
     private ClusterBalanceStats randomClusterBalanceStats() {
         return new ClusterBalanceStats(
         return new ClusterBalanceStats(
+            randomNonNegativeInt(),
+            randomNonNegativeInt(),
             randomBoolean()
             randomBoolean()
                 ? Map.of(DiscoveryNodeRole.DATA_CONTENT_NODE_ROLE.roleName(), randomTierBalanceStats())
                 ? Map.of(DiscoveryNodeRole.DATA_CONTENT_NODE_ROLE.roleName(), randomTierBalanceStats())
                 : randomSubsetOf(
                 : randomSubsetOf(
@@ -81,21 +83,27 @@ public class DesiredBalanceResponseTests extends AbstractWireSerializingTestCase
 
 
     private ClusterBalanceStats.TierBalanceStats randomTierBalanceStats() {
     private ClusterBalanceStats.TierBalanceStats randomTierBalanceStats() {
         return new ClusterBalanceStats.TierBalanceStats(
         return new ClusterBalanceStats.TierBalanceStats(
-            new ClusterBalanceStats.MetricStats(randomDouble(), randomDouble(), randomDouble(), randomDouble(), randomDouble()),
-            new ClusterBalanceStats.MetricStats(randomDouble(), randomDouble(), randomDouble(), randomDouble(), randomDouble()),
-            new ClusterBalanceStats.MetricStats(randomDouble(), randomDouble(), randomDouble(), randomDouble(), randomDouble()),
-            new ClusterBalanceStats.MetricStats(randomDouble(), randomDouble(), randomDouble(), randomDouble(), randomDouble())
+            randomMetricStats(),
+            randomMetricStats(),
+            randomMetricStats(),
+            randomMetricStats(),
+            randomMetricStats()
         );
         );
     }
     }
 
 
+    private ClusterBalanceStats.MetricStats randomMetricStats() {
+        return new ClusterBalanceStats.MetricStats(randomDouble(), randomDouble(), randomDouble(), randomDouble(), randomDouble());
+    }
+
     private ClusterBalanceStats.NodeBalanceStats randomNodeBalanceStats() {
     private ClusterBalanceStats.NodeBalanceStats randomNodeBalanceStats() {
         return new ClusterBalanceStats.NodeBalanceStats(
         return new ClusterBalanceStats.NodeBalanceStats(
             randomAlphaOfLength(10),
             randomAlphaOfLength(10),
             List.of(randomFrom("data_content", "data_hot", "data_warm", "data_cold")),
             List.of(randomFrom("data_content", "data_hot", "data_warm", "data_cold")),
-            randomIntBetween(0, Integer.MAX_VALUE),
+            randomNonNegativeInt(),
+            randomNonNegativeInt(),
             randomDouble(),
             randomDouble(),
-            randomLongBetween(0, Long.MAX_VALUE),
-            randomLongBetween(0, Long.MAX_VALUE)
+            randomNonNegativeLong(),
+            randomNonNegativeLong()
         );
         );
     }
     }
 
 
@@ -203,8 +211,13 @@ public class DesiredBalanceResponseTests extends AbstractWireSerializingTestCase
 
 
         // cluster balance stats
         // cluster balance stats
         Map<String, Object> clusterBalanceStats = (Map<String, Object>) json.get("cluster_balance_stats");
         Map<String, Object> clusterBalanceStats = (Map<String, Object>) json.get("cluster_balance_stats");
-        assertThat(clusterBalanceStats.keySet(), containsInAnyOrder("tiers", "nodes"));
+        assertThat(clusterBalanceStats.keySet(), containsInAnyOrder("shard_count", "undesired_shard_allocation_count", "tiers", "nodes"));
 
 
+        assertEquals(clusterBalanceStats.get("shard_count"), response.getClusterBalanceStats().shards());
+        assertEquals(
+            clusterBalanceStats.get("undesired_shard_allocation_count"),
+            response.getClusterBalanceStats().undesiredShardAllocations()
+        );
         // tier balance stats
         // tier balance stats
         Map<String, Object> tiers = (Map<String, Object>) clusterBalanceStats.get("tiers");
         Map<String, Object> tiers = (Map<String, Object>) clusterBalanceStats.get("tiers");
         assertEquals(tiers.keySet(), response.getClusterBalanceStats().tiers().keySet());
         assertEquals(tiers.keySet(), response.getClusterBalanceStats().tiers().keySet());
@@ -212,7 +225,13 @@ public class DesiredBalanceResponseTests extends AbstractWireSerializingTestCase
             Map<String, Object> tierStats = (Map<String, Object>) tiers.get(entry.getKey());
             Map<String, Object> tierStats = (Map<String, Object>) tiers.get(entry.getKey());
             assertThat(
             assertThat(
                 tierStats.keySet(),
                 tierStats.keySet(),
-                containsInAnyOrder("shard_count", "forecast_write_load", "forecast_disk_usage", "actual_disk_usage")
+                containsInAnyOrder(
+                    "shard_count",
+                    "undesired_shard_allocation_count",
+                    "forecast_write_load",
+                    "forecast_disk_usage",
+                    "actual_disk_usage"
+                )
             );
             );
 
 
             Map<String, Object> shardCountStats = (Map<String, Object>) tierStats.get("shard_count");
             Map<String, Object> shardCountStats = (Map<String, Object>) tierStats.get("shard_count");
@@ -223,6 +242,16 @@ public class DesiredBalanceResponseTests extends AbstractWireSerializingTestCase
             assertEquals(shardCountStats.get("max"), entry.getValue().shardCount().max());
             assertEquals(shardCountStats.get("max"), entry.getValue().shardCount().max());
             assertEquals(shardCountStats.get("std_dev"), entry.getValue().shardCount().stdDev());
             assertEquals(shardCountStats.get("std_dev"), entry.getValue().shardCount().stdDev());
 
 
+            Map<String, Object> undesiredShardAllocationCountStats = (Map<String, Object>) tierStats.get(
+                "undesired_shard_allocation_count"
+            );
+            assertThat(undesiredShardAllocationCountStats.keySet(), containsInAnyOrder("total", "average", "min", "max", "std_dev"));
+            assertEquals(undesiredShardAllocationCountStats.get("total"), entry.getValue().undesiredShardAllocations().total());
+            assertEquals(undesiredShardAllocationCountStats.get("average"), entry.getValue().undesiredShardAllocations().average());
+            assertEquals(undesiredShardAllocationCountStats.get("min"), entry.getValue().undesiredShardAllocations().min());
+            assertEquals(undesiredShardAllocationCountStats.get("max"), entry.getValue().undesiredShardAllocations().max());
+            assertEquals(undesiredShardAllocationCountStats.get("std_dev"), entry.getValue().undesiredShardAllocations().stdDev());
+
             Map<String, Object> forecastWriteLoadStats = (Map<String, Object>) tierStats.get("forecast_write_load");
             Map<String, Object> forecastWriteLoadStats = (Map<String, Object>) tierStats.get("forecast_write_load");
             assertThat(forecastWriteLoadStats.keySet(), containsInAnyOrder("total", "average", "min", "max", "std_dev"));
             assertThat(forecastWriteLoadStats.keySet(), containsInAnyOrder("total", "average", "min", "max", "std_dev"));
             assertEquals(forecastWriteLoadStats.get("total"), entry.getValue().forecastWriteLoad().total());
             assertEquals(forecastWriteLoadStats.get("total"), entry.getValue().forecastWriteLoad().total());
@@ -258,6 +287,7 @@ public class DesiredBalanceResponseTests extends AbstractWireSerializingTestCase
                     "node_id",
                     "node_id",
                     "roles",
                     "roles",
                     "shard_count",
                     "shard_count",
+                    "undesired_shard_allocation_count",
                     "forecast_write_load",
                     "forecast_write_load",
                     "forecast_disk_usage_bytes",
                     "forecast_disk_usage_bytes",
                     "actual_disk_usage_bytes"
                     "actual_disk_usage_bytes"
@@ -266,6 +296,7 @@ public class DesiredBalanceResponseTests extends AbstractWireSerializingTestCase
             assertEquals(nodesStats.get("node_id"), entry.getValue().nodeId());
             assertEquals(nodesStats.get("node_id"), entry.getValue().nodeId());
             assertEquals(nodesStats.get("roles"), entry.getValue().roles());
             assertEquals(nodesStats.get("roles"), entry.getValue().roles());
             assertEquals(nodesStats.get("shard_count"), entry.getValue().shards());
             assertEquals(nodesStats.get("shard_count"), entry.getValue().shards());
+            assertEquals(nodesStats.get("undesired_shard_allocation_count"), entry.getValue().undesiredShardAllocations());
             assertEquals(nodesStats.get("forecast_write_load"), entry.getValue().forecastWriteLoad());
             assertEquals(nodesStats.get("forecast_write_load"), entry.getValue().forecastWriteLoad());
             assertEquals(nodesStats.get("forecast_disk_usage_bytes"), entry.getValue().forecastShardSize());
             assertEquals(nodesStats.get("forecast_disk_usage_bytes"), entry.getValue().forecastShardSize());
             assertEquals(nodesStats.get("actual_disk_usage_bytes"), entry.getValue().actualShardSize());
             assertEquals(nodesStats.get("actual_disk_usage_bytes"), entry.getValue().actualShardSize());

+ 63 - 19
server/src/test/java/org/elasticsearch/cluster/routing/allocation/allocator/ClusterBalanceStatsTests.java

@@ -25,6 +25,7 @@ import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardId;
 
 
+import java.util.HashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.Set;
 import java.util.Set;
@@ -58,25 +59,33 @@ public class ClusterBalanceStatsTests extends ESAllocationTestCase {
             List.of(indexSizes("index-1", 1L, 1L), indexSizes("index-2", 2L, 2L), indexSizes("index-3", 3L, 3L))
             List.of(indexSizes("index-1", 1L, 1L), indexSizes("index-2", 2L, 2L), indexSizes("index-3", 3L, 3L))
         );
         );
 
 
-        var stats = ClusterBalanceStats.createFrom(clusterState, clusterInfo, TEST_WRITE_LOAD_FORECASTER);
+        var stats = ClusterBalanceStats.createFrom(
+            clusterState,
+            createDesiredBalance(clusterState),
+            clusterInfo,
+            TEST_WRITE_LOAD_FORECASTER
+        );
 
 
         assertThat(
         assertThat(
             stats,
             stats,
             equalTo(
             equalTo(
                 new ClusterBalanceStats(
                 new ClusterBalanceStats(
+                    6,
+                    0,
                     Map.of(
                     Map.of(
                         DATA_CONTENT_NODE_ROLE.roleName(),
                         DATA_CONTENT_NODE_ROLE.roleName(),
                         new ClusterBalanceStats.TierBalanceStats(
                         new ClusterBalanceStats.TierBalanceStats(
                             new MetricStats(6.0, 2.0, 2.0, 2.0, 0.0),
                             new MetricStats(6.0, 2.0, 2.0, 2.0, 0.0),
                             new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
+                            new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(12.0, 3.0, 5.0, 4.0, stdDev(3.0, 5.0, 4.0)),
                             new MetricStats(12.0, 3.0, 5.0, 4.0, stdDev(3.0, 5.0, 4.0)),
                             new MetricStats(12.0, 3.0, 5.0, 4.0, stdDev(3.0, 5.0, 4.0))
                             new MetricStats(12.0, 3.0, 5.0, 4.0, stdDev(3.0, 5.0, 4.0))
                         )
                         )
                     ),
                     ),
                     Map.ofEntries(
                     Map.ofEntries(
-                        Map.entry("node-1", new NodeBalanceStats("node-1", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 0.0, 4L, 4L)),
-                        Map.entry("node-2", new NodeBalanceStats("node-2", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 0.0, 3L, 3L)),
-                        Map.entry("node-3", new NodeBalanceStats("node-3", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 0.0, 5L, 5L))
+                        Map.entry("node-1", new NodeBalanceStats("node-1", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 0, 0.0, 4L, 4L)),
+                        Map.entry("node-2", new NodeBalanceStats("node-2", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 0, 0.0, 3L, 3L)),
+                        Map.entry("node-3", new NodeBalanceStats("node-3", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 0, 0.0, 5L, 5L))
                     )
                     )
                 )
                 )
             )
             )
@@ -102,25 +111,33 @@ public class ClusterBalanceStatsTests extends ESAllocationTestCase {
             List.of(indexSizes("index-1", 1L, 1L), indexSizes("index-2", 2L, 2L), indexSizes("index-3", 3L, 3L))
             List.of(indexSizes("index-1", 1L, 1L), indexSizes("index-2", 2L, 2L), indexSizes("index-3", 3L, 3L))
         );
         );
 
 
-        var stats = ClusterBalanceStats.createFrom(clusterState, clusterInfo, TEST_WRITE_LOAD_FORECASTER);
+        var stats = ClusterBalanceStats.createFrom(
+            clusterState,
+            createDesiredBalance(clusterState),
+            clusterInfo,
+            TEST_WRITE_LOAD_FORECASTER
+        );
 
 
         assertThat(
         assertThat(
             stats,
             stats,
             equalTo(
             equalTo(
                 new ClusterBalanceStats(
                 new ClusterBalanceStats(
+                    6,
+                    0,
                     Map.of(
                     Map.of(
                         DATA_CONTENT_NODE_ROLE.roleName(),
                         DATA_CONTENT_NODE_ROLE.roleName(),
                         new ClusterBalanceStats.TierBalanceStats(
                         new ClusterBalanceStats.TierBalanceStats(
                             new MetricStats(6.0, 2.0, 2.0, 2.0, 0.0),
                             new MetricStats(6.0, 2.0, 2.0, 2.0, 0.0),
+                            new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(12.0, 3.5, 4.5, 4.0, stdDev(3.5, 4.0, 4.5)),
                             new MetricStats(12.0, 3.5, 4.5, 4.0, stdDev(3.5, 4.0, 4.5)),
                             new MetricStats(36.0, 10.0, 14.0, 12.0, stdDev(10.0, 12.0, 14.0)),
                             new MetricStats(36.0, 10.0, 14.0, 12.0, stdDev(10.0, 12.0, 14.0)),
                             new MetricStats(12.0, 3.0, 5.0, 4.0, stdDev(3.0, 5.0, 4.0))
                             new MetricStats(12.0, 3.0, 5.0, 4.0, stdDev(3.0, 5.0, 4.0))
                         )
                         )
                     ),
                     ),
                     Map.ofEntries(
                     Map.ofEntries(
-                        Map.entry("node-1", new NodeBalanceStats("node-1", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 3.5, 14L, 4L)),
-                        Map.entry("node-2", new NodeBalanceStats("node-2", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 4.0, 12L, 3L)),
-                        Map.entry("node-3", new NodeBalanceStats("node-3", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 4.5, 10L, 5L))
+                        Map.entry("node-1", new NodeBalanceStats("node-1", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 0, 3.5, 14L, 4L)),
+                        Map.entry("node-2", new NodeBalanceStats("node-2", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 0, 4.0, 12L, 3L)),
+                        Map.entry("node-3", new NodeBalanceStats("node-3", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 2, 0, 4.5, 10L, 5L))
                     )
                     )
                 )
                 )
             )
             )
@@ -157,7 +174,12 @@ public class ClusterBalanceStatsTests extends ESAllocationTestCase {
             )
             )
         );
         );
 
 
-        var stats = ClusterBalanceStats.createFrom(clusterState, clusterInfo, TEST_WRITE_LOAD_FORECASTER);
+        var stats = ClusterBalanceStats.createFrom(
+            clusterState,
+            createDesiredBalance(clusterState),
+            clusterInfo,
+            TEST_WRITE_LOAD_FORECASTER
+        );
 
 
         var hotRoleNames = List.of(DATA_CONTENT_NODE_ROLE.roleName(), DATA_HOT_NODE_ROLE.roleName());
         var hotRoleNames = List.of(DATA_CONTENT_NODE_ROLE.roleName(), DATA_HOT_NODE_ROLE.roleName());
         var warmRoleNames = List.of(DATA_WARM_NODE_ROLE.roleName());
         var warmRoleNames = List.of(DATA_WARM_NODE_ROLE.roleName());
@@ -165,10 +187,13 @@ public class ClusterBalanceStatsTests extends ESAllocationTestCase {
             stats,
             stats,
             equalTo(
             equalTo(
                 new ClusterBalanceStats(
                 new ClusterBalanceStats(
+                    10,
+                    0,
                     Map.of(
                     Map.of(
                         DATA_CONTENT_NODE_ROLE.roleName(),
                         DATA_CONTENT_NODE_ROLE.roleName(),
                         new ClusterBalanceStats.TierBalanceStats(
                         new ClusterBalanceStats.TierBalanceStats(
                             new MetricStats(7.0, 2.0, 3.0, 7.0 / 3, stdDev(3.0, 2.0, 2.0)),
                             new MetricStats(7.0, 2.0, 3.0, 7.0 / 3, stdDev(3.0, 2.0, 2.0)),
+                            new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(21.0, 6.0, 8.5, 7.0, stdDev(6.0, 8.5, 6.5)),
                             new MetricStats(21.0, 6.0, 8.5, 7.0, stdDev(6.0, 8.5, 6.5)),
                             new MetricStats(36.0, 10.0, 16.0, 12.0, stdDev(10.0, 10.0, 16.0)),
                             new MetricStats(36.0, 10.0, 16.0, 12.0, stdDev(10.0, 10.0, 16.0)),
                             new MetricStats(34.0, 9.0, 15.0, 34.0 / 3, stdDev(9.0, 10.0, 15.0))
                             new MetricStats(34.0, 9.0, 15.0, 34.0 / 3, stdDev(9.0, 10.0, 15.0))
@@ -176,6 +201,7 @@ public class ClusterBalanceStatsTests extends ESAllocationTestCase {
                         DATA_HOT_NODE_ROLE.roleName(),
                         DATA_HOT_NODE_ROLE.roleName(),
                         new ClusterBalanceStats.TierBalanceStats(
                         new ClusterBalanceStats.TierBalanceStats(
                             new MetricStats(7.0, 2.0, 3.0, 7.0 / 3, stdDev(3.0, 2.0, 2.0)),
                             new MetricStats(7.0, 2.0, 3.0, 7.0 / 3, stdDev(3.0, 2.0, 2.0)),
+                            new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(21.0, 6.0, 8.5, 7.0, stdDev(6.0, 8.5, 6.5)),
                             new MetricStats(21.0, 6.0, 8.5, 7.0, stdDev(6.0, 8.5, 6.5)),
                             new MetricStats(36.0, 10.0, 16.0, 12.0, stdDev(10.0, 10.0, 16.0)),
                             new MetricStats(36.0, 10.0, 16.0, 12.0, stdDev(10.0, 10.0, 16.0)),
                             new MetricStats(34.0, 9.0, 15.0, 34.0 / 3, stdDev(9.0, 10.0, 15.0))
                             new MetricStats(34.0, 9.0, 15.0, 34.0 / 3, stdDev(9.0, 10.0, 15.0))
@@ -184,17 +210,18 @@ public class ClusterBalanceStatsTests extends ESAllocationTestCase {
                         new ClusterBalanceStats.TierBalanceStats(
                         new ClusterBalanceStats.TierBalanceStats(
                             new MetricStats(3.0, 1.0, 1.0, 1.0, 0.0),
                             new MetricStats(3.0, 1.0, 1.0, 1.0, 0.0),
                             new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
+                            new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(42.0, 12.0, 18.0, 14.0, stdDev(12.0, 12.0, 18.0)),
                             new MetricStats(42.0, 12.0, 18.0, 14.0, stdDev(12.0, 12.0, 18.0)),
                             new MetricStats(42.0, 12.0, 18.0, 14.0, stdDev(12.0, 12.0, 18.0))
                             new MetricStats(42.0, 12.0, 18.0, 14.0, stdDev(12.0, 12.0, 18.0))
                         )
                         )
                     ),
                     ),
                     Map.ofEntries(
                     Map.ofEntries(
-                        Map.entry("node-hot-1", new NodeBalanceStats("node-hot-1", hotRoleNames, 3, 8.5, 16L, 15L)),
-                        Map.entry("node-hot-2", new NodeBalanceStats("node-hot-2", hotRoleNames, 2, 6.0, 10L, 9L)),
-                        Map.entry("node-hot-3", new NodeBalanceStats("node-hot-3", hotRoleNames, 2, 6.5, 10L, 10L)),
-                        Map.entry("node-warm-1", new NodeBalanceStats("node-warm-1", warmRoleNames, 1, 0.0, 12L, 12L)),
-                        Map.entry("node-warm-2", new NodeBalanceStats("node-warm-2", warmRoleNames, 1, 0.0, 12L, 12L)),
-                        Map.entry("node-warm-3", new NodeBalanceStats("node-warm-3", warmRoleNames, 1, 0.0, 18L, 18L))
+                        Map.entry("node-hot-1", new NodeBalanceStats("node-hot-1", hotRoleNames, 3, 0, 8.5, 16L, 15L)),
+                        Map.entry("node-hot-2", new NodeBalanceStats("node-hot-2", hotRoleNames, 2, 0, 6.0, 10L, 9L)),
+                        Map.entry("node-hot-3", new NodeBalanceStats("node-hot-3", hotRoleNames, 2, 0, 6.5, 10L, 10L)),
+                        Map.entry("node-warm-1", new NodeBalanceStats("node-warm-1", warmRoleNames, 1, 0, 0.0, 12L, 12L)),
+                        Map.entry("node-warm-2", new NodeBalanceStats("node-warm-2", warmRoleNames, 1, 0, 0.0, 12L, 12L)),
+                        Map.entry("node-warm-3", new NodeBalanceStats("node-warm-3", warmRoleNames, 1, 0, 0.0, 18L, 18L))
                     )
                     )
                 )
                 )
             )
             )
@@ -213,15 +240,18 @@ public class ClusterBalanceStatsTests extends ESAllocationTestCase {
         );
         );
         var clusterInfo = createClusterInfo(List.of());
         var clusterInfo = createClusterInfo(List.of());
 
 
-        var stats = ClusterBalanceStats.createFrom(clusterState, clusterInfo, TEST_WRITE_LOAD_FORECASTER);
+        var stats = ClusterBalanceStats.createFrom(clusterState, null, clusterInfo, TEST_WRITE_LOAD_FORECASTER);
 
 
         assertThat(
         assertThat(
             stats,
             stats,
             equalTo(
             equalTo(
                 new ClusterBalanceStats(
                 new ClusterBalanceStats(
+                    0,
+                    0,
                     Map.of(
                     Map.of(
                         DATA_CONTENT_NODE_ROLE.roleName(),
                         DATA_CONTENT_NODE_ROLE.roleName(),
                         new ClusterBalanceStats.TierBalanceStats(
                         new ClusterBalanceStats.TierBalanceStats(
+                            new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
                             new MetricStats(0.0, 0.0, 0.0, 0.0, 0.0),
@@ -229,9 +259,9 @@ public class ClusterBalanceStatsTests extends ESAllocationTestCase {
                         )
                         )
                     ),
                     ),
                     Map.ofEntries(
                     Map.ofEntries(
-                        Map.entry("node-1", new NodeBalanceStats("node-1", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 0, 0.0, 0L, 0L)),
-                        Map.entry("node-2", new NodeBalanceStats("node-2", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 0, 0.0, 0L, 0L)),
-                        Map.entry("node-3", new NodeBalanceStats("node-3", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 0, 0.0, 0L, 0L))
+                        Map.entry("node-1", new NodeBalanceStats("node-1", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 0, 0, 0.0, 0L, 0L)),
+                        Map.entry("node-2", new NodeBalanceStats("node-2", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 0, 0, 0.0, 0L, 0L)),
+                        Map.entry("node-3", new NodeBalanceStats("node-3", List.of(DATA_CONTENT_NODE_ROLE.roleName()), 0, 0, 0.0, 0L, 0L))
                     )
                     )
                 )
                 )
             )
             )
@@ -269,6 +299,20 @@ public class ClusterBalanceStatsTests extends ESAllocationTestCase {
             .build();
             .build();
     }
     }
 
 
+    private static DesiredBalance createDesiredBalance(ClusterState state) {
+        var assignments = new HashMap<ShardId, ShardAssignment>();
+        for (var indexRoutingTable : state.getRoutingTable()) {
+            for (int i = 0; i < indexRoutingTable.size(); i++) {
+                var indexShardRoutingTable = indexRoutingTable.shard(i);
+                assignments.put(
+                    indexShardRoutingTable.shardId(),
+                    new ShardAssignment(Set.of(indexShardRoutingTable.primaryShard().currentNodeId()), 1, 0, 0)
+                );
+            }
+        }
+        return new DesiredBalance(1, assignments);
+    }
+
     private static Tuple<IndexMetadata.Builder, String[]> startedIndex(
     private static Tuple<IndexMetadata.Builder, String[]> startedIndex(
         String indexName,
         String indexName,
         @Nullable Double indexWriteLoadForecast,
         @Nullable Double indexWriteLoadForecast,