Browse Source

Forecast average shard size during rollovers (#91561)

Francisco Fernández Castaño 2 years ago
parent
commit
d891b1fb1b
15 changed files with 573 additions and 192 deletions
  1. 5 0
      docs/changelog/91561.yaml
  2. 124 9
      modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/DataStreamIT.java
  3. 8 7
      modules/data-streams/src/test/java/org/elasticsearch/datastreams/MetadataDataStreamRolloverServiceTests.java
  4. 40 7
      server/src/main/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverService.java
  5. 4 4
      server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java
  6. 80 44
      server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java
  7. 172 0
      server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadataStats.java
  8. 1 36
      server/src/main/java/org/elasticsearch/cluster/metadata/IndexWriteLoad.java
  9. 38 19
      server/src/test/java/org/elasticsearch/cluster/ClusterStateTests.java
  10. 18 16
      server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataStatsSerializationTests.java
  11. 19 39
      server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataStatsTests.java
  12. 10 7
      server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java
  13. 46 0
      server/src/test/java/org/elasticsearch/cluster/metadata/IndexWriteLoadTests.java
  14. 5 2
      x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java
  15. 3 2
      x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java

+ 5 - 0
docs/changelog/91561.yaml

@@ -0,0 +1,5 @@
+pr: 91561
+summary: Forecast average shard size during rollovers
+area: Allocation
+type: enhancement
+issues: []

+ 124 - 9
modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/DataStreamIT.java

@@ -63,6 +63,8 @@ import org.elasticsearch.cluster.metadata.DataStream;
 import org.elasticsearch.cluster.metadata.DataStreamAction;
 import org.elasticsearch.cluster.metadata.DataStreamAlias;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.IndexMetadataStats;
+import org.elasticsearch.cluster.metadata.IndexWriteLoad;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.metadata.Template;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
@@ -78,7 +80,6 @@ import org.elasticsearch.index.mapper.DataStreamTimestampFieldMapper;
 import org.elasticsearch.index.mapper.DateFieldMapper;
 import org.elasticsearch.index.mapper.MapperParsingException;
 import org.elasticsearch.index.query.TermQueryBuilder;
-import org.elasticsearch.index.shard.IndexWriteLoad;
 import org.elasticsearch.index.shard.IndexingStats;
 import org.elasticsearch.indices.InvalidAliasNameException;
 import org.elasticsearch.indices.InvalidIndexNameException;
@@ -103,6 +104,7 @@ import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Optional;
+import java.util.OptionalLong;
 import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.CyclicBarrier;
@@ -2010,7 +2012,7 @@ public class DataStreamIT extends ESIntegTestCase {
         assertEquals(searchResponse.getTotalShards(), 4);
     }
 
-    public void testWriteIndexWriteLoadIsStoredAfterRollover() throws Exception {
+    public void testWriteIndexWriteLoadAndAvgShardSizeIsStoredAfterRollover() throws Exception {
         final String dataStreamName = "logs-es";
         final int numberOfShards = randomIntBetween(1, 5);
         final int numberOfReplicas = randomIntBetween(0, 1);
@@ -2047,21 +2049,26 @@ public class DataStreamIT extends ESIntegTestCase {
 
         for (Index index : dataStream.getIndices()) {
             final IndexMetadata indexMetadata = clusterState.metadata().index(index);
-            final IndexWriteLoad indexWriteLoad = indexMetadata.getWriteLoad();
+            final IndexMetadataStats metadataStats = indexMetadata.getStats();
 
             if (index.equals(dataStream.getWriteIndex()) == false) {
-                assertThat(indexWriteLoad, is(notNullValue()));
+                assertThat(metadataStats, is(notNullValue()));
+
+                final var averageShardSize = metadataStats.averageShardSize();
+                assertThat(averageShardSize.getAverageSizeInBytes(), is(greaterThan(0L)));
+
+                final IndexWriteLoad indexWriteLoad = metadataStats.writeLoad();
                 for (int shardId = 0; shardId < numberOfShards; shardId++) {
                     assertThat(indexWriteLoad.getWriteLoadForShard(shardId).getAsDouble(), is(greaterThanOrEqualTo(0.0)));
                     assertThat(indexWriteLoad.getUptimeInMillisForShard(shardId).getAsLong(), is(greaterThan(0L)));
                 }
             } else {
-                assertThat(indexWriteLoad, is(nullValue()));
+                assertThat(metadataStats, is(nullValue()));
             }
         }
     }
 
-    public void testWriteLoadIsStoredInABestEffort() throws Exception {
+    public void testWriteLoadAndAvgShardSizeIsStoredInABestEffort() throws Exception {
         // This test simulates the scenario where some nodes fail to respond
         // to the IndicesStatsRequest and therefore only a partial view of the
         // write-index write-load is stored during rollover.
@@ -2115,10 +2122,12 @@ public class DataStreamIT extends ESIntegTestCase {
 
         for (Index index : dataStream.getIndices()) {
             final IndexMetadata indexMetadata = clusterState.metadata().index(index);
-            final IndexWriteLoad indexWriteLoad = indexMetadata.getWriteLoad();
+            final IndexMetadataStats metadataStats = indexMetadata.getStats();
 
             if (index.equals(dataStream.getWriteIndex()) == false) {
-                assertThat(indexWriteLoad, is(notNullValue()));
+                assertThat(metadataStats, is(notNullValue()));
+
+                final IndexWriteLoad indexWriteLoad = metadataStats.writeLoad();
                 // All stats request performed against nodes holding the shard 0 failed
                 assertThat(indexWriteLoad.getWriteLoadForShard(0).isPresent(), is(false));
                 assertThat(indexWriteLoad.getUptimeInMillisForShard(0).isPresent(), is(false));
@@ -2126,10 +2135,116 @@ public class DataStreamIT extends ESIntegTestCase {
                 // At least one of the shard 1 copies responded with stats
                 assertThat(indexWriteLoad.getWriteLoadForShard(1).getAsDouble(), is(greaterThanOrEqualTo(0.0)));
                 assertThat(indexWriteLoad.getUptimeInMillisForShard(1).getAsLong(), is(greaterThan(0L)));
+
+                final var averageShardSize = metadataStats.averageShardSize();
+                assertThat(averageShardSize.numberOfShards(), is(equalTo(1)));
+
+                assertThat(averageShardSize.getAverageSizeInBytes(), is(greaterThan(0L)));
             } else {
-                assertThat(indexWriteLoad, is(nullValue()));
+                assertThat(metadataStats, is(nullValue()));
+            }
+        }
+    }
+
+    public void testNoShardSizeIsForecastedWhenAllShardStatRequestsFail() throws Exception {
+        final String dataOnlyNode = internalCluster().startDataOnlyNode();
+        final String dataStreamName = "logs-es";
+
+        final var indexSettings = Settings.builder()
+            .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
+            .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+            .put("index.routing.allocation.require._name", dataOnlyNode)
+            .build();
+        DataStreamIT.putComposableIndexTemplate("my-template", null, List.of("logs-*"), indexSettings, null);
+        final var createDataStreamRequest = new CreateDataStreamAction.Request(dataStreamName);
+        assertAcked(client().execute(CreateDataStreamAction.INSTANCE, createDataStreamRequest).actionGet());
+
+        for (int i = 0; i < 10; i++) {
+            indexDocs(dataStreamName, randomIntBetween(100, 200));
+        }
+
+        final ClusterState clusterStateBeforeRollover = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
+        final DataStream dataStreamBeforeRollover = clusterStateBeforeRollover.getMetadata().dataStreams().get(dataStreamName);
+        final String assignedShardNodeId = clusterStateBeforeRollover.routingTable()
+            .index(dataStreamBeforeRollover.getWriteIndex())
+            .shard(0)
+            .primaryShard()
+            .currentNodeId();
+
+        final String nodeName = clusterStateBeforeRollover.nodes().resolveNode(assignedShardNodeId).getName();
+        final MockTransportService transportService = (MockTransportService) internalCluster().getInstance(
+            TransportService.class,
+            nodeName
+        );
+        transportService.addRequestHandlingBehavior(
+            IndicesStatsAction.NAME + "[n]",
+            (handler, request, channel, task) -> channel.sendResponse(new RuntimeException("Unable to get stats"))
+        );
+
+        assertAcked(client().admin().indices().rolloverIndex(new RolloverRequest(dataStreamName, null)).actionGet());
+
+        final ClusterState clusterState = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
+        final DataStream dataStream = clusterState.getMetadata().dataStreams().get(dataStreamName);
+        final IndexMetadata currentWriteIndexMetadata = clusterState.metadata().getIndexSafe(dataStream.getWriteIndex());
+
+        // When all shard stats request fail, we cannot forecast the shard size
+        assertThat(currentWriteIndexMetadata.getForecastedShardSizeInBytes().isEmpty(), is(equalTo(true)));
+    }
+
+    public void testShardSizeIsForecastedDuringRollover() throws Exception {
+        final String dataStreamName = "logs-es";
+        final int numberOfShards = randomIntBetween(1, 5);
+        final int numberOfReplicas = randomIntBetween(0, 1);
+        final var indexSettings = Settings.builder()
+            .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards)
+            .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numberOfReplicas)
+            .build();
+        DataStreamIT.putComposableIndexTemplate("my-template", null, List.of("logs-*"), indexSettings, null);
+        final var request = new CreateDataStreamAction.Request(dataStreamName);
+        assertAcked(client().execute(CreateDataStreamAction.INSTANCE, request).actionGet());
+
+        for (int i = 0; i < 4; i++) {
+            for (int j = 0; j < 10; j++) {
+                indexDocs(dataStreamName, randomIntBetween(100, 200));
             }
+
+            // Ensure that we get a stable size to compare against the expected size
+            assertThat(
+                client().admin().indices().prepareForceMerge().setFlush(true).setMaxNumSegments(1).get().getSuccessfulShards(),
+                is(greaterThanOrEqualTo(numberOfShards))
+            );
+
+            assertAcked(client().admin().indices().rolloverIndex(new RolloverRequest(dataStreamName, null)).actionGet());
         }
+
+        final ClusterState clusterState = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
+        final DataStream dataStream = clusterState.getMetadata().dataStreams().get(dataStreamName);
+
+        final List<String> dataStreamReadIndices = dataStream.getIndices()
+            .stream()
+            .filter(index -> index.equals(dataStream.getWriteIndex()) == false)
+            .map(Index::getName)
+            .toList();
+
+        final IndicesStatsResponse indicesStatsResponse = client().admin()
+            .indices()
+            .prepareStats(dataStreamReadIndices.toArray(new String[dataStreamReadIndices.size()]))
+            .setStore(true)
+            .get();
+        long expectedTotalSizeInBytes = 0;
+        int shardCount = 0;
+        for (ShardStats shard : indicesStatsResponse.getShards()) {
+            if (shard.getShardRouting().primary() == false) {
+                continue;
+            }
+            expectedTotalSizeInBytes += shard.getStats().getDocs().getTotalSizeInBytes();
+            shardCount++;
+        }
+
+        final IndexMetadata writeIndexMetadata = clusterState.metadata().index(dataStream.getWriteIndex());
+        final OptionalLong forecastedShardSizeInBytes = writeIndexMetadata.getForecastedShardSizeInBytes();
+        assertThat(forecastedShardSizeInBytes.isPresent(), is(equalTo(true)));
+        assertThat(forecastedShardSizeInBytes.getAsLong(), is(equalTo(expectedTotalSizeInBytes / shardCount)));
     }
 
     static void putComposableIndexTemplate(

+ 8 - 7
modules/data-streams/src/test/java/org/elasticsearch/datastreams/MetadataDataStreamRolloverServiceTests.java

@@ -20,6 +20,8 @@ import org.elasticsearch.cluster.metadata.DataStream;
 import org.elasticsearch.cluster.metadata.DataStreamTestHelper;
 import org.elasticsearch.cluster.metadata.IndexAbstraction;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.IndexMetadataStats;
+import org.elasticsearch.cluster.metadata.IndexWriteLoad;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.metadata.Template;
 import org.elasticsearch.common.settings.Settings;
@@ -27,7 +29,6 @@ import org.elasticsearch.index.Index;
 import org.elasticsearch.index.IndexMode;
 import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.MapperTestUtils;
-import org.elasticsearch.index.shard.IndexWriteLoad;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -102,7 +103,7 @@ public class MetadataDataStreamRolloverServiceTests extends ESTestCase {
             MaxDocsCondition condition = new MaxDocsCondition(randomNonNegativeLong());
             List<Condition<?>> metConditions = Collections.singletonList(condition);
             CreateIndexRequest createIndexRequest = new CreateIndexRequest("_na_");
-            IndexWriteLoad indexWriteLoad = IndexWriteLoad.builder(1).build();
+            IndexMetadataStats indexStats = new IndexMetadataStats(IndexWriteLoad.builder(1).build(), 10, 10);
 
             long before = testThreadPool.absoluteTimeInMillis();
             MetadataRolloverService.RolloverResult rolloverResult = rolloverService.rolloverClusterState(
@@ -114,7 +115,7 @@ public class MetadataDataStreamRolloverServiceTests extends ESTestCase {
                 now,
                 randomBoolean(),
                 false,
-                indexWriteLoad
+                indexStats
             );
             long after = testThreadPool.absoluteTimeInMillis();
 
@@ -142,16 +143,16 @@ public class MetadataDataStreamRolloverServiceTests extends ESTestCase {
             IndexMetadata im = rolloverMetadata.index(rolloverMetadata.dataStreams().get(dataStreamName).getIndices().get(0));
             Instant startTime1 = IndexSettings.TIME_SERIES_START_TIME.get(im.getSettings());
             Instant endTime1 = IndexSettings.TIME_SERIES_END_TIME.get(im.getSettings());
-            IndexWriteLoad indexWriteLoad1 = im.getWriteLoad();
+            IndexMetadataStats indexStats1 = im.getStats();
             im = rolloverMetadata.index(rolloverMetadata.dataStreams().get(dataStreamName).getIndices().get(1));
             Instant startTime2 = IndexSettings.TIME_SERIES_START_TIME.get(im.getSettings());
             Instant endTime2 = IndexSettings.TIME_SERIES_END_TIME.get(im.getSettings());
-            IndexWriteLoad indexWriteLoad2 = im.getWriteLoad();
+            IndexMetadataStats indexStats2 = im.getStats();
             assertThat(startTime1.isBefore(endTime1), is(true));
             assertThat(endTime1, equalTo(startTime2));
             assertThat(endTime2.isAfter(endTime1), is(true));
-            assertThat(indexWriteLoad1, is(equalTo(indexWriteLoad)));
-            assertThat(indexWriteLoad2, is(nullValue()));
+            assertThat(indexStats1, is(equalTo(indexStats)));
+            assertThat(indexStats2, is(nullValue()));
         } finally {
             testThreadPool.shutdown();
         }

+ 40 - 7
server/src/main/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverService.java

@@ -18,6 +18,7 @@ import org.elasticsearch.cluster.metadata.ComposableIndexTemplate;
 import org.elasticsearch.cluster.metadata.DataStream;
 import org.elasticsearch.cluster.metadata.IndexAbstraction;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.IndexMetadataStats;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.IndexTemplateMetadata;
 import org.elasticsearch.cluster.metadata.Metadata;
@@ -31,7 +32,6 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.Index;
-import org.elasticsearch.index.shard.IndexWriteLoad;
 import org.elasticsearch.indices.SystemDataStreamDescriptor;
 import org.elasticsearch.indices.SystemIndices;
 import org.elasticsearch.snapshots.SnapshotInProgressException;
@@ -43,6 +43,7 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
+import java.util.Objects;
 import java.util.regex.Pattern;
 
 import static org.elasticsearch.cluster.metadata.IndexAbstraction.Type.ALIAS;
@@ -105,7 +106,7 @@ public class MetadataRolloverService {
         Instant now,
         boolean silent,
         boolean onlyValidate,
-        @Nullable IndexWriteLoad sourceIndexWriteLoad
+        @Nullable IndexMetadataStats sourceIndexStats
     ) throws Exception {
         validate(currentState.metadata(), rolloverTarget, newIndexName, createIndexRequest);
         final IndexAbstraction indexAbstraction = currentState.metadata().getIndicesLookup().get(rolloverTarget);
@@ -129,7 +130,7 @@ public class MetadataRolloverService {
                 now,
                 silent,
                 onlyValidate,
-                sourceIndexWriteLoad
+                sourceIndexStats
             );
             default ->
                 // the validate method above prevents this case
@@ -240,7 +241,7 @@ public class MetadataRolloverService {
         Instant now,
         boolean silent,
         boolean onlyValidate,
-        @Nullable IndexWriteLoad sourceIndexWriteLoad
+        @Nullable IndexMetadataStats sourceIndexStats
     ) throws Exception {
 
         if (SnapshotsService.snapshottingDataStreams(currentState, Collections.singleton(dataStream.getName())).isEmpty() == false) {
@@ -302,18 +303,50 @@ public class MetadataRolloverService {
 
         Metadata.Builder metadataBuilder = Metadata.builder(newState.metadata())
             .put(
-                IndexMetadata.builder(newState.metadata().index(originalWriteIndex))
-                    .indexWriteLoad(sourceIndexWriteLoad)
-                    .putRolloverInfo(rolloverInfo)
+                IndexMetadata.builder(newState.metadata().index(originalWriteIndex)).stats(sourceIndexStats).putRolloverInfo(rolloverInfo)
             );
 
         metadataBuilder = writeLoadForecaster.withWriteLoadForecastForWriteIndex(dataStreamName, metadataBuilder);
+        metadataBuilder = withShardSizeForecastForWriteIndex(dataStreamName, metadataBuilder);
 
         newState = ClusterState.builder(newState).metadata(metadataBuilder).build();
 
         return new RolloverResult(newWriteIndexName, originalWriteIndex.getName(), newState);
     }
 
+    public Metadata.Builder withShardSizeForecastForWriteIndex(String dataStreamName, Metadata.Builder metadata) {
+        final DataStream dataStream = metadata.dataStream(dataStreamName);
+
+        if (dataStream == null) {
+            return metadata;
+        }
+
+        final List<IndexMetadataStats> indicesStats = dataStream.getIndices()
+            .stream()
+            .map(metadata::getSafe)
+            .map(IndexMetadata::getStats)
+            .filter(Objects::nonNull)
+            .toList();
+
+        long totalSizeInBytes = 0;
+        int shardCount = 0;
+        for (IndexMetadataStats stats : indicesStats) {
+            var averageShardSize = stats.averageShardSize();
+            totalSizeInBytes += averageShardSize.totalSizeInBytes();
+            shardCount += averageShardSize.numberOfShards();
+        }
+
+        if (shardCount == 0) {
+            return metadata;
+        }
+
+        long shardSizeInBytesForecast = totalSizeInBytes / shardCount;
+        final IndexMetadata writeIndex = metadata.getSafe(dataStream.getWriteIndex());
+        metadata.put(IndexMetadata.builder(writeIndex).shardSizeInBytesForecast(shardSizeInBytesForecast).build(), false);
+
+        return metadata;
+    }
+
     static String generateRolloverIndexName(String sourceIndexName) {
         String resolvedName = IndexNameExpressionResolver.resolveDateMathExpression(sourceIndexName);
         final boolean isDateMath = sourceIndexName.equals(resolvedName) == false;

+ 4 - 4
server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java

@@ -29,6 +29,7 @@ import org.elasticsearch.cluster.block.ClusterBlockException;
 import org.elasticsearch.cluster.block.ClusterBlockLevel;
 import org.elasticsearch.cluster.metadata.IndexAbstraction;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.IndexMetadataStats;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.routing.allocation.AllocationService;
@@ -40,7 +41,6 @@ import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.index.shard.DocsStats;
-import org.elasticsearch.index.shard.IndexWriteLoad;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -320,8 +320,8 @@ public class TransportRolloverAction extends TransportMasterNodeAction<RolloverR
                     .getIndicesLookup()
                     .get(rolloverRequest.getRolloverTarget());
 
-                final IndexWriteLoad sourceIndexWriteLoad = rolloverTargetAbstraction.getType() == IndexAbstraction.Type.DATA_STREAM
-                    ? IndexWriteLoad.fromStats(rolloverSourceIndex, rolloverTask.statsResponse())
+                final IndexMetadataStats sourceIndexStats = rolloverTargetAbstraction.getType() == IndexAbstraction.Type.DATA_STREAM
+                    ? IndexMetadataStats.fromStatsResponse(rolloverSourceIndex, rolloverTask.statsResponse())
                     : null;
 
                 // Perform the actual rollover
@@ -334,7 +334,7 @@ public class TransportRolloverAction extends TransportMasterNodeAction<RolloverR
                     Instant.now(),
                     false,
                     false,
-                    sourceIndexWriteLoad
+                    sourceIndexStats
                 );
                 results.add(rolloverResult);
                 logger.trace("rollover result [{}]", rolloverResult);

+ 80 - 44
server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadata.java

@@ -43,7 +43,6 @@ import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.seqno.SequenceNumbers;
 import org.elasticsearch.index.shard.IndexLongFieldRange;
-import org.elasticsearch.index.shard.IndexWriteLoad;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.index.shard.ShardLongFieldRange;
 import org.elasticsearch.rest.RestStatus;
@@ -70,6 +69,7 @@ import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
 import java.util.OptionalDouble;
+import java.util.OptionalLong;
 import java.util.Set;
 import java.util.function.Function;
 
@@ -522,15 +522,17 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
     static final String KEY_SYSTEM = "system";
     static final String KEY_TIMESTAMP_RANGE = "timestamp_range";
     public static final String KEY_PRIMARY_TERMS = "primary_terms";
-    public static final String KEY_WRITE_LOAD = "write_load";
+    public static final String KEY_STATS = "stats";
 
     public static final String KEY_WRITE_LOAD_FORECAST = "write_load_forecast";
 
+    public static final String KEY_SHARD_SIZE_FORECAST = "shard_size_forecast";
+
     public static final String INDEX_STATE_FILE_PREFIX = "state-";
 
     static final Version SYSTEM_INDEX_FLAG_ADDED = Version.V_7_10_0;
 
-    static final Version WRITE_LOAD_ADDED = Version.V_8_6_0;
+    static final Version STATS_AND_FORECAST_ADDED = Version.V_8_6_0;
 
     private final int routingNumShards;
     private final int routingFactor;
@@ -610,9 +612,11 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
     @Nullable
     private final Instant timeSeriesEnd;
     @Nullable
-    private final IndexWriteLoad writeLoad;
+    private final IndexMetadataStats stats;
     @Nullable
     private final Double writeLoadForecast;
+    @Nullable
+    private final Long shardSizeInBytesForecast;
 
     private IndexMetadata(
         final Index index,
@@ -656,8 +660,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
         @Nullable final Instant timeSeriesStart,
         @Nullable final Instant timeSeriesEnd,
         final Version indexCompatibilityVersion,
-        @Nullable final IndexWriteLoad writeLoad,
-        @Nullable final Double writeLoadForecast
+        @Nullable final IndexMetadataStats stats,
+        @Nullable final Double writeLoadForecast,
+        @Nullable Long shardSizeInBytesForecast
     ) {
         this.index = index;
         this.version = version;
@@ -708,8 +713,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
         this.indexMode = indexMode;
         this.timeSeriesStart = timeSeriesStart;
         this.timeSeriesEnd = timeSeriesEnd;
-        this.writeLoad = writeLoad;
+        this.stats = stats;
         this.writeLoadForecast = writeLoadForecast;
+        this.shardSizeInBytesForecast = shardSizeInBytesForecast;
         assert numberOfShards * routingFactor == routingNumShards : routingNumShards + " must be a multiple of " + numberOfShards;
     }
 
@@ -759,8 +765,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             this.timeSeriesStart,
             this.timeSeriesEnd,
             this.indexCompatibilityVersion,
-            this.writeLoad,
-            this.writeLoadForecast
+            this.stats,
+            this.writeLoadForecast,
+            this.shardSizeInBytesForecast
         );
     }
 
@@ -816,8 +823,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             this.timeSeriesStart,
             this.timeSeriesEnd,
             this.indexCompatibilityVersion,
-            this.writeLoad,
-            this.writeLoadForecast
+            this.stats,
+            this.writeLoadForecast,
+            this.shardSizeInBytesForecast
         );
     }
 
@@ -871,8 +879,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             this.timeSeriesStart,
             this.timeSeriesEnd,
             this.indexCompatibilityVersion,
-            this.writeLoad,
-            this.writeLoadForecast
+            this.stats,
+            this.writeLoadForecast,
+            this.shardSizeInBytesForecast
         );
     }
 
@@ -926,8 +935,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             this.timeSeriesStart,
             this.timeSeriesEnd,
             this.indexCompatibilityVersion,
-            this.writeLoad,
-            this.writeLoadForecast
+            this.stats,
+            this.writeLoadForecast,
+            this.shardSizeInBytesForecast
         );
     }
 
@@ -977,8 +987,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             this.timeSeriesStart,
             this.timeSeriesEnd,
             this.indexCompatibilityVersion,
-            this.writeLoad,
-            this.writeLoadForecast
+            this.stats,
+            this.writeLoadForecast,
+            this.shardSizeInBytesForecast
         );
     }
 
@@ -1170,14 +1181,18 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
     }
 
     @Nullable
-    public IndexWriteLoad getWriteLoad() {
-        return writeLoad;
+    public IndexMetadataStats getStats() {
+        return stats;
     }
 
     public OptionalDouble getForecastedWriteLoad() {
         return writeLoadForecast == null ? OptionalDouble.empty() : OptionalDouble.of(writeLoadForecast);
     }
 
+    public OptionalLong getForecastedShardSizeInBytes() {
+        return shardSizeInBytesForecast == null ? OptionalLong.empty() : OptionalLong.of(shardSizeInBytesForecast);
+    }
+
     public static final String INDEX_RESIZE_SOURCE_UUID_KEY = "index.resize.source.uuid";
     public static final String INDEX_RESIZE_SOURCE_NAME_KEY = "index.resize.source.name";
     public static final Setting<String> INDEX_RESIZE_SOURCE_UUID = Setting.simpleString(INDEX_RESIZE_SOURCE_UUID_KEY);
@@ -1412,8 +1427,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
         private final Diff<ImmutableOpenMap<String, RolloverInfo>> rolloverInfos;
         private final boolean isSystem;
         private final IndexLongFieldRange timestampRange;
-        private final IndexWriteLoad indexWriteLoad;
+        private final IndexMetadataStats stats;
         private final Double indexWriteLoadForecast;
+        private final Long shardSizeInBytesForecast;
 
         IndexMetadataDiff(IndexMetadata before, IndexMetadata after) {
             index = after.index.getName();
@@ -1447,8 +1463,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             rolloverInfos = DiffableUtils.diff(before.rolloverInfos, after.rolloverInfos, DiffableUtils.getStringKeySerializer());
             isSystem = after.isSystem;
             timestampRange = after.timestampRange;
-            indexWriteLoad = after.writeLoad;
+            stats = after.stats;
             indexWriteLoadForecast = after.writeLoadForecast;
+            shardSizeInBytesForecast = after.shardSizeInBytesForecast;
         }
 
         private static final DiffableUtils.DiffableValueReader<String, AliasMetadata> ALIAS_METADATA_DIFF_VALUE_READER =
@@ -1499,12 +1516,14 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
                 isSystem = false;
             }
             timestampRange = IndexLongFieldRange.readFrom(in);
-            if (in.getVersion().onOrAfter(WRITE_LOAD_ADDED)) {
-                indexWriteLoad = in.readOptionalWriteable(IndexWriteLoad::new);
+            if (in.getVersion().onOrAfter(STATS_AND_FORECAST_ADDED)) {
+                stats = in.readOptionalWriteable(IndexMetadataStats::new);
                 indexWriteLoadForecast = in.readOptionalDouble();
+                shardSizeInBytesForecast = in.readOptionalLong();
             } else {
-                indexWriteLoad = null;
+                stats = null;
                 indexWriteLoadForecast = null;
+                shardSizeInBytesForecast = null;
             }
         }
 
@@ -1536,9 +1555,10 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
                 out.writeBoolean(isSystem);
             }
             timestampRange.writeTo(out);
-            if (out.getVersion().onOrAfter(WRITE_LOAD_ADDED)) {
-                out.writeOptionalWriteable(indexWriteLoad);
+            if (out.getVersion().onOrAfter(STATS_AND_FORECAST_ADDED)) {
+                out.writeOptionalWriteable(stats);
                 out.writeOptionalDouble(indexWriteLoadForecast);
+                out.writeOptionalLong(shardSizeInBytesForecast);
             }
         }
 
@@ -1566,8 +1586,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             builder.rolloverInfos.putAllFromMap(rolloverInfos.apply(part.rolloverInfos));
             builder.system(isSystem);
             builder.timestampRange(timestampRange);
-            builder.indexWriteLoad(indexWriteLoad);
+            builder.stats(stats);
             builder.indexWriteLoadForecast(indexWriteLoadForecast);
+            builder.shardSizeInBytesForecast(shardSizeInBytesForecast);
             return builder.build();
         }
     }
@@ -1630,9 +1651,10 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
         }
         builder.timestampRange(IndexLongFieldRange.readFrom(in));
 
-        if (in.getVersion().onOrAfter(WRITE_LOAD_ADDED)) {
-            builder.indexWriteLoad(in.readOptionalWriteable(IndexWriteLoad::new));
+        if (in.getVersion().onOrAfter(STATS_AND_FORECAST_ADDED)) {
+            builder.stats(in.readOptionalWriteable(IndexMetadataStats::new));
             builder.indexWriteLoadForecast(in.readOptionalDouble());
+            builder.shardSizeInBytesForecast(in.readOptionalLong());
         }
         return builder.build();
     }
@@ -1675,9 +1697,10 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             out.writeBoolean(isSystem);
         }
         timestampRange.writeTo(out);
-        if (out.getVersion().onOrAfter(WRITE_LOAD_ADDED)) {
-            out.writeOptionalWriteable(writeLoad);
+        if (out.getVersion().onOrAfter(STATS_AND_FORECAST_ADDED)) {
+            out.writeOptionalWriteable(stats);
             out.writeOptionalDouble(writeLoadForecast);
+            out.writeOptionalLong(shardSizeInBytesForecast);
         }
     }
 
@@ -1725,8 +1748,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
         private boolean isSystem;
         private IndexLongFieldRange timestampRange = IndexLongFieldRange.NO_SHARDS;
         private LifecycleExecutionState lifecycleExecutionState = LifecycleExecutionState.EMPTY_STATE;
-        private IndexWriteLoad indexWriteLoad = null;
+        private IndexMetadataStats stats = null;
         private Double indexWriteLoadForecast = null;
+        private Long shardSizeInBytesForecast = null;
 
         public Builder(String index) {
             this.index = index;
@@ -1755,8 +1779,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             this.isSystem = indexMetadata.isSystem;
             this.timestampRange = indexMetadata.timestampRange;
             this.lifecycleExecutionState = indexMetadata.lifecycleExecutionState;
-            this.indexWriteLoad = indexMetadata.writeLoad;
+            this.stats = indexMetadata.stats;
             this.indexWriteLoadForecast = indexMetadata.writeLoadForecast;
+            this.shardSizeInBytesForecast = indexMetadata.shardSizeInBytesForecast;
         }
 
         public Builder index(String index) {
@@ -1971,8 +1996,8 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             return this;
         }
 
-        public Builder indexWriteLoad(IndexWriteLoad indexWriteLoad) {
-            this.indexWriteLoad = indexWriteLoad;
+        public Builder stats(IndexMetadataStats stats) {
+            this.stats = stats;
             return this;
         }
 
@@ -1981,6 +2006,11 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             return this;
         }
 
+        public Builder shardSizeInBytesForecast(Long shardSizeInBytesForecast) {
+            this.shardSizeInBytesForecast = shardSizeInBytesForecast;
+            return this;
+        }
+
         public IndexMetadata build() {
             /*
              * We expect that the metadata has been properly built to set the number of shards and the number of replicas, and do not rely
@@ -2096,11 +2126,11 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
                 lifecycleExecutionState = LifecycleExecutionState.EMPTY_STATE;
             }
 
-            if (indexWriteLoad != null && indexWriteLoad.numberOfShards() != numberOfShards) {
+            if (stats != null && stats.writeLoad().numberOfShards() != numberOfShards) {
                 assert false;
                 throw new IllegalArgumentException(
                     "The number of write load shards ["
-                        + indexWriteLoad.numberOfShards()
+                        + stats.writeLoad().numberOfShards()
                         + "] is different than the number of index shards ["
                         + numberOfShards
                         + "]"
@@ -2159,8 +2189,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
                 isTsdb ? IndexSettings.TIME_SERIES_START_TIME.get(settings) : null,
                 isTsdb ? IndexSettings.TIME_SERIES_END_TIME.get(settings) : null,
                 SETTING_INDEX_VERSION_COMPATIBILITY.get(settings),
-                indexWriteLoad,
-                indexWriteLoadForecast
+                stats,
+                indexWriteLoadForecast,
+                shardSizeInBytesForecast
             );
         }
 
@@ -2272,9 +2303,9 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
             indexMetadata.timestampRange.toXContent(builder, params);
             builder.endObject();
 
-            if (indexMetadata.writeLoad != null) {
-                builder.startObject(KEY_WRITE_LOAD);
-                indexMetadata.writeLoad.toXContent(builder, params);
+            if (indexMetadata.stats != null) {
+                builder.startObject(KEY_STATS);
+                indexMetadata.stats.toXContent(builder, params);
                 builder.endObject();
             }
 
@@ -2282,6 +2313,10 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
                 builder.field(KEY_WRITE_LOAD_FORECAST, indexMetadata.writeLoadForecast);
             }
 
+            if (indexMetadata.shardSizeInBytesForecast != null) {
+                builder.field(KEY_SHARD_SIZE_FORECAST, indexMetadata.shardSizeInBytesForecast);
+            }
+
             builder.endObject();
         }
 
@@ -2356,8 +2391,8 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
                         case KEY_TIMESTAMP_RANGE:
                             builder.timestampRange(IndexLongFieldRange.fromXContent(parser));
                             break;
-                        case KEY_WRITE_LOAD:
-                            builder.indexWriteLoad(IndexWriteLoad.fromXContent(parser));
+                        case KEY_STATS:
+                            builder.stats(IndexMetadataStats.fromXContent(parser));
                             break;
                         default:
                             // assume it's custom index metadata
@@ -2416,6 +2451,7 @@ public class IndexMetadata implements Diffable<IndexMetadata>, ToXContentFragmen
                             builder.putMapping(mappingsByHash.get(parser.text()));
                         }
                         case KEY_WRITE_LOAD_FORECAST -> builder.indexWriteLoadForecast(parser.doubleValue());
+                        case KEY_SHARD_SIZE_FORECAST -> builder.shardSizeInBytesForecast(parser.longValue());
                         default -> throw new IllegalArgumentException("Unexpected field [" + currentFieldName + "]");
                     }
                 } else {

+ 172 - 0
server/src/main/java/org/elasticsearch/cluster/metadata/IndexMetadataStats.java

@@ -0,0 +1,172 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.cluster.metadata;
+
+import org.elasticsearch.action.admin.indices.stats.IndexShardStats;
+import org.elasticsearch.action.admin.indices.stats.IndexStats;
+import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ToXContentFragment;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Objects;
+
+public record IndexMetadataStats(IndexWriteLoad indexWriteLoad, AverageShardSize averageShardSize)
+    implements
+        Writeable,
+        ToXContentFragment {
+
+    public static final ParseField WRITE_LOAD_FIELD = new ParseField("write_load");
+    public static final ParseField AVERAGE_SIZE_FIELD = new ParseField("avg_size");
+
+    @SuppressWarnings("unchecked")
+    private static final ConstructingObjectParser<IndexMetadataStats, Void> PARSER = new ConstructingObjectParser<>(
+        "index_metadata_stats_parser",
+        false,
+        (args, unused) -> new IndexMetadataStats((IndexWriteLoad) args[0], (AverageShardSize) args[1])
+    );
+
+    static {
+        PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> IndexWriteLoad.fromXContent(p), WRITE_LOAD_FIELD);
+        PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> AverageShardSize.fromXContent(p), AVERAGE_SIZE_FIELD);
+    }
+
+    // Visible for testing
+    public IndexMetadataStats(IndexWriteLoad indexWriteLoad, long totalSizeInBytes, int numberOfShards) {
+        this(indexWriteLoad, new AverageShardSize(totalSizeInBytes, numberOfShards));
+    }
+
+    public IndexMetadataStats(StreamInput in) throws IOException {
+        this(new IndexWriteLoad(in), new AverageShardSize(in));
+    }
+
+    public IndexMetadataStats {
+        Objects.requireNonNull(indexWriteLoad, "Expected a non null index write load");
+        Objects.requireNonNull(averageShardSize, "Expected a non null average shard size");
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        indexWriteLoad.writeTo(out);
+        averageShardSize.writeTo(out);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject(WRITE_LOAD_FIELD.getPreferredName());
+        indexWriteLoad.toXContent(builder, params);
+        builder.endObject();
+
+        builder.startObject(AVERAGE_SIZE_FIELD.getPreferredName());
+        averageShardSize.toXContent(builder, params);
+        builder.endObject();
+        return builder;
+    }
+
+    public static IndexMetadataStats fromXContent(XContentParser parser) throws IOException {
+        return PARSER.parse(parser, null);
+    }
+
+    @Nullable
+    public static IndexMetadataStats fromStatsResponse(IndexMetadata indexMetadata, @Nullable IndicesStatsResponse indicesStatsResponse) {
+        if (indicesStatsResponse == null) {
+            return null;
+        }
+        final IndexStats indexStats = indicesStatsResponse.getIndex(indexMetadata.getIndex().getName());
+        if (indexStats == null) {
+            return null;
+        }
+
+        long totalSizeInBytes = 0;
+        int shardsTookIntoAccountForSizeAvg = 0;
+        final int numberOfShards = indexMetadata.getNumberOfShards();
+        final var indexWriteLoadBuilder = IndexWriteLoad.builder(numberOfShards);
+        final var indexShards = indexStats.getIndexShards();
+        for (IndexShardStats indexShardsStats : indexShards.values()) {
+            final var shardStats = Arrays.stream(indexShardsStats.getShards())
+                .filter(stats -> stats.getShardRouting().primary())
+                .findFirst()
+                // Fallback to a replica if for some reason we couldn't find the primary stats
+                .orElse(indexShardsStats.getAt(0));
+            final var commonStats = shardStats.getStats();
+            final var indexingShardStats = commonStats.getIndexing().getTotal();
+            indexWriteLoadBuilder.withShardWriteLoad(
+                shardStats.getShardRouting().id(),
+                indexingShardStats.getWriteLoad(),
+                indexingShardStats.getTotalActiveTimeInMillis()
+            );
+            totalSizeInBytes += commonStats.getDocs().getTotalSizeInBytes();
+            shardsTookIntoAccountForSizeAvg++;
+        }
+
+        return new IndexMetadataStats(
+            indexWriteLoadBuilder.build(),
+            new AverageShardSize(totalSizeInBytes, shardsTookIntoAccountForSizeAvg)
+        );
+    }
+
+    public IndexWriteLoad writeLoad() {
+        return indexWriteLoad;
+    }
+
+    public record AverageShardSize(long totalSizeInBytes, int numberOfShards) implements Writeable, ToXContentFragment {
+
+        public static final ParseField TOTAL_SIZE_IN_BYTES_FIELD = new ParseField("total_size_in_bytes");
+        public static final ParseField SHARD_COUNT_FIELD = new ParseField("shard_count");
+
+        @SuppressWarnings("unchecked")
+        private static final ConstructingObjectParser<AverageShardSize, Void> PARSER = new ConstructingObjectParser<>(
+            "average_shard_size",
+            false,
+            (args, unused) -> new AverageShardSize((long) args[0], (int) args[1])
+        );
+
+        static {
+            PARSER.declareLong(ConstructingObjectParser.constructorArg(), TOTAL_SIZE_IN_BYTES_FIELD);
+            PARSER.declareInt(ConstructingObjectParser.constructorArg(), SHARD_COUNT_FIELD);
+        }
+
+        public AverageShardSize {
+            assert numberOfShards > 0;
+        }
+
+        AverageShardSize(StreamInput in) throws IOException {
+            this(in.readLong(), in.readInt());
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeLong(totalSizeInBytes);
+            out.writeInt(numberOfShards);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.field(TOTAL_SIZE_IN_BYTES_FIELD.getPreferredName(), totalSizeInBytes);
+            builder.field(SHARD_COUNT_FIELD.getPreferredName(), numberOfShards);
+            return builder;
+        }
+
+        static AverageShardSize fromXContent(XContentParser parser) throws IOException {
+            return PARSER.parse(parser, null);
+        }
+
+        public long getAverageSizeInBytes() {
+            return totalSizeInBytes / numberOfShards;
+        }
+    }
+}

+ 1 - 36
server/src/main/java/org/elasticsearch/index/shard/IndexWriteLoad.java → server/src/main/java/org/elasticsearch/cluster/metadata/IndexWriteLoad.java

@@ -6,16 +6,11 @@
  * Side Public License, v 1.
  */
 
-package org.elasticsearch.index.shard;
+package org.elasticsearch.cluster.metadata;
 
-import org.elasticsearch.action.admin.indices.stats.IndexShardStats;
-import org.elasticsearch.action.admin.indices.stats.IndexStats;
-import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse;
-import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentFragment;
@@ -69,36 +64,6 @@ public class IndexWriteLoad implements Writeable, ToXContentFragment {
         );
     }
 
-    @Nullable
-    public static IndexWriteLoad fromStats(IndexMetadata indexMetadata, @Nullable IndicesStatsResponse indicesStatsResponse) {
-        if (indicesStatsResponse == null) {
-            return null;
-        }
-
-        final IndexStats indexStats = indicesStatsResponse.getIndex(indexMetadata.getIndex().getName());
-        if (indexStats == null) {
-            return null;
-        }
-
-        final int numberOfShards = indexMetadata.getNumberOfShards();
-        final var indexWriteLoadBuilder = IndexWriteLoad.builder(numberOfShards);
-        final var indexShards = indexStats.getIndexShards();
-        for (IndexShardStats indexShardsStats : indexShards.values()) {
-            final var shardStats = Arrays.stream(indexShardsStats.getShards())
-                .filter(stats -> stats.getShardRouting().primary())
-                .findFirst()
-                // Fallback to a replica if for some reason we couldn't find the primary stats
-                .orElse(indexShardsStats.getAt(0));
-            final var indexingShardStats = shardStats.getStats().getIndexing().getTotal();
-            indexWriteLoadBuilder.withShardWriteLoad(
-                shardStats.getShardRouting().id(),
-                indexingShardStats.getWriteLoad(),
-                indexingShardStats.getTotalActiveTimeInMillis()
-            );
-        }
-        return indexWriteLoadBuilder.build();
-    }
-
     private final double[] shardWriteLoad;
     private final long[] shardUptimeInMillis;
 

+ 38 - 19
server/src/test/java/org/elasticsearch/cluster/ClusterStateTests.java

@@ -15,7 +15,9 @@ import org.elasticsearch.cluster.block.ClusterBlocks;
 import org.elasticsearch.cluster.coordination.CoordinationMetadata;
 import org.elasticsearch.cluster.metadata.AliasMetadata;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.IndexMetadataStats;
 import org.elasticsearch.cluster.metadata.IndexTemplateMetadata;
+import org.elasticsearch.cluster.metadata.IndexWriteLoad;
 import org.elasticsearch.cluster.metadata.MappingMetadata;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -32,7 +34,6 @@ import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.gateway.GatewayService;
 import org.elasticsearch.index.Index;
-import org.elasticsearch.index.shard.IndexWriteLoad;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.test.ESTestCase;
@@ -283,9 +284,15 @@ public class ClusterStateTests extends ESTestCase {
                     "timestamp_range": {
                       "shards": []
                     },
-                    "write_load": {
-                      "loads": [-1.0],
-                      "uptimes": [-1]
+                    "stats": {
+                        "write_load": {
+                          "loads": [-1.0],
+                          "uptimes": [-1]
+                        },
+                        "avg_size": {
+                            "total_size_in_bytes": 120,
+                            "shard_count": 1
+                        }
                     },
                     "write_load_forecast" : 8.0
                   }
@@ -500,13 +507,19 @@ public class ClusterStateTests extends ESTestCase {
                     "timestamp_range" : {
                       "shards" : [ ]
                     },
-                    "write_load" : {
-                      "loads" : [
-                        -1.0
-                      ],
-                      "uptimes" : [
-                        -1
-                      ]
+                    "stats" : {
+                      "write_load" : {
+                        "loads" : [
+                          -1.0
+                        ],
+                        "uptimes" : [
+                          -1
+                        ]
+                      },
+                      "avg_size" : {
+                        "total_size_in_bytes" : 120,
+                        "shard_count" : 1
+                      }
                     },
                     "write_load_forecast" : 8.0
                   }
@@ -728,13 +741,19 @@ public class ClusterStateTests extends ESTestCase {
                     "timestamp_range" : {
                       "shards" : [ ]
                     },
-                    "write_load" : {
-                      "loads" : [
-                        -1.0
-                      ],
-                      "uptimes" : [
-                        -1
-                      ]
+                    "stats" : {
+                      "write_load" : {
+                        "loads" : [
+                          -1.0
+                        ],
+                        "uptimes" : [
+                          -1
+                        ]
+                      },
+                      "avg_size" : {
+                        "total_size_in_bytes" : 120,
+                        "shard_count" : 1
+                      }
                     },
                     "write_load_forecast" : 8.0
                   }
@@ -925,7 +944,7 @@ public class ClusterStateTests extends ESTestCase {
             })
             .numberOfReplicas(2)
             .putRolloverInfo(new RolloverInfo("rolloveAlias", new ArrayList<>(), 1L))
-            .indexWriteLoad(IndexWriteLoad.builder(1).build())
+            .stats(new IndexMetadataStats(IndexWriteLoad.builder(1).build(), 120, 1))
             .indexWriteLoadForecast(8.0)
             .build();
 

+ 18 - 16
server/src/test/java/org/elasticsearch/index/shard/IndexWriteLoadSerializationTests.java → server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataStatsSerializationTests.java

@@ -6,7 +6,7 @@
  * Side Public License, v 1.
  */
 
-package org.elasticsearch.index.shard;
+package org.elasticsearch.cluster.metadata;
 
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.test.AbstractXContentSerializingTestCase;
@@ -14,47 +14,49 @@ import org.elasticsearch.xcontent.XContentParser;
 
 import java.io.IOException;
 
-public class IndexWriteLoadSerializationTests extends AbstractXContentSerializingTestCase<IndexWriteLoad> {
+public class IndexMetadataStatsSerializationTests extends AbstractXContentSerializingTestCase<IndexMetadataStats> {
 
     @Override
-    protected IndexWriteLoad doParseInstance(XContentParser parser) throws IOException {
-        return IndexWriteLoad.fromXContent(parser);
+    protected IndexMetadataStats doParseInstance(XContentParser parser) throws IOException {
+        return IndexMetadataStats.fromXContent(parser);
     }
 
     @Override
-    protected Writeable.Reader<IndexWriteLoad> instanceReader() {
-        return IndexWriteLoad::new;
+    protected Writeable.Reader<IndexMetadataStats> instanceReader() {
+        return IndexMetadataStats::new;
     }
 
     @Override
-    protected IndexWriteLoad createTestInstance() {
+    protected IndexMetadataStats createTestInstance() {
         final int numberOfShards = randomIntBetween(1, 10);
         final var indexWriteLoad = IndexWriteLoad.builder(numberOfShards);
         for (int i = 0; i < numberOfShards; i++) {
             indexWriteLoad.withShardWriteLoad(i, randomDoubleBetween(1, 10, true), randomLongBetween(1, 1000));
         }
-        return indexWriteLoad.build();
+        return new IndexMetadataStats(indexWriteLoad.build(), randomLongBetween(1024, 10240), randomIntBetween(1, 4));
     }
 
     @Override
-    protected IndexWriteLoad mutateInstance(IndexWriteLoad instance) throws IOException {
+    protected IndexMetadataStats mutateInstance(IndexMetadataStats originalStats) throws IOException {
+        final IndexWriteLoad originalWriteLoad = originalStats.writeLoad();
+
         final int newNumberOfShards;
-        if (instance.numberOfShards() > 1 && randomBoolean()) {
-            newNumberOfShards = randomIntBetween(1, instance.numberOfShards() - 1);
+        if (originalWriteLoad.numberOfShards() > 1 && randomBoolean()) {
+            newNumberOfShards = randomIntBetween(1, originalWriteLoad.numberOfShards() - 1);
         } else {
-            newNumberOfShards = instance.numberOfShards() + randomIntBetween(1, 5);
+            newNumberOfShards = originalWriteLoad.numberOfShards() + randomIntBetween(1, 5);
         }
         final var indexWriteLoad = IndexWriteLoad.builder(newNumberOfShards);
         for (int i = 0; i < newNumberOfShards; i++) {
-            boolean existingShard = i < instance.numberOfShards();
+            boolean existingShard = i < originalWriteLoad.numberOfShards();
             double shardLoad = existingShard && randomBoolean()
-                ? instance.getWriteLoadForShard(i).getAsDouble()
+                ? originalWriteLoad.getWriteLoadForShard(i).getAsDouble()
                 : randomDoubleBetween(0, 128, true);
             long uptimeInMillis = existingShard && randomBoolean()
-                ? instance.getUptimeInMillisForShard(i).getAsLong()
+                ? originalWriteLoad.getUptimeInMillisForShard(i).getAsLong()
                 : randomNonNegativeLong();
             indexWriteLoad.withShardWriteLoad(i, shardLoad, uptimeInMillis);
         }
-        return indexWriteLoad.build();
+        return new IndexMetadataStats(indexWriteLoad.build(), randomLongBetween(1024, 10240), randomIntBetween(1, 4));
     }
 }

+ 19 - 39
server/src/test/java/org/elasticsearch/index/shard/IndexWriteLoadTests.java → server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataStatsTests.java

@@ -6,7 +6,7 @@
  * Side Public License, v 1.
  */
 
-package org.elasticsearch.index.shard;
+package org.elasticsearch.cluster.metadata;
 
 import org.elasticsearch.Version;
 import org.elasticsearch.action.admin.indices.stats.CommonStats;
@@ -15,7 +15,6 @@ import org.elasticsearch.action.admin.indices.stats.IndexShardStats;
 import org.elasticsearch.action.admin.indices.stats.IndexStats;
 import org.elasticsearch.action.admin.indices.stats.IndicesStatsResponse;
 import org.elasticsearch.action.admin.indices.stats.ShardStats;
-import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.routing.RecoverySource;
 import org.elasticsearch.cluster.routing.ShardRouting;
 import org.elasticsearch.cluster.routing.ShardRoutingHelper;
@@ -23,6 +22,9 @@ import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.index.shard.DocsStats;
+import org.elasticsearch.index.shard.IndexingStats;
+import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.test.ESTestCase;
 
 import java.util.Map;
@@ -33,37 +35,7 @@ import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
-public class IndexWriteLoadTests extends ESTestCase {
-
-    public void testGetWriteLoadForShardAndGetUptimeInMillisForShard() {
-        final int numberOfPopulatedShards = 10;
-        final int numberOfShards = randomIntBetween(numberOfPopulatedShards, 20);
-        final IndexWriteLoad.Builder indexWriteLoadBuilder = IndexWriteLoad.builder(numberOfShards);
-
-        final double[] populatedShardWriteLoads = new double[numberOfPopulatedShards];
-        final long[] populatedShardUptimes = new long[numberOfPopulatedShards];
-        for (int shardId = 0; shardId < numberOfPopulatedShards; shardId++) {
-            double writeLoad = randomDoubleBetween(1, 128, true);
-            long uptimeInMillis = randomNonNegativeLong();
-            populatedShardWriteLoads[shardId] = writeLoad;
-            populatedShardUptimes[shardId] = uptimeInMillis;
-            indexWriteLoadBuilder.withShardWriteLoad(shardId, writeLoad, uptimeInMillis);
-        }
-
-        final IndexWriteLoad indexWriteLoad = indexWriteLoadBuilder.build();
-        for (int shardId = 0; shardId < numberOfShards; shardId++) {
-            if (shardId < numberOfPopulatedShards) {
-                assertThat(indexWriteLoad.getWriteLoadForShard(shardId).isPresent(), is(equalTo(true)));
-                assertThat(indexWriteLoad.getWriteLoadForShard(shardId).getAsDouble(), is(equalTo(populatedShardWriteLoads[shardId])));
-                assertThat(indexWriteLoad.getUptimeInMillisForShard(shardId).isPresent(), is(equalTo(true)));
-                assertThat(indexWriteLoad.getUptimeInMillisForShard(shardId).getAsLong(), is(equalTo(populatedShardUptimes[shardId])));
-            } else {
-                assertThat(indexWriteLoad.getWriteLoadForShard(shardId).isPresent(), is(false));
-                assertThat(indexWriteLoad.getUptimeInMillisForShard(shardId).isPresent(), is(false));
-            }
-        }
-    }
-
+public class IndexMetadataStatsTests extends ESTestCase {
     public void testFromStatsCreation() {
         final String indexName = "idx";
         final IndexMetadata indexMetadata = IndexMetadata.builder(indexName)
@@ -83,23 +55,25 @@ public class IndexWriteLoadTests extends ESTestCase {
         final IndexShardStats indexShard0Stats = new IndexShardStats(
             new ShardId(indexName, "__na__", 0),
             new ShardStats[] {
-                createShardStats(indexName, 0, true, TimeValue.timeValueMillis(2048).nanos(), TimeValue.timeValueMillis(1024).nanos()),
-                createShardStats(indexName, 0, false, TimeValue.timeValueMillis(2048).nanos(), TimeValue.timeValueMillis(512).nanos()) }
+                createShardStats(indexName, 0, true, TimeValue.timeValueMillis(2048).nanos(), TimeValue.timeValueMillis(1024).nanos(), 15),
+                createShardStats(indexName, 0, false, TimeValue.timeValueMillis(2048).nanos(), TimeValue.timeValueMillis(512).nanos(), 16) }
         );
 
         // Shard 1 only has a replica available
         final IndexShardStats indexShard1Stats = new IndexShardStats(
             new ShardId(indexName, "__na__", 1),
             new ShardStats[] {
-                createShardStats(indexName, 1, false, TimeValue.timeValueMillis(4096).nanos(), TimeValue.timeValueMillis(512).nanos()) }
+                createShardStats(indexName, 1, false, TimeValue.timeValueMillis(4096).nanos(), TimeValue.timeValueMillis(512).nanos(), 30) }
         );
         // Shard 2 was not available
 
         when(response.getIndex(indexName)).thenReturn(indexStats);
         when(indexStats.getIndexShards()).thenReturn(Map.of(0, indexShard0Stats, 1, indexShard1Stats));
 
+        final IndexMetadataStats indexMetadataStats = IndexMetadataStats.fromStatsResponse(indexMetadata, response);
+
         // Shard 0 uses the results from the primary
-        final IndexWriteLoad indexWriteLoadFromStats = IndexWriteLoad.fromStats(indexMetadata, response);
+        final IndexWriteLoad indexWriteLoadFromStats = indexMetadataStats.writeLoad();
         assertThat(indexWriteLoadFromStats.getWriteLoadForShard(0).isPresent(), is(equalTo(true)));
         assertThat(indexWriteLoadFromStats.getWriteLoadForShard(0).getAsDouble(), is(equalTo(2.0)));
         assertThat(indexWriteLoadFromStats.getUptimeInMillisForShard(0).isPresent(), is(equalTo(true)));
@@ -114,7 +88,11 @@ public class IndexWriteLoadTests extends ESTestCase {
         assertThat(indexWriteLoadFromStats.getWriteLoadForShard(2).isPresent(), is(equalTo(false)));
         assertThat(indexWriteLoadFromStats.getUptimeInMillisForShard(2).isPresent(), is(equalTo(false)));
 
-        assertThat(IndexWriteLoad.fromStats(indexMetadata, null), is(nullValue()));
+        final long averageShardSize = indexMetadataStats.averageShardSize().getAverageSizeInBytes();
+        // (shard_0 = 15 + shard_1 = 30) / 2
+        assertThat(averageShardSize, is(equalTo(22L)));
+
+        assertThat(IndexMetadataStats.fromStatsResponse(indexMetadata, null), is(nullValue()));
     }
 
     private ShardStats createShardStats(
@@ -122,7 +100,8 @@ public class IndexWriteLoadTests extends ESTestCase {
         int shard,
         boolean primary,
         long totalIndexingTimeSinceShardStartedInNanos,
-        long totalActiveTimeInNanos
+        long totalActiveTimeInNanos,
+        long sizeInBytes
     ) {
         RecoverySource recoverySource = primary
             ? RecoverySource.EmptyStoreRecoverySource.INSTANCE
@@ -137,6 +116,7 @@ public class IndexWriteLoadTests extends ESTestCase {
         shardRouting = ShardRoutingHelper.moveToStarted(shardRouting);
 
         final CommonStats commonStats = new CommonStats(CommonStatsFlags.ALL);
+        commonStats.getDocs().add(new DocsStats(1, 0, sizeInBytes));
         commonStats.getIndexing()
             .getTotal()
             .add(

+ 10 - 7
server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java

@@ -27,7 +27,6 @@ import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.core.SuppressForbidden;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.index.shard.IndexWriteLoad;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.indices.IndicesModule;
 import org.elasticsearch.test.ESTestCase;
@@ -75,8 +74,9 @@ public class IndexMetadataTests extends ESTestCase {
         Map<String, String> customMap = new HashMap<>();
         customMap.put(randomAlphaOfLength(5), randomAlphaOfLength(10));
         customMap.put(randomAlphaOfLength(10), randomAlphaOfLength(15));
-        IndexWriteLoad indexWriteLoad = randomBoolean() ? randomWriteLoad(numShard) : null;
+        IndexMetadataStats indexStats = randomBoolean() ? randomIndexStats(numShard) : null;
         Double indexWriteLoadForecast = randomBoolean() ? randomDoubleBetween(0.0, 128, true) : null;
+        Long shardSizeInBytesForecast = randomBoolean() ? randomLongBetween(1024, 10240) : null;
         IndexMetadata metadata = IndexMetadata.builder("foo")
             .settings(
                 Settings.builder()
@@ -103,8 +103,9 @@ public class IndexMetadataTests extends ESTestCase {
                     randomNonNegativeLong()
                 )
             )
-            .indexWriteLoad(indexWriteLoad)
+            .stats(indexStats)
             .indexWriteLoadForecast(indexWriteLoadForecast)
+            .shardSizeInBytesForecast(shardSizeInBytesForecast)
             .build();
         assertEquals(system, metadata.isSystem());
 
@@ -133,8 +134,9 @@ public class IndexMetadataTests extends ESTestCase {
         Map<String, DiffableStringMap> expectedCustom = Map.of("my_custom", new DiffableStringMap(customMap));
         assertEquals(metadata.getCustomData(), expectedCustom);
         assertEquals(metadata.getCustomData(), fromXContentMeta.getCustomData());
-        assertEquals(metadata.getWriteLoad(), fromXContentMeta.getWriteLoad());
+        assertEquals(metadata.getStats(), fromXContentMeta.getStats());
         assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad());
+        assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes());
 
         final BytesStreamOutput out = new BytesStreamOutput();
         metadata.writeTo(out);
@@ -155,8 +157,9 @@ public class IndexMetadataTests extends ESTestCase {
             assertEquals(deserialized.getCustomData(), expectedCustom);
             assertEquals(metadata.getCustomData(), deserialized.getCustomData());
             assertEquals(metadata.isSystem(), deserialized.isSystem());
-            assertEquals(metadata.getWriteLoad(), deserialized.getWriteLoad());
+            assertEquals(metadata.getStats(), deserialized.getStats());
             assertEquals(metadata.getForecastedWriteLoad(), fromXContentMeta.getForecastedWriteLoad());
+            assertEquals(metadata.getForecastedShardSizeInBytes(), fromXContentMeta.getForecastedShardSizeInBytes());
         }
     }
 
@@ -515,12 +518,12 @@ public class IndexMetadataTests extends ESTestCase {
             .build();
     }
 
-    private IndexWriteLoad randomWriteLoad(int numberOfShards) {
+    private IndexMetadataStats randomIndexStats(int numberOfShards) {
         IndexWriteLoad.Builder indexWriteLoadBuilder = IndexWriteLoad.builder(numberOfShards);
         int numberOfPopulatedWriteLoads = randomIntBetween(0, numberOfShards);
         for (int i = 0; i < numberOfPopulatedWriteLoads; i++) {
             indexWriteLoadBuilder.withShardWriteLoad(i, randomDoubleBetween(0.0, 128.0, true), randomNonNegativeLong());
         }
-        return indexWriteLoadBuilder.build();
+        return new IndexMetadataStats(indexWriteLoadBuilder.build(), randomLongBetween(100, 1024), randomIntBetween(1, 2));
     }
 }

+ 46 - 0
server/src/test/java/org/elasticsearch/cluster/metadata/IndexWriteLoadTests.java

@@ -0,0 +1,46 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.cluster.metadata;
+
+import org.elasticsearch.test.ESTestCase;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+
+public class IndexWriteLoadTests extends ESTestCase {
+
+    public void testGetWriteLoadForShardAndGetUptimeInMillisForShard() {
+        final int numberOfPopulatedShards = 10;
+        final int numberOfShards = randomIntBetween(numberOfPopulatedShards, 20);
+        final IndexWriteLoad.Builder indexWriteLoadBuilder = IndexWriteLoad.builder(numberOfShards);
+
+        final double[] populatedShardWriteLoads = new double[numberOfPopulatedShards];
+        final long[] populatedShardUptimes = new long[numberOfPopulatedShards];
+        for (int shardId = 0; shardId < numberOfPopulatedShards; shardId++) {
+            double writeLoad = randomDoubleBetween(1, 128, true);
+            long uptimeInMillis = randomNonNegativeLong();
+            populatedShardWriteLoads[shardId] = writeLoad;
+            populatedShardUptimes[shardId] = uptimeInMillis;
+            indexWriteLoadBuilder.withShardWriteLoad(shardId, writeLoad, uptimeInMillis);
+        }
+
+        final IndexWriteLoad indexWriteLoad = indexWriteLoadBuilder.build();
+        for (int shardId = 0; shardId < numberOfShards; shardId++) {
+            if (shardId < numberOfPopulatedShards) {
+                assertThat(indexWriteLoad.getWriteLoadForShard(shardId).isPresent(), is(equalTo(true)));
+                assertThat(indexWriteLoad.getWriteLoadForShard(shardId).getAsDouble(), is(equalTo(populatedShardWriteLoads[shardId])));
+                assertThat(indexWriteLoad.getUptimeInMillisForShard(shardId).isPresent(), is(equalTo(true)));
+                assertThat(indexWriteLoad.getUptimeInMillisForShard(shardId).getAsLong(), is(equalTo(populatedShardUptimes[shardId])));
+            } else {
+                assertThat(indexWriteLoad.getWriteLoadForShard(shardId).isPresent(), is(false));
+                assertThat(indexWriteLoad.getUptimeInMillisForShard(shardId).isPresent(), is(false));
+            }
+        }
+    }
+}

+ 5 - 2
x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java

@@ -9,6 +9,8 @@ package org.elasticsearch.xpack.writeloadforecaster;
 
 import org.elasticsearch.cluster.metadata.DataStream;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.IndexMetadataStats;
+import org.elasticsearch.cluster.metadata.IndexWriteLoad;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.routing.allocation.WriteLoadForecaster;
 import org.elasticsearch.common.settings.ClusterSettings;
@@ -17,7 +19,6 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.SuppressForbidden;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.Index;
-import org.elasticsearch.index.shard.IndexWriteLoad;
 import org.elasticsearch.threadpool.ThreadPool;
 
 import java.util.List;
@@ -74,7 +75,9 @@ class LicensedWriteLoadForecaster implements WriteLoadForecaster {
         final List<IndexWriteLoad> indicesWriteLoadWithinMaxAgeRange = getIndicesWithinMaxAgeRange(dataStream, metadata).stream()
             .filter(index -> index.equals(dataStream.getWriteIndex()) == false)
             .map(metadata::getSafe)
-            .map(IndexMetadata::getWriteLoad)
+            .map(IndexMetadata::getStats)
+            .filter(Objects::nonNull)
+            .map(IndexMetadataStats::writeLoad)
             .filter(Objects::nonNull)
             .toList();
 

+ 3 - 2
x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java

@@ -10,13 +10,14 @@ package org.elasticsearch.xpack.writeloadforecaster;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.metadata.DataStream;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.IndexMetadataStats;
+import org.elasticsearch.cluster.metadata.IndexWriteLoad;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.routing.allocation.WriteLoadForecaster;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.IndexMode;
-import org.elasticsearch.index.shard.IndexWriteLoad;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -373,7 +374,7 @@ public class LicensedWriteLoadForecasterTests extends ESTestCase {
                     .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
                     .build()
             )
-            .indexWriteLoad(indexWriteLoad)
+            .stats(indexWriteLoad == null ? null : new IndexMetadataStats(indexWriteLoad, 1, 1))
             .creationDate(createdAt)
             .build();
     }