Browse Source

Add service for computing the optimal number of shards for data streams (#105498)

This adds the `DataStreamAutoShardingService` that will compute the
optimal number of shards for a data stream and return a recommendation
as to when to apply it (a time interval we call cool down which is 0
when the auto sharding recommendation can be applied immediately).

This also introduces a `DataStreamAutoShardingEvent` object that will be
stored in the data stream metadata to indicate the last auto sharding
event that was applied to a data stream and its cluster state
representation looks like so:

```
"auto_sharding": {
 "trigger_index_name": ".ds-logs-nginx-2024.02.12-000002",
 "target_number_of_shards": 3,
 "event_timestamp": 1707739707954
}
```

The auto sharding service is not used in this PR, so the auto sharding
event will not be stored in the data stream metadata, but the required
infrastructure to configure it is in place.
Andrei Dan 1 year ago
parent
commit
882b92ab60
25 changed files with 1681 additions and 119 deletions
  1. 2 1
      modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/DataStreamIT.java
  2. 2 1
      modules/data-streams/src/test/java/org/elasticsearch/datastreams/DataStreamIndexSettingsProviderTests.java
  3. 2 1
      modules/data-streams/src/test/java/org/elasticsearch/datastreams/UpdateTimeSeriesRangeServiceTests.java
  4. 4 2
      modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsResponseTests.java
  5. 2 1
      modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java
  6. 1 0
      server/src/main/java/module-info.java
  7. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  8. 17 0
      server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java
  9. 59 0
      server/src/main/java/org/elasticsearch/action/datastreams/autosharding/AutoShardingResult.java
  10. 21 0
      server/src/main/java/org/elasticsearch/action/datastreams/autosharding/AutoShardingType.java
  11. 415 0
      server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java
  12. 88 18
      server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java
  13. 84 0
      server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamAutoShardingEvent.java
  14. 2 1
      server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateDataStreamService.java
  15. 4 2
      server/src/main/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsService.java
  16. 2 1
      server/src/main/java/org/elasticsearch/snapshots/RestoreService.java
  17. 771 0
      server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java
  18. 62 0
      server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamAutoShardingEventTests.java
  19. 117 7
      server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java
  20. 2 1
      server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsServiceTests.java
  21. 10 1
      test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java
  22. 4 2
      x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportPutFollowAction.java
  23. 2 1
      x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/action/DataStreamLifecycleUsageTransportActionIT.java
  24. 7 20
      x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java
  25. 0 59
      x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java

+ 2 - 1
modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/DataStreamIT.java

@@ -1790,7 +1790,8 @@ public class DataStreamIT extends ESIntegTestCase {
                         original.getIndexMode(),
                         original.getLifecycle(),
                         original.isFailureStore(),
-                        original.getFailureIndices()
+                        original.getFailureIndices(),
+                        null
                     );
                     brokenDataStreamHolder.set(broken);
                     return ClusterState.builder(currentState)

+ 2 - 1
modules/data-streams/src/test/java/org/elasticsearch/datastreams/DataStreamIndexSettingsProviderTests.java

@@ -314,7 +314,8 @@ public class DataStreamIndexSettingsProviderTests extends ESTestCase {
                 IndexMode.TIME_SERIES,
                 ds.getLifecycle(),
                 ds.isFailureStore(),
-                ds.getFailureIndices()
+                ds.getFailureIndices(),
+                null
             )
         );
         Metadata metadata = mb.build();

+ 2 - 1
modules/data-streams/src/test/java/org/elasticsearch/datastreams/UpdateTimeSeriesRangeServiceTests.java

@@ -153,7 +153,8 @@ public class UpdateTimeSeriesRangeServiceTests extends ESTestCase {
                     d.getIndexMode(),
                     d.getLifecycle(),
                     d.isFailureStore(),
-                    d.getFailureIndices()
+                    d.getFailureIndices(),
+                    null
                 )
             )
             .build();

+ 4 - 2
modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsResponseTests.java

@@ -89,7 +89,8 @@ public class GetDataStreamsResponseTests extends AbstractWireSerializingTestCase
                 IndexMode.STANDARD,
                 new DataStreamLifecycle(),
                 true,
-                failureStores
+                failureStores,
+                null
             );
 
             String ilmPolicyName = "rollover-30days";
@@ -198,7 +199,8 @@ public class GetDataStreamsResponseTests extends AbstractWireSerializingTestCase
                 IndexMode.STANDARD,
                 new DataStreamLifecycle(null, null, false),
                 true,
-                failureStores
+                failureStores,
+                null
             );
 
             String ilmPolicyName = "rollover-30days";

+ 2 - 1
modules/data-streams/src/test/java/org/elasticsearch/datastreams/lifecycle/DataStreamLifecycleServiceTests.java

@@ -295,7 +295,8 @@ public class DataStreamLifecycleServiceTests extends ESTestCase {
                 dataStream.getIndexMode(),
                 DataStreamLifecycle.newBuilder().dataRetention(0L).build(),
                 dataStream.isFailureStore(),
-                dataStream.getFailureIndices()
+                dataStream.getFailureIndices(),
+                null
             )
         );
         clusterState = ClusterState.builder(clusterState).metadata(builder).build();

+ 1 - 0
server/src/main/java/module-info.java

@@ -381,6 +381,7 @@ module org.elasticsearch.server {
     opens org.elasticsearch.common.logging to org.apache.logging.log4j.core;
 
     exports org.elasticsearch.action.datastreams.lifecycle;
+    exports org.elasticsearch.action.datastreams.autosharding;
     exports org.elasticsearch.action.downsample;
     exports org.elasticsearch.plugins.internal
         to

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -135,6 +135,7 @@ public class TransportVersions {
     public static final TransportVersion ML_MODEL_IN_SERVICE_SETTINGS = def(8_595_00_0);
     public static final TransportVersion RANDOM_AGG_SHARD_SEED = def(8_596_00_0);
     public static final TransportVersion ESQL_TIMINGS = def(8_597_00_0);
+    public static final TransportVersion DATA_STREAM_AUTO_SHARDING_EVENT = def(8_598_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 17 - 0
server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java

@@ -18,6 +18,7 @@ import org.elasticsearch.action.support.master.MasterNodeReadRequest;
 import org.elasticsearch.cluster.SimpleDiffable;
 import org.elasticsearch.cluster.health.ClusterHealthStatus;
 import org.elasticsearch.cluster.metadata.DataStream;
+import org.elasticsearch.cluster.metadata.DataStreamAutoShardingEvent;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -37,6 +38,7 @@ import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.TransportVersions.V_8_11_X;
+import static org.elasticsearch.cluster.metadata.DataStream.AUTO_SHARDING_FIELD;
 
 public class GetDataStreamAction extends ActionType<GetDataStreamAction.Response> {
 
@@ -179,6 +181,10 @@ public class GetDataStreamAction extends ActionType<GetDataStreamAction.Response
             public static final ParseField TEMPORAL_RANGES = new ParseField("temporal_ranges");
             public static final ParseField TEMPORAL_RANGE_START = new ParseField("start");
             public static final ParseField TEMPORAL_RANGE_END = new ParseField("end");
+            public static final ParseField TIME_SINCE_LAST_AUTO_SHARD_EVENT = new ParseField("time_since_last_auto_shard_event");
+            public static final ParseField TIME_SINCE_LAST_AUTO_SHARD_EVENT_MILLIS = new ParseField(
+                "time_since_last_auto_shard_event_millis"
+            );
 
             private final DataStream dataStream;
             private final ClusterHealthStatus dataStreamStatus;
@@ -348,6 +354,17 @@ public class GetDataStreamAction extends ActionType<GetDataStreamAction.Response
                 if (DataStream.isFailureStoreEnabled()) {
                     builder.field(DataStream.FAILURE_STORE_FIELD.getPreferredName(), dataStream.isFailureStore());
                 }
+                if (dataStream.getAutoShardingEvent() != null) {
+                    DataStreamAutoShardingEvent autoShardingEvent = dataStream.getAutoShardingEvent();
+                    builder.startObject(AUTO_SHARDING_FIELD.getPreferredName());
+                    autoShardingEvent.toXContent(builder, params);
+                    builder.humanReadableField(
+                        TIME_SINCE_LAST_AUTO_SHARD_EVENT_MILLIS.getPreferredName(),
+                        TIME_SINCE_LAST_AUTO_SHARD_EVENT.getPreferredName(),
+                        autoShardingEvent.getTimeSinceLastAutoShardingEvent(System::currentTimeMillis)
+                    );
+                    builder.endObject();
+                }
                 if (timeSeries != null) {
                     builder.startObject(TIME_SERIES.getPreferredName());
                     builder.startArray(TEMPORAL_RANGES.getPreferredName());

+ 59 - 0
server/src/main/java/org/elasticsearch/action/datastreams/autosharding/AutoShardingResult.java

@@ -0,0 +1,59 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.datastreams.autosharding;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+
+import java.util.Arrays;
+
+import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.COOLDOWN_PREVENTED_DECREASE;
+import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.COOLDOWN_PREVENTED_INCREASE;
+
+/**
+ * Represents an auto sharding recommendation. It includes the current and target number of shards together with a remaining cooldown
+ * period that needs to lapse before the current recommendation should be applied.
+ * <p>
+ * If auto sharding is not applicable for a data stream (e.g. due to
+ * {@link DataStreamAutoShardingService#DATA_STREAMS_AUTO_SHARDING_EXCLUDES_SETTING}) the target number of shards will be -1 and cool down
+ * remaining {@link TimeValue#MAX_VALUE}.
+ */
+public record AutoShardingResult(
+    AutoShardingType type,
+    int currentNumberOfShards,
+    int targetNumberOfShards,
+    TimeValue coolDownRemaining,
+    @Nullable Double writeLoad
+) {
+
+    static final String COOLDOWN_PREVENTING_TYPES = Arrays.toString(
+        new AutoShardingType[] { COOLDOWN_PREVENTED_DECREASE, COOLDOWN_PREVENTED_INCREASE }
+    );
+
+    public AutoShardingResult {
+        if (type.equals(AutoShardingType.INCREASE_SHARDS) || type.equals(AutoShardingType.DECREASE_SHARDS)) {
+            if (coolDownRemaining.equals(TimeValue.ZERO) == false) {
+                throw new IllegalArgumentException(
+                    "The increase/decrease shards events must have a cooldown period of zero. Use one of ["
+                        + COOLDOWN_PREVENTING_TYPES
+                        + "] types indead"
+                );
+            }
+        }
+    }
+
+    public static final AutoShardingResult NOT_APPLICABLE_RESULT = new AutoShardingResult(
+        AutoShardingType.NOT_APPLICABLE,
+        -1,
+        -1,
+        TimeValue.MAX_VALUE,
+        null
+    );
+
+}

+ 21 - 0
server/src/main/java/org/elasticsearch/action/datastreams/autosharding/AutoShardingType.java

@@ -0,0 +1,21 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.datastreams.autosharding;
+
+/**
+ * Represents the type of recommendation the auto sharding service provided.
+ */
+public enum AutoShardingType {
+    INCREASE_SHARDS,
+    DECREASE_SHARDS,
+    COOLDOWN_PREVENTED_INCREASE,
+    COOLDOWN_PREVENTED_DECREASE,
+    NO_CHANGE_REQUIRED,
+    NOT_APPLICABLE
+}

+ 415 - 0
server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java

@@ -0,0 +1,415 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.datastreams.autosharding;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.metadata.DataStream;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.IndexMetadataStats;
+import org.elasticsearch.cluster.metadata.IndexWriteLoad;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.regex.Regex;
+import org.elasticsearch.common.settings.Setting;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.features.FeatureService;
+import org.elasticsearch.features.NodeFeature;
+import org.elasticsearch.index.Index;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.OptionalDouble;
+import java.util.function.Function;
+import java.util.function.LongSupplier;
+
+import static org.elasticsearch.action.datastreams.autosharding.AutoShardingResult.NOT_APPLICABLE_RESULT;
+
+/**
+ * Calculates the optimal number of shards the data stream write index should have based on the indexing load.
+ */
+public class DataStreamAutoShardingService {
+
+    private static final Logger logger = LogManager.getLogger(DataStreamAutoShardingService.class);
+    public static final String DATA_STREAMS_AUTO_SHARDING_ENABLED = "data_streams.auto_sharding.enabled";
+
+    public static final NodeFeature DATA_STREAM_AUTO_SHARDING_FEATURE = new NodeFeature("data_stream.auto_sharding");
+
+    public static final Setting<List<String>> DATA_STREAMS_AUTO_SHARDING_EXCLUDES_SETTING = Setting.listSetting(
+        "data_streams.auto_sharding.excludes",
+        List.of("*"),
+        Function.identity(),
+        Setting.Property.Dynamic,
+        Setting.Property.NodeScope
+    );
+
+    /**
+     * Represents the minimum amount of time between two scaling events if the next event will increase the number of shards.
+     * We've chosen a value of 4.5minutes by default, just lower than the data stream lifecycle poll interval so we can increase shards with
+     * every DSL run, but we don't want it to be lower/0 as data stream lifecycle might run more often than the poll interval in case of
+     * a master failover.
+     */
+    public static final Setting<TimeValue> DATA_STREAMS_AUTO_SHARDING_INCREASE_SHARDS_COOLDOWN = Setting.timeSetting(
+        "data_streams.auto_sharding.increase_shards.cooldown",
+        TimeValue.timeValueSeconds(270),
+        TimeValue.timeValueSeconds(0),
+        Setting.Property.Dynamic,
+        Setting.Property.NodeScope
+    );
+
+    /**
+     * Represents the minimum amount of time between two scaling events if the next event will reduce the number of shards.
+     */
+    public static final Setting<TimeValue> DATA_STREAMS_AUTO_SHARDING_DECREASE_SHARDS_COOLDOWN = Setting.timeSetting(
+        "data_streams.auto_sharding.decrease_shards.cooldown",
+        TimeValue.timeValueDays(3),
+        TimeValue.timeValueSeconds(0),
+        Setting.Property.Dynamic,
+        Setting.Property.NodeScope
+    );
+
+    /**
+     * Represents the minimum number of write threads we expect a node to have in the environments where auto sharding will be enabled.
+     */
+    public static final Setting<Integer> CLUSTER_AUTO_SHARDING_MIN_WRITE_THREADS = Setting.intSetting(
+        "cluster.auto_sharding.min_write_threads",
+        2,
+        1,
+        Setting.Property.Dynamic,
+        Setting.Property.NodeScope
+    );
+
+    /**
+     * Represents the maximum number of write threads we expect a node to have in the environments where auto sharding will be enabled.
+     */
+    public static final Setting<Integer> CLUSTER_AUTO_SHARDING_MAX_WRITE_THREADS = Setting.intSetting(
+        "cluster.auto_sharding.max_write_threads",
+        32,
+        1,
+        Setting.Property.Dynamic,
+        Setting.Property.NodeScope
+    );
+    private final ClusterService clusterService;
+    private final boolean isAutoShardingEnabled;
+    private final FeatureService featureService;
+    private final LongSupplier nowSupplier;
+    private volatile TimeValue increaseShardsCooldown;
+    private volatile TimeValue reduceShardsCooldown;
+    private volatile int minWriteThreads;
+    private volatile int maxWriteThreads;
+    private volatile List<String> dataStreamExcludePatterns;
+
+    public DataStreamAutoShardingService(
+        Settings settings,
+        ClusterService clusterService,
+        FeatureService featureService,
+        LongSupplier nowSupplier
+    ) {
+        this.clusterService = clusterService;
+        this.isAutoShardingEnabled = settings.getAsBoolean(DATA_STREAMS_AUTO_SHARDING_ENABLED, false);
+        this.increaseShardsCooldown = DATA_STREAMS_AUTO_SHARDING_INCREASE_SHARDS_COOLDOWN.get(settings);
+        this.reduceShardsCooldown = DATA_STREAMS_AUTO_SHARDING_DECREASE_SHARDS_COOLDOWN.get(settings);
+        this.minWriteThreads = CLUSTER_AUTO_SHARDING_MIN_WRITE_THREADS.get(settings);
+        this.maxWriteThreads = CLUSTER_AUTO_SHARDING_MAX_WRITE_THREADS.get(settings);
+        this.dataStreamExcludePatterns = DATA_STREAMS_AUTO_SHARDING_EXCLUDES_SETTING.get(settings);
+        this.featureService = featureService;
+        this.nowSupplier = nowSupplier;
+    }
+
+    public void init() {
+        clusterService.getClusterSettings()
+            .addSettingsUpdateConsumer(DATA_STREAMS_AUTO_SHARDING_INCREASE_SHARDS_COOLDOWN, this::updateIncreaseShardsCooldown);
+        clusterService.getClusterSettings()
+            .addSettingsUpdateConsumer(DATA_STREAMS_AUTO_SHARDING_DECREASE_SHARDS_COOLDOWN, this::updateReduceShardsCooldown);
+        clusterService.getClusterSettings().addSettingsUpdateConsumer(CLUSTER_AUTO_SHARDING_MIN_WRITE_THREADS, this::updateMinWriteThreads);
+        clusterService.getClusterSettings().addSettingsUpdateConsumer(CLUSTER_AUTO_SHARDING_MAX_WRITE_THREADS, this::updateMaxWriteThreads);
+        clusterService.getClusterSettings()
+            .addSettingsUpdateConsumer(DATA_STREAMS_AUTO_SHARDING_EXCLUDES_SETTING, this::updateDataStreamExcludePatterns);
+    }
+
+    /**
+     * Computes the optimal number of shards for the provided data stream according to the write index's indexing load (to check if we must
+     * increase the number of shards, whilst the heuristics for decreasing the number of shards _might_ use the provided write indexing
+     * load).
+     * The result type will indicate the recommendation of the auto sharding service :
+     * - not applicable if the data stream is excluded from auto sharding as configured by
+     * {@link #DATA_STREAMS_AUTO_SHARDING_EXCLUDES_SETTING} or if the auto sharding functionality is disabled according to
+     * {@link #DATA_STREAMS_AUTO_SHARDING_EXCLUDES_SETTING}, or if the cluster doesn't have the feature available
+     * - increase number of shards if the optimal number of shards it deems necessary for the provided data stream is GT the current number
+     * of shards
+     * - decrease the number of shards if the optimal number of shards it deems necessary for the provided data stream is LT the current
+     * number of shards
+     *
+     * If the recommendation is to INCREASE/DECREASE shards the reported cooldown period will be TimeValue.ZERO.
+     * If the auto sharding service thinks the number of shards must be changed but it can't recommend a change due to the cooldown
+     * period not lapsing, the result will be of type {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} or
+     * {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} with the remaining cooldown configured and the number of shards that should
+     * be configured for the data stream once the remaining cooldown lapses as the target number of shards.
+     *
+     * The NOT_APPLICABLE type result will report a cooldown period of TimeValue.MAX_VALUE.
+     *
+     * The NO_CHANGE_REQUIRED type will potentially report the remaining cooldown always report a cool down period of TimeValue.ZERO (as
+     * there'll be no new auto sharding event)
+     */
+    public AutoShardingResult calculate(ClusterState state, DataStream dataStream, @Nullable Double writeIndexLoad) {
+        Metadata metadata = state.metadata();
+        if (isAutoShardingEnabled == false) {
+            logger.debug("Data stream auto sharding service is not enabled.");
+            return NOT_APPLICABLE_RESULT;
+        }
+
+        if (featureService.clusterHasFeature(state, DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE) == false) {
+            logger.debug(
+                "Data stream auto sharding service cannot compute the optimal number of shards for data stream [{}] because the cluster "
+                    + "doesn't have the auto sharding feature",
+                dataStream.getName()
+            );
+            return NOT_APPLICABLE_RESULT;
+        }
+
+        if (dataStreamExcludePatterns.stream().anyMatch(pattern -> Regex.simpleMatch(pattern, dataStream.getName()))) {
+            logger.debug(
+                "Data stream [{}] is excluded from auto sharding via the [{}] setting",
+                dataStream.getName(),
+                DATA_STREAMS_AUTO_SHARDING_EXCLUDES_SETTING.getKey()
+            );
+            return NOT_APPLICABLE_RESULT;
+        }
+
+        if (writeIndexLoad == null) {
+            logger.debug(
+                "Data stream auto sharding service cannot compute the optimal number of shards for data stream [{}] as the write index "
+                    + "load is not available",
+                dataStream.getName()
+            );
+            return NOT_APPLICABLE_RESULT;
+        }
+        return innerCalculate(metadata, dataStream, writeIndexLoad, nowSupplier);
+    }
+
+    private AutoShardingResult innerCalculate(Metadata metadata, DataStream dataStream, double writeIndexLoad, LongSupplier nowSupplier) {
+        // increasing the number of shards is calculated solely based on the index load of the write index
+        IndexMetadata writeIndex = metadata.index(dataStream.getWriteIndex());
+        assert writeIndex != null : "the data stream write index must exist in the provided cluster metadata";
+        AutoShardingResult increaseShardsResult = getIncreaseShardsResult(dataStream, writeIndexLoad, nowSupplier, writeIndex);
+        return Objects.requireNonNullElseGet(
+            increaseShardsResult,
+            () -> getDecreaseShardsResult(
+                metadata,
+                dataStream,
+                writeIndexLoad,
+                nowSupplier,
+                writeIndex,
+                getRemainingDecreaseShardsCooldown(metadata, dataStream)
+            )
+        );
+
+    }
+
+    @Nullable
+    private AutoShardingResult getIncreaseShardsResult(
+        DataStream dataStream,
+        double writeIndexLoad,
+        LongSupplier nowSupplier,
+        IndexMetadata writeIndex
+    ) {
+        // increasing the number of shards is calculated solely based on the index load of the write index
+        long optimalShardCount = computeOptimalNumberOfShards(minWriteThreads, maxWriteThreads, writeIndexLoad);
+        if (optimalShardCount > writeIndex.getNumberOfShards()) {
+            TimeValue timeSinceLastAutoShardingEvent = dataStream.getAutoShardingEvent() != null
+                ? dataStream.getAutoShardingEvent().getTimeSinceLastAutoShardingEvent(nowSupplier)
+                : TimeValue.MAX_VALUE;
+
+            TimeValue coolDownRemaining = TimeValue.timeValueMillis(
+                Math.max(0L, increaseShardsCooldown.millis() - timeSinceLastAutoShardingEvent.millis())
+            );
+            logger.debug(
+                "data stream autosharding service recommends increasing the number of shards from [{}] to [{}] after [{}] cooldown for "
+                    + "data stream [{}]",
+                writeIndex.getNumberOfShards(),
+                optimalShardCount,
+                coolDownRemaining,
+                dataStream.getName()
+            );
+            return new AutoShardingResult(
+                coolDownRemaining.equals(TimeValue.ZERO) ? AutoShardingType.INCREASE_SHARDS : AutoShardingType.COOLDOWN_PREVENTED_INCREASE,
+                writeIndex.getNumberOfShards(),
+                Math.toIntExact(optimalShardCount),
+                coolDownRemaining,
+                writeIndexLoad
+            );
+        }
+        return null;
+    }
+
+    /**
+     * Calculates the amount of time remaining before we can consider reducing the number of shards.
+     * This reference for the remaining time math is either the time since the last auto sharding event (if available) or otherwise the
+     * oldest index in the data stream.
+     */
+    private TimeValue getRemainingDecreaseShardsCooldown(Metadata metadata, DataStream dataStream) {
+        Index oldestBackingIndex = dataStream.getIndices().get(0);
+        IndexMetadata oldestIndexMeta = metadata.getIndexSafe(oldestBackingIndex);
+
+        return dataStream.getAutoShardingEvent() == null
+            // without a pre-existing auto sharding event we wait until the oldest index has been created longer than the decrease_shards
+            // cool down period "ago" so we don't immediately reduce the number of shards after a data stream is created
+            ? TimeValue.timeValueMillis(
+                Math.max(0L, oldestIndexMeta.getCreationDate() + reduceShardsCooldown.millis() - nowSupplier.getAsLong())
+            )
+            : TimeValue.timeValueMillis(
+                Math.max(
+                    0L,
+                    reduceShardsCooldown.millis() - dataStream.getAutoShardingEvent()
+                        .getTimeSinceLastAutoShardingEvent(nowSupplier)
+                        .millis()
+                )
+            );
+    }
+
+    private AutoShardingResult getDecreaseShardsResult(
+        Metadata metadata,
+        DataStream dataStream,
+        double writeIndexLoad,
+        LongSupplier nowSupplier,
+        IndexMetadata writeIndex,
+        TimeValue remainingReduceShardsCooldown
+    ) {
+        double maxIndexLoadWithinCoolingPeriod = getMaxIndexLoadWithinCoolingPeriod(
+            metadata,
+            dataStream,
+            writeIndexLoad,
+            reduceShardsCooldown,
+            nowSupplier
+        );
+
+        logger.trace(
+            "calculating the optimal number of shards for a potential decrease in number of shards for data stream [{}] with the"
+                + " max indexing load [{}] over the decrease shards cool down period",
+            dataStream.getName(),
+            maxIndexLoadWithinCoolingPeriod
+        );
+        long optimalShardCount = computeOptimalNumberOfShards(minWriteThreads, maxWriteThreads, maxIndexLoadWithinCoolingPeriod);
+        if (optimalShardCount < writeIndex.getNumberOfShards()) {
+            logger.debug(
+                "data stream autosharding service recommends decreasing the number of shards from [{}] to [{}] after [{}] cooldown for "
+                    + "data stream [{}]",
+                writeIndex.getNumberOfShards(),
+                optimalShardCount,
+                remainingReduceShardsCooldown,
+                dataStream.getName()
+            );
+
+            // we should reduce the number of shards
+            return new AutoShardingResult(
+                remainingReduceShardsCooldown.equals(TimeValue.ZERO)
+                    ? AutoShardingType.DECREASE_SHARDS
+                    : AutoShardingType.COOLDOWN_PREVENTED_DECREASE,
+                writeIndex.getNumberOfShards(),
+                Math.toIntExact(optimalShardCount),
+                remainingReduceShardsCooldown,
+                maxIndexLoadWithinCoolingPeriod
+            );
+        }
+
+        logger.trace(
+            "data stream autosharding service recommends maintaining the number of shards [{}] for data stream [{}]",
+            writeIndex.getNumberOfShards(),
+            dataStream.getName()
+        );
+        return new AutoShardingResult(
+            AutoShardingType.NO_CHANGE_REQUIRED,
+            writeIndex.getNumberOfShards(),
+            writeIndex.getNumberOfShards(),
+            TimeValue.ZERO,
+            maxIndexLoadWithinCoolingPeriod
+        );
+    }
+
+    // Visible for testing
+    static long computeOptimalNumberOfShards(int minNumberWriteThreads, int maxNumberWriteThreads, double indexingLoad) {
+        return Math.max(
+            Math.min(roundUp(indexingLoad / (minNumberWriteThreads / 2.0)), 3),
+            roundUp(indexingLoad / (maxNumberWriteThreads / 2.0))
+        );
+    }
+
+    private static long roundUp(double value) {
+        return (long) Math.ceil(value);
+    }
+
+    // Visible for testing
+    /**
+     * Calculates the maximum write index load observed for the provided data stream across all the backing indices that were created
+     * during the provide {@param coolingPeriod} (note: to cover the entire cooling period, the backing index created before the cooling
+     * period is also considered).
+     */
+    static double getMaxIndexLoadWithinCoolingPeriod(
+        Metadata metadata,
+        DataStream dataStream,
+        double writeIndexLoad,
+        TimeValue coolingPeriod,
+        LongSupplier nowSupplier
+    ) {
+        // for reducing the number of shards we look at more than just the write index
+        List<IndexWriteLoad> writeLoadsWithinCoolingPeriod = DataStream.getIndicesWithinMaxAgeRange(
+            dataStream,
+            metadata::getIndexSafe,
+            coolingPeriod,
+            nowSupplier
+        )
+            .stream()
+            .filter(index -> index.equals(dataStream.getWriteIndex()) == false)
+            .map(metadata::index)
+            .filter(Objects::nonNull)
+            .map(IndexMetadata::getStats)
+            .filter(Objects::nonNull)
+            .map(IndexMetadataStats::writeLoad)
+            .filter(Objects::nonNull)
+            .toList();
+
+        // assume the current write index load is the highest observed and look back to find the actual maximum
+        double maxIndexLoadWithinCoolingPeriod = writeIndexLoad;
+        for (IndexWriteLoad writeLoad : writeLoadsWithinCoolingPeriod) {
+            double totalIndexLoad = 0;
+            for (int shardId = 0; shardId < writeLoad.numberOfShards(); shardId++) {
+                final OptionalDouble writeLoadForShard = writeLoad.getWriteLoadForShard(shardId);
+                totalIndexLoad += writeLoadForShard.orElse(0);
+            }
+
+            if (totalIndexLoad > maxIndexLoadWithinCoolingPeriod) {
+                maxIndexLoadWithinCoolingPeriod = totalIndexLoad;
+            }
+        }
+        return maxIndexLoadWithinCoolingPeriod;
+    }
+
+    void updateIncreaseShardsCooldown(TimeValue scaleUpCooldown) {
+        this.increaseShardsCooldown = scaleUpCooldown;
+    }
+
+    void updateReduceShardsCooldown(TimeValue scaleDownCooldown) {
+        this.reduceShardsCooldown = scaleDownCooldown;
+    }
+
+    void updateMinWriteThreads(int minNumberWriteThreads) {
+        this.minWriteThreads = minNumberWriteThreads;
+    }
+
+    void updateMaxWriteThreads(int maxNumberWriteThreads) {
+        this.maxWriteThreads = maxNumberWriteThreads;
+    }
+
+    private void updateDataStreamExcludePatterns(List<String> newExcludePatterns) {
+        this.dataStreamExcludePatterns = newExcludePatterns;
+    }
+}

+ 88 - 18
server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java

@@ -70,6 +70,7 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
 
     public static final FeatureFlag FAILURE_STORE_FEATURE_FLAG = new FeatureFlag("failure_store");
     public static final TransportVersion ADDED_FAILURE_STORE_TRANSPORT_VERSION = TransportVersions.V_8_12_0;
+    public static final TransportVersion ADDED_AUTO_SHARDING_EVENT_VERSION = TransportVersions.DATA_STREAM_AUTO_SHARDING_EVENT;
 
     public static boolean isFailureStoreEnabled() {
         return FAILURE_STORE_FEATURE_FLAG.isEnabled();
@@ -113,6 +114,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
     private final boolean rolloverOnWrite;
     private final boolean failureStore;
     private final List<Index> failureIndices;
+    @Nullable
+    private final DataStreamAutoShardingEvent autoShardingEvent;
 
     public DataStream(
         String name,
@@ -126,7 +129,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
         IndexMode indexMode,
         DataStreamLifecycle lifecycle,
         boolean failureStore,
-        List<Index> failureIndices
+        List<Index> failureIndices,
+        @Nullable DataStreamAutoShardingEvent autoShardingEvent
     ) {
         this(
             name,
@@ -142,7 +146,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             lifecycle,
             failureStore,
             failureIndices,
-            false
+            false,
+            autoShardingEvent
         );
     }
 
@@ -159,7 +164,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
         DataStreamLifecycle lifecycle,
         boolean failureStore,
         List<Index> failureIndices,
-        boolean rolloverOnWrite
+        boolean rolloverOnWrite,
+        @Nullable DataStreamAutoShardingEvent autoShardingEvent
     ) {
         this(
             name,
@@ -175,7 +181,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             lifecycle,
             failureStore,
             failureIndices,
-            rolloverOnWrite
+            rolloverOnWrite,
+            autoShardingEvent
         );
     }
 
@@ -194,7 +201,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
         DataStreamLifecycle lifecycle,
         boolean failureStore,
         List<Index> failureIndices,
-        boolean rolloverOnWrite
+        boolean rolloverOnWrite,
+        @Nullable DataStreamAutoShardingEvent autoShardingEvent
     ) {
         this.name = name;
         this.indices = List.copyOf(indices);
@@ -213,6 +221,7 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
         this.failureIndices = failureIndices;
         assert assertConsistent(this.indices);
         this.rolloverOnWrite = rolloverOnWrite;
+        this.autoShardingEvent = autoShardingEvent;
     }
 
     // mainly available for testing
@@ -227,7 +236,7 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
         boolean allowCustomRouting,
         IndexMode indexMode
     ) {
-        this(name, indices, generation, metadata, hidden, replicated, system, allowCustomRouting, indexMode, null, false, List.of());
+        this(name, indices, generation, metadata, hidden, replicated, system, allowCustomRouting, indexMode, null, false, List.of(), null);
     }
 
     private static boolean assertConsistent(List<Index> indices) {
@@ -412,6 +421,13 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
         return lifecycle;
     }
 
+    /**
+     * Returns the latest auto sharding event that happened for this data stream
+     */
+    public DataStreamAutoShardingEvent getAutoShardingEvent() {
+        return autoShardingEvent;
+    }
+
     /**
      * Performs a rollover on a {@code DataStream} instance and returns a new instance containing
      * the updated list of backing indices and incremented generation.
@@ -456,7 +472,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             indexMode,
             lifecycle,
             failureStore,
-            failureIndices
+            failureIndices,
+            autoShardingEvent
         );
     }
 
@@ -534,7 +551,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             indexMode,
             lifecycle,
             failureStore,
-            failureIndices
+            failureIndices,
+            autoShardingEvent
         );
     }
 
@@ -579,7 +597,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             indexMode,
             lifecycle,
             failureStore,
-            failureIndices
+            failureIndices,
+            autoShardingEvent
         );
     }
 
@@ -639,7 +658,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             indexMode,
             lifecycle,
             failureStore,
-            failureIndices
+            failureIndices,
+            autoShardingEvent
         );
     }
 
@@ -658,7 +678,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             lifecycle,
             failureStore,
             failureIndices,
-            rolloverOnWrite
+            rolloverOnWrite,
+            autoShardingEvent
         );
     }
 
@@ -694,7 +715,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             indexMode,
             lifecycle,
             failureStore,
-            failureIndices
+            failureIndices,
+            autoShardingEvent
         );
     }
 
@@ -909,7 +931,10 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             in.getTransportVersion().onOrAfter(TransportVersions.V_8_9_X) ? in.readOptionalWriteable(DataStreamLifecycle::new) : null,
             in.getTransportVersion().onOrAfter(DataStream.ADDED_FAILURE_STORE_TRANSPORT_VERSION) ? in.readBoolean() : false,
             in.getTransportVersion().onOrAfter(DataStream.ADDED_FAILURE_STORE_TRANSPORT_VERSION) ? readIndices(in) : List.of(),
-            in.getTransportVersion().onOrAfter(TransportVersions.LAZY_ROLLOVER_ADDED) ? in.readBoolean() : false
+            in.getTransportVersion().onOrAfter(TransportVersions.LAZY_ROLLOVER_ADDED) ? in.readBoolean() : false,
+            in.getTransportVersion().onOrAfter(DataStream.ADDED_AUTO_SHARDING_EVENT_VERSION)
+                ? in.readOptionalWriteable(DataStreamAutoShardingEvent::new)
+                : null
         );
     }
 
@@ -953,6 +978,9 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
         if (out.getTransportVersion().onOrAfter(TransportVersions.LAZY_ROLLOVER_ADDED)) {
             out.writeBoolean(rolloverOnWrite);
         }
+        if (out.getTransportVersion().onOrAfter(DataStream.ADDED_AUTO_SHARDING_EVENT_VERSION)) {
+            out.writeOptionalWriteable(autoShardingEvent);
+        }
     }
 
     public static final ParseField NAME_FIELD = new ParseField("name");
@@ -969,13 +997,14 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
     public static final ParseField FAILURE_STORE_FIELD = new ParseField("failure_store");
     public static final ParseField FAILURE_INDICES_FIELD = new ParseField("failure_indices");
     public static final ParseField ROLLOVER_ON_WRITE_FIELD = new ParseField("rollover_on_write");
+    public static final ParseField AUTO_SHARDING_FIELD = new ParseField("auto_sharding");
 
     @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<DataStream, Void> PARSER = new ConstructingObjectParser<>("data_stream", args -> {
         // Fields behind a feature flag need to be parsed last otherwise the parser will fail when the feature flag is disabled.
         // Until the feature flag is removed we keep them separately to be mindful of this.
-        boolean failureStoreEnabled = DataStream.isFailureStoreEnabled() && args[11] != null && (boolean) args[11];
-        List<Index> failureStoreIndices = DataStream.isFailureStoreEnabled() && args[12] != null ? (List<Index>) args[12] : List.of();
+        boolean failureStoreEnabled = DataStream.isFailureStoreEnabled() && args[12] != null && (boolean) args[12];
+        List<Index> failureStoreIndices = DataStream.isFailureStoreEnabled() && args[13] != null ? (List<Index>) args[13] : List.of();
         return new DataStream(
             (String) args[0],
             (List<Index>) args[1],
@@ -989,7 +1018,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             (DataStreamLifecycle) args[9],
             failureStoreEnabled,
             failureStoreIndices,
-            args[10] != null && (boolean) args[10]
+            args[10] != null && (boolean) args[10],
+            (DataStreamAutoShardingEvent) args[11]
         );
     });
 
@@ -1013,6 +1043,11 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
         PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), INDEX_MODE);
         PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> DataStreamLifecycle.fromXContent(p), LIFECYCLE);
         PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), ROLLOVER_ON_WRITE_FIELD);
+        PARSER.declareObject(
+            ConstructingObjectParser.optionalConstructorArg(),
+            (p, c) -> DataStreamAutoShardingEvent.fromXContent(p),
+            AUTO_SHARDING_FIELD
+        );
         // The fields behind the feature flag should always be last.
         if (DataStream.isFailureStoreEnabled()) {
             PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), FAILURE_STORE_FIELD);
@@ -1067,6 +1102,11 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             lifecycle.toXContent(builder, params, rolloverConfiguration);
         }
         builder.field(ROLLOVER_ON_WRITE_FIELD.getPreferredName(), rolloverOnWrite);
+        if (autoShardingEvent != null) {
+            builder.startObject(AUTO_SHARDING_FIELD.getPreferredName());
+            autoShardingEvent.toXContent(builder, params);
+            builder.endObject();
+        }
         builder.endObject();
         return builder;
     }
@@ -1088,7 +1128,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             && Objects.equals(lifecycle, that.lifecycle)
             && failureStore == that.failureStore
             && failureIndices.equals(that.failureIndices)
-            && rolloverOnWrite == that.rolloverOnWrite;
+            && rolloverOnWrite == that.rolloverOnWrite
+            && Objects.equals(autoShardingEvent, that.autoShardingEvent);
     }
 
     @Override
@@ -1106,7 +1147,8 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
             lifecycle,
             failureStore,
             failureIndices,
-            rolloverOnWrite
+            rolloverOnWrite,
+            autoShardingEvent
         );
     }
 
@@ -1169,6 +1211,34 @@ public final class DataStream implements SimpleDiffable<DataStream>, ToXContentO
         "strict_date_optional_time_nanos||strict_date_optional_time||epoch_millis"
     );
 
+    /**
+     * Returns the indices created within the {@param maxIndexAge} interval. Note that this strives to cover
+     * the entire {@param maxIndexAge} interval so one backing index created before the specified age will also
+     * be return.
+     */
+    public static List<Index> getIndicesWithinMaxAgeRange(
+        DataStream dataStream,
+        Function<Index, IndexMetadata> indexProvider,
+        TimeValue maxIndexAge,
+        LongSupplier nowSupplier
+    ) {
+        final List<Index> dataStreamIndices = dataStream.getIndices();
+        final long currentTimeMillis = nowSupplier.getAsLong();
+        // Consider at least 1 index (including the write index) for cases where rollovers happen less often than maxIndexAge
+        int firstIndexWithinAgeRange = Math.max(dataStreamIndices.size() - 2, 0);
+        for (int i = 0; i < dataStreamIndices.size(); i++) {
+            Index index = dataStreamIndices.get(i);
+            final IndexMetadata indexMetadata = indexProvider.apply(index);
+            final long indexAge = currentTimeMillis - indexMetadata.getCreationDate();
+            if (indexAge < maxIndexAge.getMillis()) {
+                // We need to consider the previous index too in order to cover the entire max-index-age range.
+                firstIndexWithinAgeRange = i == 0 ? 0 : i - 1;
+                break;
+            }
+        }
+        return dataStreamIndices.subList(firstIndexWithinAgeRange, dataStreamIndices.size());
+    }
+
     private static Instant getTimeStampFromRaw(Object rawTimestamp) {
         try {
             if (rawTimestamp instanceof Long lTimestamp) {

+ 84 - 0
server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamAutoShardingEvent.java

@@ -0,0 +1,84 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.cluster.metadata;
+
+import org.elasticsearch.cluster.Diff;
+import org.elasticsearch.cluster.SimpleDiffable;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ToXContentFragment;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.function.LongSupplier;
+
+/**
+ * Represents the last auto sharding event that occured for a data stream.
+ */
+public record DataStreamAutoShardingEvent(String triggerIndexName, int targetNumberOfShards, long timestamp)
+    implements
+        SimpleDiffable<DataStreamAutoShardingEvent>,
+        ToXContentFragment {
+
+    public static final ParseField TRIGGER_INDEX_NAME = new ParseField("trigger_index_name");
+    public static final ParseField TARGET_NUMBER_OF_SHARDS = new ParseField("target_number_of_shards");
+    public static final ParseField EVENT_TIME = new ParseField("event_time");
+    public static final ParseField EVENT_TIME_MILLIS = new ParseField("event_time_millis");
+
+    public static final ConstructingObjectParser<DataStreamAutoShardingEvent, Void> PARSER = new ConstructingObjectParser<>(
+        "auto_sharding",
+        false,
+        (args, unused) -> new DataStreamAutoShardingEvent((String) args[0], (int) args[1], (long) args[2])
+    );
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TRIGGER_INDEX_NAME);
+        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), TARGET_NUMBER_OF_SHARDS);
+        PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), EVENT_TIME_MILLIS);
+    }
+
+    public static DataStreamAutoShardingEvent fromXContent(XContentParser parser) throws IOException {
+        return PARSER.parse(parser, null);
+    }
+
+    static Diff<DataStreamAutoShardingEvent> readDiffFrom(StreamInput in) throws IOException {
+        return SimpleDiffable.readDiffFrom(DataStreamAutoShardingEvent::new, in);
+    }
+
+    DataStreamAutoShardingEvent(StreamInput in) throws IOException {
+        this(in.readString(), in.readVInt(), in.readVLong());
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field(TRIGGER_INDEX_NAME.getPreferredName(), triggerIndexName);
+        builder.field(TARGET_NUMBER_OF_SHARDS.getPreferredName(), targetNumberOfShards);
+        builder.humanReadableField(
+            EVENT_TIME_MILLIS.getPreferredName(),
+            EVENT_TIME.getPreferredName(),
+            TimeValue.timeValueMillis(timestamp)
+        );
+        return builder;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(triggerIndexName);
+        out.writeVInt(targetNumberOfShards);
+        out.writeVLong(timestamp);
+    }
+
+    public TimeValue getTimeSinceLastAutoShardingEvent(LongSupplier now) {
+        return TimeValue.timeValueMillis(Math.max(0L, now.getAsLong() - timestamp));
+    }
+}

+ 2 - 1
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateDataStreamService.java

@@ -314,7 +314,8 @@ public class MetadataCreateDataStreamService {
             indexMode,
             lifecycle == null && isDslOnlyMode ? DataStreamLifecycle.DEFAULT : lifecycle,
             template.getDataStreamTemplate().hasFailureStore(),
-            failureIndices
+            failureIndices,
+            null
         );
         Metadata.Builder builder = Metadata.builder(currentState.metadata()).put(newDataStream);
 

+ 4 - 2
server/src/main/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsService.java

@@ -212,7 +212,8 @@ public class MetadataDataStreamsService {
                     dataStream.getIndexMode(),
                     lifecycle,
                     dataStream.isFailureStore(),
-                    dataStream.getFailureIndices()
+                    dataStream.getFailureIndices(),
+                    dataStream.getAutoShardingEvent()
                 )
             );
         }
@@ -249,7 +250,8 @@ public class MetadataDataStreamsService {
                 dataStream.getLifecycle(),
                 dataStream.isFailureStore(),
                 dataStream.getFailureIndices(),
-                rolloverOnWrite
+                rolloverOnWrite,
+                dataStream.getAutoShardingEvent()
             )
         );
         return ClusterState.builder(currentState).metadata(builder.build()).build();

+ 2 - 1
server/src/main/java/org/elasticsearch/snapshots/RestoreService.java

@@ -716,7 +716,8 @@ public final class RestoreService implements ClusterStateApplier {
             dataStream.getIndexMode(),
             dataStream.getLifecycle(),
             dataStream.isFailureStore(),
-            dataStream.getFailureIndices()
+            dataStream.getFailureIndices(),
+            dataStream.getAutoShardingEvent()
         );
     }
 

+ 771 - 0
server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java

@@ -0,0 +1,771 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.action.datastreams.autosharding;
+
+import org.elasticsearch.action.admin.indices.rollover.MaxAgeCondition;
+import org.elasticsearch.action.admin.indices.rollover.RolloverInfo;
+import org.elasticsearch.cluster.ClusterName;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.metadata.DataStream;
+import org.elasticsearch.cluster.metadata.DataStreamAutoShardingEvent;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.metadata.IndexMetadataStats;
+import org.elasticsearch.cluster.metadata.IndexWriteLoad;
+import org.elasticsearch.cluster.metadata.Metadata;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.settings.ClusterSettings;
+import org.elasticsearch.common.settings.Setting;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.features.FeatureService;
+import org.elasticsearch.features.FeatureSpecification;
+import org.elasticsearch.features.NodeFeature;
+import org.elasticsearch.index.Index;
+import org.elasticsearch.index.IndexMode;
+import org.elasticsearch.index.IndexVersion;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.junit.After;
+import org.junit.Before;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+
+import static org.elasticsearch.action.datastreams.autosharding.AutoShardingResult.NOT_APPLICABLE_RESULT;
+import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.COOLDOWN_PREVENTED_DECREASE;
+import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.COOLDOWN_PREVENTED_INCREASE;
+import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.DECREASE_SHARDS;
+import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.INCREASE_SHARDS;
+import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.NO_CHANGE_REQUIRED;
+import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
+import static org.hamcrest.Matchers.is;
+
+public class DataStreamAutoShardingServiceTests extends ESTestCase {
+
+    private ClusterService clusterService;
+    private ThreadPool threadPool;
+    private DataStreamAutoShardingService service;
+    private long now;
+    String dataStreamName;
+
+    @Before
+    public void setupService() {
+        threadPool = new TestThreadPool(getTestName());
+        Set<Setting<?>> builtInClusterSettings = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
+        builtInClusterSettings.add(DataStreamAutoShardingService.CLUSTER_AUTO_SHARDING_MIN_WRITE_THREADS);
+        builtInClusterSettings.add(DataStreamAutoShardingService.CLUSTER_AUTO_SHARDING_MAX_WRITE_THREADS);
+        builtInClusterSettings.add(DataStreamAutoShardingService.DATA_STREAMS_AUTO_SHARDING_INCREASE_SHARDS_COOLDOWN);
+        builtInClusterSettings.add(DataStreamAutoShardingService.DATA_STREAMS_AUTO_SHARDING_DECREASE_SHARDS_COOLDOWN);
+        builtInClusterSettings.add(
+            Setting.boolSetting(
+                DataStreamAutoShardingService.DATA_STREAMS_AUTO_SHARDING_ENABLED,
+                false,
+                Setting.Property.Dynamic,
+                Setting.Property.NodeScope
+            )
+        );
+        ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, builtInClusterSettings);
+        clusterService = createClusterService(threadPool, clusterSettings);
+        now = System.currentTimeMillis();
+        service = new DataStreamAutoShardingService(
+            Settings.builder()
+                .put(DataStreamAutoShardingService.DATA_STREAMS_AUTO_SHARDING_ENABLED, true)
+                .putList(DataStreamAutoShardingService.DATA_STREAMS_AUTO_SHARDING_EXCLUDES_SETTING.getKey(), List.of())
+                .build(),
+            clusterService,
+            new FeatureService(List.of(new FeatureSpecification() {
+                @Override
+                public Set<NodeFeature> getFeatures() {
+                    return Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE);
+                }
+            })),
+            () -> now
+        );
+        dataStreamName = randomAlphaOfLengthBetween(10, 100);
+        logger.info("-> data stream name is [{}]", dataStreamName);
+    }
+
+    @After
+    public void cleanup() {
+        clusterService.close();
+        threadPool.shutdownNow();
+    }
+
+    public void testCalculateValidations() {
+        Metadata.Builder builder = Metadata.builder();
+        DataStream dataStream = createDataStream(
+            builder,
+            dataStreamName,
+            1,
+            now,
+            List.of(now - 3000, now - 2000, now - 1000),
+            getWriteLoad(1, 2.0),
+            null
+        );
+        builder.put(dataStream);
+        ClusterState state = ClusterState.builder(ClusterName.DEFAULT)
+            .nodeFeatures(
+                Map.of(
+                    "n1",
+                    Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id()),
+                    "n2",
+                    Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id())
+                )
+            )
+            .metadata(builder)
+            .build();
+
+        {
+            // autosharding disabled
+            DataStreamAutoShardingService disabledAutoshardingService = new DataStreamAutoShardingService(
+                Settings.EMPTY,
+                clusterService,
+                new FeatureService(List.of(new FeatureSpecification() {
+                    @Override
+                    public Set<NodeFeature> getFeatures() {
+                        return Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE);
+                    }
+                })),
+                System::currentTimeMillis
+            );
+
+            AutoShardingResult autoShardingResult = disabledAutoshardingService.calculate(state, dataStream, 2.0);
+            assertThat(autoShardingResult, is(NOT_APPLICABLE_RESULT));
+        }
+
+        {
+            // cluster doesn't have feature
+            ClusterState stateNoFeature = ClusterState.builder(ClusterName.DEFAULT).metadata(Metadata.builder()).build();
+
+            DataStreamAutoShardingService noFeatureService = new DataStreamAutoShardingService(
+                Settings.builder()
+                    .put(DataStreamAutoShardingService.DATA_STREAMS_AUTO_SHARDING_ENABLED, true)
+                    .putList(DataStreamAutoShardingService.DATA_STREAMS_AUTO_SHARDING_EXCLUDES_SETTING.getKey(), List.of())
+                    .build(),
+                clusterService,
+                new FeatureService(List.of()),
+                () -> now
+            );
+
+            AutoShardingResult autoShardingResult = noFeatureService.calculate(stateNoFeature, dataStream, 2.0);
+            assertThat(autoShardingResult, is(NOT_APPLICABLE_RESULT));
+        }
+
+        {
+            // patterns are configured to exclude the current data stream
+            DataStreamAutoShardingService noFeatureService = new DataStreamAutoShardingService(
+                Settings.builder()
+                    .put(DataStreamAutoShardingService.DATA_STREAMS_AUTO_SHARDING_ENABLED, true)
+                    .putList(
+                        DataStreamAutoShardingService.DATA_STREAMS_AUTO_SHARDING_EXCLUDES_SETTING.getKey(),
+                        List.of("foo", dataStreamName + "*")
+                    )
+                    .build(),
+                clusterService,
+                new FeatureService(List.of()),
+                () -> now
+            );
+
+            AutoShardingResult autoShardingResult = noFeatureService.calculate(state, dataStream, 2.0);
+            assertThat(autoShardingResult, is(NOT_APPLICABLE_RESULT));
+        }
+
+        {
+            // null write load passed
+            AutoShardingResult autoShardingResult = service.calculate(state, dataStream, null);
+            assertThat(autoShardingResult, is(NOT_APPLICABLE_RESULT));
+        }
+    }
+
+    public void testCalculateIncreaseShardingRecommendations() {
+        // the input is a data stream with 5 backing indices with 1 shard each
+        // all 4 backing indices have a write load of 2.0
+        // we'll recreate it across the test and add an auto sharding event as we iterate
+        {
+            Metadata.Builder builder = Metadata.builder();
+            Function<DataStreamAutoShardingEvent, DataStream> dataStreamSupplier = (autoShardingEvent) -> createDataStream(
+                builder,
+                dataStreamName,
+                1,
+                now,
+                List.of(now - 10_000, now - 7000, now - 5000, now - 2000, now - 1000),
+                getWriteLoad(1, 2.0),
+                autoShardingEvent
+            );
+
+            DataStream dataStream = dataStreamSupplier.apply(null);
+            builder.put(dataStream);
+            ClusterState state = ClusterState.builder(ClusterName.DEFAULT)
+                .nodeFeatures(
+                    Map.of(
+                        "n1",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id()),
+                        "n2",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id())
+                    )
+                )
+                .metadata(builder)
+                .build();
+
+            AutoShardingResult autoShardingResult = service.calculate(state, dataStream, 2.5);
+            assertThat(autoShardingResult.type(), is(INCREASE_SHARDS));
+            // no pre-existing scaling event so the cool down must be zero
+            assertThat(autoShardingResult.coolDownRemaining(), is(TimeValue.ZERO));
+            assertThat(autoShardingResult.targetNumberOfShards(), is(3));
+        }
+
+        {
+            // let's add a pre-existing sharding event so that we'll return some cool down period that's preventing an INCREASE_SHARDS
+            // event so the result type we're expecting is COOLDOWN_PREVENTED_INCREASE
+            Metadata.Builder builder = Metadata.builder();
+            Function<DataStreamAutoShardingEvent, DataStream> dataStreamSupplier = (autoShardingEvent) -> createDataStream(
+                builder,
+                dataStreamName,
+                1,
+                now,
+                List.of(now - 10_000, now - 7000, now - 5000, now - 2000, now - 1000),
+                getWriteLoad(1, 2.0),
+                autoShardingEvent
+            );
+
+            // generation 4 triggered an auto sharding event to 2 shards
+            DataStream dataStream = dataStreamSupplier.apply(
+                new DataStreamAutoShardingEvent(DataStream.getDefaultBackingIndexName(dataStreamName, 4), 2, now - 1005)
+            );
+            builder.put(dataStream);
+            ClusterState state = ClusterState.builder(ClusterName.DEFAULT)
+                .nodeFeatures(
+                    Map.of(
+                        "n1",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id()),
+                        "n2",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id())
+                    )
+                )
+                .metadata(builder)
+                .build();
+
+            AutoShardingResult autoShardingResult = service.calculate(state, dataStream, 2.5);
+            assertThat(autoShardingResult.type(), is(COOLDOWN_PREVENTED_INCREASE));
+            // no pre-existing scaling event so the cool down must be zero
+            assertThat(autoShardingResult.targetNumberOfShards(), is(3));
+            // it's been 1005 millis since the last auto sharding event and the cool down is 270secoinds (270_000 millis)
+            assertThat(autoShardingResult.coolDownRemaining(), is(TimeValue.timeValueMillis(268995)));
+        }
+
+        {
+            // let's test a subsequent increase in the number of shards after a previos auto sharding event
+            Metadata.Builder builder = Metadata.builder();
+            Function<DataStreamAutoShardingEvent, DataStream> dataStreamSupplier = (autoShardingEvent) -> createDataStream(
+                builder,
+                dataStreamName,
+                1,
+                now,
+                List.of(now - 10_000_000, now - 7_000_000, now - 2_000_000, now - 1_000_000, now - 1000),
+                getWriteLoad(1, 2.0),
+                autoShardingEvent
+            );
+
+            // generation 3 triggered an increase in shards event to 2 shards
+            DataStream dataStream = dataStreamSupplier.apply(
+                new DataStreamAutoShardingEvent(DataStream.getDefaultBackingIndexName(dataStreamName, 4), 2, now - 2_000_100)
+            );
+            builder.put(dataStream);
+            ClusterState state = ClusterState.builder(ClusterName.DEFAULT)
+                .nodeFeatures(
+                    Map.of(
+                        "n1",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id()),
+                        "n2",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id())
+                    )
+                )
+                .metadata(builder)
+                .build();
+
+            AutoShardingResult autoShardingResult = service.calculate(state, dataStream, 2.5);
+            assertThat(autoShardingResult.type(), is(INCREASE_SHARDS));
+            // no pre-existing scaling event so the cool down must be zero
+            assertThat(autoShardingResult.targetNumberOfShards(), is(3));
+            assertThat(autoShardingResult.coolDownRemaining(), is(TimeValue.ZERO));
+        }
+    }
+
+    public void testCalculateDecreaseShardingRecommendations() {
+        // the input is a data stream with 5 backing indices with 3 shards each
+        {
+            // testing a decrease shards events prevented by the cool down period not lapsing due to the oldest generation index being
+            // "too new" (i.e. the cool down period hasn't lapsed since the oldest generation index)
+            Metadata.Builder builder = Metadata.builder();
+            Function<DataStreamAutoShardingEvent, DataStream> dataStreamSupplier = (autoShardingEvent) -> createDataStream(
+                builder,
+                dataStreamName,
+                3,
+                now,
+                List.of(now - 10_000, now - 7000, now - 5000, now - 2000, now - 1000),
+                getWriteLoad(3, 0.25),
+                autoShardingEvent
+            );
+
+            DataStream dataStream = dataStreamSupplier.apply(null);
+            builder.put(dataStream);
+            ClusterState state = ClusterState.builder(ClusterName.DEFAULT)
+                .nodeFeatures(
+                    Map.of(
+                        "n1",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id()),
+                        "n2",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id())
+                    )
+                )
+                .metadata(builder)
+                .build();
+
+            AutoShardingResult autoShardingResult = service.calculate(state, dataStream, 1.0);
+            // the cooldown period for the decrease shards event hasn't lapsed since the data stream was created
+            assertThat(autoShardingResult.type(), is(COOLDOWN_PREVENTED_DECREASE));
+            assertThat(autoShardingResult.coolDownRemaining(), is(TimeValue.timeValueMillis(TimeValue.timeValueDays(3).millis() - 10_000)));
+            // based on the write load of 0.75 we should be reducing the number of shards to 1
+            assertThat(autoShardingResult.targetNumberOfShards(), is(1));
+        }
+
+        {
+            Metadata.Builder builder = Metadata.builder();
+            Function<DataStreamAutoShardingEvent, DataStream> dataStreamSupplier = (autoShardingEvent) -> createDataStream(
+                builder,
+                dataStreamName,
+                3,
+                now,
+                List.of(
+                    now - TimeValue.timeValueDays(21).getMillis(),
+                    now - TimeValue.timeValueDays(15).getMillis(),
+                    now - TimeValue.timeValueDays(4).getMillis(),
+                    now - TimeValue.timeValueDays(2).getMillis(),
+                    now - 1000
+                ),
+                getWriteLoad(3, 0.333),
+                autoShardingEvent
+            );
+
+            DataStream dataStream = dataStreamSupplier.apply(null);
+            builder.put(dataStream);
+            ClusterState state = ClusterState.builder(ClusterName.DEFAULT)
+                .nodeFeatures(
+                    Map.of(
+                        "n1",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id()),
+                        "n2",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id())
+                    )
+                )
+                .metadata(builder)
+                .build();
+
+            AutoShardingResult autoShardingResult = service.calculate(state, dataStream, 1.0);
+            assertThat(autoShardingResult.type(), is(DECREASE_SHARDS));
+            assertThat(autoShardingResult.targetNumberOfShards(), is(1));
+            // no pre-existing auto sharding event however we have old enough backing indices (older than the cooldown period) so we can
+            // make a decision to reduce the number of shards
+            assertThat(autoShardingResult.coolDownRemaining(), is(TimeValue.ZERO));
+        }
+
+        {
+            // let's test a decrease in number of shards after a previous decrease event
+            Metadata.Builder builder = Metadata.builder();
+            Function<DataStreamAutoShardingEvent, DataStream> dataStreamSupplier = (autoShardingEvent) -> createDataStream(
+                builder,
+                dataStreamName,
+                3,
+                now,
+                List.of(
+                    now - TimeValue.timeValueDays(21).getMillis(),
+                    now - TimeValue.timeValueDays(15).getMillis(), // triggers auto sharding event
+                    now - TimeValue.timeValueDays(4).getMillis(),
+                    now - TimeValue.timeValueDays(2).getMillis(),
+                    now - 1000
+                ),
+                getWriteLoad(3, 0.333),
+                autoShardingEvent
+            );
+
+            // generation 2 triggered a decrease in shards event to 2 shards
+            DataStream dataStream = dataStreamSupplier.apply(
+                new DataStreamAutoShardingEvent(
+                    DataStream.getDefaultBackingIndexName(dataStreamName, 2),
+                    2,
+                    now - TimeValue.timeValueDays(4).getMillis()
+                )
+            );
+            builder.put(dataStream);
+            ClusterState state = ClusterState.builder(ClusterName.DEFAULT)
+                .nodeFeatures(
+                    Map.of(
+                        "n1",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id()),
+                        "n2",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id())
+                    )
+                )
+                .metadata(builder)
+                .build();
+
+            AutoShardingResult autoShardingResult = service.calculate(state, dataStream, 1.0);
+            assertThat(autoShardingResult.type(), is(DECREASE_SHARDS));
+            assertThat(autoShardingResult.targetNumberOfShards(), is(1));
+            assertThat(autoShardingResult.coolDownRemaining(), is(TimeValue.ZERO));
+        }
+
+        {
+            // let's test a decrease in number of shards that's prevented by the cool down period due to a previous sharding event
+            // the expected result type here is COOLDOWN_PREVENTED_DECREASE
+            Metadata.Builder builder = Metadata.builder();
+            Function<DataStreamAutoShardingEvent, DataStream> dataStreamSupplier = (autoShardingEvent) -> createDataStream(
+                builder,
+                dataStreamName,
+                3,
+                now,
+                List.of(
+                    now - TimeValue.timeValueDays(21).getMillis(),
+                    now - TimeValue.timeValueDays(2).getMillis(), // triggers auto sharding event
+                    now - TimeValue.timeValueDays(1).getMillis(),
+                    now - 1000
+                ),
+                getWriteLoad(3, 0.25),
+                autoShardingEvent
+            );
+
+            // generation 2 triggered a decrease in shards event to 2 shards
+            DataStream dataStream = dataStreamSupplier.apply(
+                new DataStreamAutoShardingEvent(
+                    DataStream.getDefaultBackingIndexName(dataStreamName, 2),
+                    2,
+                    now - TimeValue.timeValueDays(2).getMillis()
+                )
+            );
+            builder.put(dataStream);
+            ClusterState state = ClusterState.builder(ClusterName.DEFAULT)
+                .nodeFeatures(
+                    Map.of(
+                        "n1",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id()),
+                        "n2",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id())
+                    )
+                )
+                .metadata(builder)
+                .build();
+
+            AutoShardingResult autoShardingResult = service.calculate(state, dataStream, 1.0);
+            assertThat(autoShardingResult.type(), is(COOLDOWN_PREVENTED_DECREASE));
+            assertThat(autoShardingResult.targetNumberOfShards(), is(1));
+            assertThat(autoShardingResult.coolDownRemaining(), is(TimeValue.timeValueDays(1)));
+        }
+
+        {
+            // no change required
+            Metadata.Builder builder = Metadata.builder();
+            Function<DataStreamAutoShardingEvent, DataStream> dataStreamSupplier = (autoShardingEvent) -> createDataStream(
+                builder,
+                dataStreamName,
+                3,
+                now,
+                List.of(
+                    now - TimeValue.timeValueDays(21).getMillis(),
+                    now - TimeValue.timeValueDays(15).getMillis(),
+                    now - TimeValue.timeValueDays(4).getMillis(),
+                    now - TimeValue.timeValueDays(2).getMillis(),
+                    now - 1000
+                ),
+                getWriteLoad(3, 1.333),
+                autoShardingEvent
+            );
+
+            // generation 2 triggered a decrease in shards event to 2 shards
+            DataStream dataStream = dataStreamSupplier.apply(null);
+            builder.put(dataStream);
+            ClusterState state = ClusterState.builder(ClusterName.DEFAULT)
+                .nodeFeatures(
+                    Map.of(
+                        "n1",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id()),
+                        "n2",
+                        Set.of(DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE.id())
+                    )
+                )
+                .metadata(builder)
+                .build();
+
+            AutoShardingResult autoShardingResult = service.calculate(state, dataStream, 4.0);
+            assertThat(autoShardingResult.type(), is(NO_CHANGE_REQUIRED));
+            assertThat(autoShardingResult.targetNumberOfShards(), is(3));
+            assertThat(autoShardingResult.coolDownRemaining(), is(TimeValue.ZERO));
+        }
+    }
+
+    public void testComputeOptimalNumberOfShards() {
+        int minWriteThreads = 2;
+        int maxWriteThreads = 32;
+        {
+            // the small values will be very common so let's randomise to make sure we never go below 1L
+            double indexingLoad = randomDoubleBetween(0.0001, 1.0, true);
+            logger.info("-> indexingLoad {}", indexingLoad);
+            assertThat(DataStreamAutoShardingService.computeOptimalNumberOfShards(minWriteThreads, maxWriteThreads, indexingLoad), is(1L));
+        }
+
+        {
+            double indexingLoad = 2.0;
+            logger.info("-> indexingLoad {}", indexingLoad);
+            assertThat(DataStreamAutoShardingService.computeOptimalNumberOfShards(minWriteThreads, maxWriteThreads, indexingLoad), is(2L));
+        }
+
+        {
+            // there's a broad range of popular values (a write index starting to be very busy, using between 3 and all of the 32 write
+            // threads, so let's randomise this too to make sure we stay at 3 recommended shards)
+            double indexingLoad = randomDoubleBetween(3.0002, 32.0, true);
+            logger.info("-> indexingLoad {}", indexingLoad);
+
+            assertThat(DataStreamAutoShardingService.computeOptimalNumberOfShards(minWriteThreads, maxWriteThreads, indexingLoad), is(3L));
+        }
+
+        {
+            double indexingLoad = 49.0;
+            logger.info("-> indexingLoad {}", indexingLoad);
+            assertThat(DataStreamAutoShardingService.computeOptimalNumberOfShards(minWriteThreads, maxWriteThreads, indexingLoad), is(4L));
+        }
+
+        {
+            double indexingLoad = 70.0;
+            logger.info("-> indexingLoad {}", indexingLoad);
+            assertThat(DataStreamAutoShardingService.computeOptimalNumberOfShards(minWriteThreads, maxWriteThreads, indexingLoad), is(5L));
+        }
+
+        {
+            double indexingLoad = 100.0;
+            logger.info("-> indexingLoad {}", indexingLoad);
+            assertThat(DataStreamAutoShardingService.computeOptimalNumberOfShards(minWriteThreads, maxWriteThreads, indexingLoad), is(7L));
+        }
+
+        {
+            double indexingLoad = 180.0;
+            logger.info("-> indexingLoad {}", indexingLoad);
+            assertThat(DataStreamAutoShardingService.computeOptimalNumberOfShards(minWriteThreads, maxWriteThreads, indexingLoad), is(12L));
+        }
+    }
+
+    public void testGetMaxIndexLoadWithinCoolingPeriod() {
+        final TimeValue coolingPeriod = TimeValue.timeValueDays(3);
+
+        final Metadata.Builder metadataBuilder = Metadata.builder();
+        final int numberOfBackingIndicesOutsideCoolingPeriod = randomIntBetween(3, 10);
+        final int numberOfBackingIndicesWithinCoolingPeriod = randomIntBetween(3, 10);
+        final List<Index> backingIndices = new ArrayList<>();
+        final String dataStreamName = "logs";
+        long now = System.currentTimeMillis();
+
+        // to cover the entire cooling period we'll also include the backing index right before the index age calculation
+        // this flag makes that index have a very low or very high write load
+        boolean lastIndexBeforeCoolingPeriodHasLowWriteLoad = randomBoolean();
+        for (int i = 0; i < numberOfBackingIndicesOutsideCoolingPeriod; i++) {
+            long creationDate = now - (coolingPeriod.millis() * 2);
+            IndexMetadata indexMetadata = createIndexMetadata(
+                DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size(), creationDate),
+                1,
+                getWriteLoad(1, 999.0),
+                creationDate
+            );
+
+            if (lastIndexBeforeCoolingPeriodHasLowWriteLoad) {
+                indexMetadata = createIndexMetadata(
+                    DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size(), creationDate),
+                    1,
+                    getWriteLoad(1, 1.0),
+                    creationDate
+                );
+            }
+            backingIndices.add(indexMetadata.getIndex());
+            metadataBuilder.put(indexMetadata, false);
+        }
+
+        for (int i = 0; i < numberOfBackingIndicesWithinCoolingPeriod; i++) {
+            final long createdAt = now - (coolingPeriod.getMillis() / 2);
+            IndexMetadata indexMetadata;
+            if (i == numberOfBackingIndicesWithinCoolingPeriod - 1) {
+                indexMetadata = createIndexMetadata(
+                    DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size(), createdAt),
+                    3,
+                    getWriteLoad(3, 5.0), // max write index within cooling period
+                    createdAt
+                );
+            } else {
+                indexMetadata = createIndexMetadata(
+                    DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size(), createdAt),
+                    3,
+                    getWriteLoad(3, 3.0), // each backing index has a write load of 9.0
+                    createdAt
+                );
+            }
+            backingIndices.add(indexMetadata.getIndex());
+            metadataBuilder.put(indexMetadata, false);
+        }
+
+        final String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size());
+        final IndexMetadata writeIndexMetadata = createIndexMetadata(writeIndexName, 3, getWriteLoad(3, 1.0), System.currentTimeMillis());
+        backingIndices.add(writeIndexMetadata.getIndex());
+        metadataBuilder.put(writeIndexMetadata, false);
+
+        final DataStream dataStream = new DataStream(
+            dataStreamName,
+            backingIndices,
+            backingIndices.size(),
+            Collections.emptyMap(),
+            false,
+            false,
+            false,
+            false,
+            IndexMode.STANDARD
+        );
+
+        metadataBuilder.put(dataStream);
+
+        double maxIndexLoadWithinCoolingPeriod = DataStreamAutoShardingService.getMaxIndexLoadWithinCoolingPeriod(
+            metadataBuilder.build(),
+            dataStream,
+            3.0,
+            coolingPeriod,
+            () -> now
+        );
+        // to cover the entire cooldown period, the last index before the cooling period is taken into account
+        assertThat(maxIndexLoadWithinCoolingPeriod, is(lastIndexBeforeCoolingPeriodHasLowWriteLoad ? 15.0 : 999.0));
+    }
+
+    public void testAutoShardingResultValidation() {
+        {
+            // throws exception when constructed using types that shouldn't report cooldowns
+            expectThrows(
+                IllegalArgumentException.class,
+                () -> new AutoShardingResult(INCREASE_SHARDS, 1, 3, TimeValue.timeValueSeconds(3), 3.0)
+            );
+
+            expectThrows(
+                IllegalArgumentException.class,
+                () -> new AutoShardingResult(DECREASE_SHARDS, 3, 1, TimeValue.timeValueSeconds(3), 1.0)
+            );
+
+        }
+
+        {
+            // we can successfully create results with cooldown period for the designated types
+            AutoShardingResult cooldownPreventedIncrease = new AutoShardingResult(
+                COOLDOWN_PREVENTED_INCREASE,
+                1,
+                3,
+                TimeValue.timeValueSeconds(3),
+                3.0
+            );
+            assertThat(cooldownPreventedIncrease.coolDownRemaining(), is(TimeValue.timeValueSeconds(3)));
+
+            AutoShardingResult cooldownPreventedDecrease = new AutoShardingResult(
+                COOLDOWN_PREVENTED_DECREASE,
+                3,
+                1,
+                TimeValue.timeValueSeconds(7),
+                1.0
+            );
+            assertThat(cooldownPreventedDecrease.coolDownRemaining(), is(TimeValue.timeValueSeconds(7)));
+        }
+    }
+
+    private DataStream createDataStream(
+        Metadata.Builder builder,
+        String dataStreamName,
+        int numberOfShards,
+        Long now,
+        List<Long> indicesCreationDate,
+        IndexWriteLoad backingIndicesWriteLoad,
+        @Nullable DataStreamAutoShardingEvent autoShardingEvent
+    ) {
+        final List<Index> backingIndices = new ArrayList<>();
+        int backingIndicesCount = indicesCreationDate.size();
+        for (int k = 0; k < indicesCreationDate.size(); k++) {
+            long createdAt = indicesCreationDate.get(k);
+            IndexMetadata.Builder indexMetaBuilder;
+            if (k < backingIndicesCount - 1) {
+                indexMetaBuilder = IndexMetadata.builder(
+                    createIndexMetadata(
+                        DataStream.getDefaultBackingIndexName(dataStreamName, k + 1),
+                        numberOfShards,
+                        backingIndicesWriteLoad,
+                        createdAt
+                    )
+                );
+                // add rollover info only for non-write indices
+                MaxAgeCondition rolloverCondition = new MaxAgeCondition(TimeValue.timeValueMillis(now - 2000L));
+                indexMetaBuilder.putRolloverInfo(new RolloverInfo(dataStreamName, List.of(rolloverCondition), now - 2000L));
+            } else {
+                // write index
+                indexMetaBuilder = IndexMetadata.builder(
+                    createIndexMetadata(DataStream.getDefaultBackingIndexName(dataStreamName, k + 1), numberOfShards, null, createdAt)
+                );
+            }
+            IndexMetadata indexMetadata = indexMetaBuilder.build();
+            builder.put(indexMetadata, false);
+            backingIndices.add(indexMetadata.getIndex());
+        }
+        return new DataStream(
+            dataStreamName,
+            backingIndices,
+            backingIndicesCount,
+            null,
+            false,
+            false,
+            false,
+            false,
+            null,
+            null,
+            false,
+            List.of(),
+            autoShardingEvent
+        );
+    }
+
+    private IndexMetadata createIndexMetadata(
+        String indexName,
+        int numberOfShards,
+        @Nullable IndexWriteLoad indexWriteLoad,
+        long createdAt
+    ) {
+        return IndexMetadata.builder(indexName)
+            .settings(
+                Settings.builder()
+                    .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards)
+                    .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+                    .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
+                    .build()
+            )
+            .stats(indexWriteLoad == null ? null : new IndexMetadataStats(indexWriteLoad, 1, 1))
+            .creationDate(createdAt)
+            .build();
+    }
+
+    private IndexWriteLoad getWriteLoad(int numberOfShards, double shardWriteLoad) {
+        IndexWriteLoad.Builder builder = IndexWriteLoad.builder(numberOfShards);
+        for (int shardId = 0; shardId < numberOfShards; shardId++) {
+            builder.withShardWriteLoad(shardId, shardWriteLoad, randomLongBetween(1, 10));
+        }
+        return builder.build();
+    }
+
+}

+ 62 - 0
server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamAutoShardingEventTests.java

@@ -0,0 +1,62 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.cluster.metadata;
+
+import org.elasticsearch.cluster.Diff;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.SimpleDiffableSerializationTestCase;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+
+public class DataStreamAutoShardingEventTests extends SimpleDiffableSerializationTestCase<DataStreamAutoShardingEvent> {
+
+    @Override
+    protected DataStreamAutoShardingEvent doParseInstance(XContentParser parser) throws IOException {
+        return DataStreamAutoShardingEvent.fromXContent(parser);
+    }
+
+    @Override
+    protected Writeable.Reader<DataStreamAutoShardingEvent> instanceReader() {
+        return DataStreamAutoShardingEvent::new;
+    }
+
+    @Override
+    protected DataStreamAutoShardingEvent createTestInstance() {
+        return DataStreamAutoShardingEventTests.randomInstance();
+    }
+
+    @Override
+    protected DataStreamAutoShardingEvent mutateInstance(DataStreamAutoShardingEvent instance) {
+        String triggerIndex = instance.triggerIndexName();
+        long timestamp = instance.timestamp();
+        int targetNumberOfShards = instance.targetNumberOfShards();
+        switch (randomInt(2)) {
+            case 0 -> triggerIndex = randomValueOtherThan(triggerIndex, () -> randomAlphaOfLengthBetween(10, 50));
+            case 1 -> timestamp = randomValueOtherThan(timestamp, ESTestCase::randomNonNegativeLong);
+            case 2 -> targetNumberOfShards = randomValueOtherThan(targetNumberOfShards, ESTestCase::randomNonNegativeInt);
+        }
+        return new DataStreamAutoShardingEvent(triggerIndex, targetNumberOfShards, timestamp);
+    }
+
+    static DataStreamAutoShardingEvent randomInstance() {
+        return new DataStreamAutoShardingEvent(randomAlphaOfLengthBetween(10, 40), randomNonNegativeInt(), randomNonNegativeLong());
+    }
+
+    @Override
+    protected DataStreamAutoShardingEvent makeTestChanges(DataStreamAutoShardingEvent testInstance) {
+        return mutateInstance(testInstance);
+    }
+
+    @Override
+    protected Writeable.Reader<Diff<DataStreamAutoShardingEvent>> diffReader() {
+        return DataStreamAutoShardingEvent::readDiffFrom;
+    }
+}

+ 117 - 7
server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java

@@ -39,6 +39,7 @@ import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -93,7 +94,8 @@ public class DataStreamTests extends AbstractXContentSerializingTestCase<DataStr
         var lifecycle = instance.getLifecycle();
         var failureStore = instance.isFailureStore();
         var failureIndices = instance.getFailureIndices();
-        switch (between(0, 10)) {
+        var autoShardingEvent = instance.getAutoShardingEvent();
+        switch (between(0, 11)) {
             case 0 -> name = randomAlphaOfLength(10);
             case 1 -> indices = randomValueOtherThan(List.of(), DataStreamTestHelper::randomIndexInstances);
             case 2 -> generation = instance.getGeneration() + randomIntBetween(1, 10);
@@ -130,6 +132,15 @@ public class DataStreamTests extends AbstractXContentSerializingTestCase<DataStr
                     failureStore = true;
                 }
             }
+            case 11 -> {
+                autoShardingEvent = randomBoolean() && autoShardingEvent != null
+                    ? null
+                    : new DataStreamAutoShardingEvent(
+                        indices.get(indices.size() - 1).getName(),
+                        randomIntBetween(1, 10),
+                        randomMillisUpToYear9999()
+                    );
+            }
         }
 
         return new DataStream(
@@ -144,7 +155,8 @@ public class DataStreamTests extends AbstractXContentSerializingTestCase<DataStr
             indexMode,
             lifecycle,
             failureStore,
-            failureIndices
+            failureIndices,
+            autoShardingEvent
         );
     }
 
@@ -201,7 +213,8 @@ public class DataStreamTests extends AbstractXContentSerializingTestCase<DataStr
             indexMode,
             ds.getLifecycle(),
             ds.isFailureStore(),
-            ds.getFailureIndices()
+            ds.getFailureIndices(),
+            ds.getAutoShardingEvent()
         );
         var newCoordinates = ds.nextWriteIndexAndGeneration(Metadata.EMPTY_METADATA);
 
@@ -228,7 +241,8 @@ public class DataStreamTests extends AbstractXContentSerializingTestCase<DataStr
             IndexMode.TIME_SERIES,
             ds.getLifecycle(),
             ds.isFailureStore(),
-            ds.getFailureIndices()
+            ds.getFailureIndices(),
+            ds.getAutoShardingEvent()
         );
         var newCoordinates = ds.nextWriteIndexAndGeneration(Metadata.EMPTY_METADATA);
 
@@ -590,7 +604,8 @@ public class DataStreamTests extends AbstractXContentSerializingTestCase<DataStr
             preSnapshotDataStream.getIndexMode(),
             preSnapshotDataStream.getLifecycle(),
             preSnapshotDataStream.isFailureStore(),
-            preSnapshotDataStream.getFailureIndices()
+            preSnapshotDataStream.getFailureIndices(),
+            preSnapshotDataStream.getAutoShardingEvent()
         );
 
         var reconciledDataStream = postSnapshotDataStream.snapshot(
@@ -634,7 +649,8 @@ public class DataStreamTests extends AbstractXContentSerializingTestCase<DataStr
             preSnapshotDataStream.getIndexMode(),
             preSnapshotDataStream.getLifecycle(),
             preSnapshotDataStream.isFailureStore(),
-            preSnapshotDataStream.getFailureIndices()
+            preSnapshotDataStream.getFailureIndices(),
+            preSnapshotDataStream.getAutoShardingEvent()
         );
 
         assertNull(postSnapshotDataStream.snapshot(preSnapshotDataStream.getIndices().stream().map(Index::getName).toList()));
@@ -1654,7 +1670,8 @@ public class DataStreamTests extends AbstractXContentSerializingTestCase<DataStr
             lifecycle,
             failureStore,
             failureIndices,
-            false
+            false,
+            null
         );
 
         try (XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent())) {
@@ -1671,6 +1688,99 @@ public class DataStreamTests extends AbstractXContentSerializingTestCase<DataStr
         }
     }
 
+    public void testGetIndicesWithinMaxAgeRange() {
+        final TimeValue maxIndexAge = TimeValue.timeValueDays(7);
+
+        final Metadata.Builder metadataBuilder = Metadata.builder();
+        final int numberOfBackingIndicesOlderThanMinAge = randomIntBetween(0, 10);
+        final int numberOfBackingIndicesWithinMinAnge = randomIntBetween(0, 10);
+        final int numberOfShards = 1;
+        final List<Index> backingIndices = new ArrayList<>();
+        final String dataStreamName = "logs-es";
+        final List<Index> backingIndicesOlderThanMinAge = new ArrayList<>();
+        for (int i = 0; i < numberOfBackingIndicesOlderThanMinAge; i++) {
+            long creationDate = System.currentTimeMillis() - maxIndexAge.millis() * 2;
+            final IndexMetadata indexMetadata = createIndexMetadata(
+                DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size(), creationDate),
+                randomIndexWriteLoad(numberOfShards),
+                creationDate
+            );
+            backingIndices.add(indexMetadata.getIndex());
+            backingIndicesOlderThanMinAge.add(indexMetadata.getIndex());
+            metadataBuilder.put(indexMetadata, false);
+        }
+
+        final List<Index> backingIndicesWithinMinAge = new ArrayList<>();
+        for (int i = 0; i < numberOfBackingIndicesWithinMinAnge; i++) {
+            final long createdAt = System.currentTimeMillis() - (maxIndexAge.getMillis() / 2);
+            final IndexMetadata indexMetadata = createIndexMetadata(
+                DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size(), createdAt),
+                randomIndexWriteLoad(numberOfShards),
+                createdAt
+            );
+            backingIndices.add(indexMetadata.getIndex());
+            backingIndicesWithinMinAge.add(indexMetadata.getIndex());
+            metadataBuilder.put(indexMetadata, false);
+        }
+
+        final String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size());
+        final IndexMetadata writeIndexMetadata = createIndexMetadata(writeIndexName, null, System.currentTimeMillis());
+        backingIndices.add(writeIndexMetadata.getIndex());
+        metadataBuilder.put(writeIndexMetadata, false);
+
+        final DataStream dataStream = new DataStream(
+            dataStreamName,
+            backingIndices,
+            backingIndices.size(),
+            Collections.emptyMap(),
+            false,
+            false,
+            false,
+            false,
+            randomBoolean() ? IndexMode.STANDARD : IndexMode.TIME_SERIES
+        );
+
+        metadataBuilder.put(dataStream);
+
+        final List<Index> indicesWithinMaxAgeRange = DataStream.getIndicesWithinMaxAgeRange(
+            dataStream,
+            metadataBuilder::getSafe,
+            maxIndexAge,
+            System::currentTimeMillis
+        );
+
+        final List<Index> expectedIndicesWithinMaxAgeRange = new ArrayList<>();
+        if (numberOfBackingIndicesOlderThanMinAge > 0) {
+            expectedIndicesWithinMaxAgeRange.add(backingIndicesOlderThanMinAge.get(backingIndicesOlderThanMinAge.size() - 1));
+        }
+        expectedIndicesWithinMaxAgeRange.addAll(backingIndicesWithinMinAge);
+        expectedIndicesWithinMaxAgeRange.add(writeIndexMetadata.getIndex());
+
+        assertThat(indicesWithinMaxAgeRange, is(equalTo(expectedIndicesWithinMaxAgeRange)));
+    }
+
+    private IndexWriteLoad randomIndexWriteLoad(int numberOfShards) {
+        IndexWriteLoad.Builder builder = IndexWriteLoad.builder(numberOfShards);
+        for (int shardId = 0; shardId < numberOfShards; shardId++) {
+            builder.withShardWriteLoad(shardId, randomDoubleBetween(0, 64, true), randomLongBetween(1, 10));
+        }
+        return builder.build();
+    }
+
+    private IndexMetadata createIndexMetadata(String indexName, IndexWriteLoad indexWriteLoad, long createdAt) {
+        return IndexMetadata.builder(indexName)
+            .settings(
+                Settings.builder()
+                    .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
+                    .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+                    .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
+                    .build()
+            )
+            .stats(indexWriteLoad == null ? null : new IndexMetadataStats(indexWriteLoad, 1, 1))
+            .creationDate(createdAt)
+            .build();
+    }
+
     private record DataStreamMetadata(Long creationTimeInMillis, Long rolloverTimeInMillis, Long originationTimeInMillis) {
         public static DataStreamMetadata dataStreamMetadata(Long creationTimeInMillis, Long rolloverTimeInMillis) {
             return new DataStreamMetadata(creationTimeInMillis, rolloverTimeInMillis, null);

+ 2 - 1
server/src/test/java/org/elasticsearch/cluster/metadata/MetadataDataStreamsServiceTests.java

@@ -356,7 +356,8 @@ public class MetadataDataStreamsServiceTests extends MapperServiceTestCase {
             original.getIndexMode(),
             original.getLifecycle(),
             original.isFailureStore(),
-            original.getFailureIndices()
+            original.getFailureIndices(),
+            original.getAutoShardingEvent()
         );
         var brokenState = ClusterState.builder(state).metadata(Metadata.builder(state.getMetadata()).put(broken).build()).build();
 

+ 10 - 1
test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java

@@ -74,6 +74,7 @@ import static org.elasticsearch.test.ESTestCase.generateRandomStringArray;
 import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
 import static org.elasticsearch.test.ESTestCase.randomBoolean;
 import static org.elasticsearch.test.ESTestCase.randomFrom;
+import static org.elasticsearch.test.ESTestCase.randomIntBetween;
 import static org.elasticsearch.test.ESTestCase.randomMap;
 import static org.elasticsearch.test.ESTestCase.randomMillisUpToYear9999;
 import static org.mockito.ArgumentMatchers.any;
@@ -136,7 +137,8 @@ public final class DataStreamTestHelper {
             null,
             lifecycle,
             failureStores.size() > 0,
-            failureStores
+            failureStores,
+            null
         );
     }
 
@@ -307,7 +309,14 @@ public final class DataStreamTestHelper {
             randomBoolean() ? DataStreamLifecycle.newBuilder().dataRetention(randomMillisUpToYear9999()).build() : null,
             failureStore,
             failureIndices,
+            randomBoolean(),
             randomBoolean()
+                ? new DataStreamAutoShardingEvent(
+                    indices.get(indices.size() - 1).getName(),
+                    randomIntBetween(1, 10),
+                    randomMillisUpToYear9999()
+                )
+                : null
         );
     }
 

+ 4 - 2
x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/TransportPutFollowAction.java

@@ -334,7 +334,8 @@ public final class TransportPutFollowAction extends TransportMasterNodeAction<Pu
                 remoteDataStream.getIndexMode(),
                 remoteDataStream.getLifecycle(),
                 remoteDataStream.isFailureStore(),
-                remoteDataStream.getFailureIndices()
+                remoteDataStream.getFailureIndices(),
+                remoteDataStream.getAutoShardingEvent()
             );
         } else {
             if (localDataStream.isReplicated() == false) {
@@ -387,7 +388,8 @@ public final class TransportPutFollowAction extends TransportMasterNodeAction<Pu
                 localDataStream.getIndexMode(),
                 localDataStream.getLifecycle(),
                 localDataStream.isFailureStore(),
-                localDataStream.getFailureIndices()
+                localDataStream.getFailureIndices(),
+                localDataStream.getAutoShardingEvent()
             );
         }
     }

+ 2 - 1
x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/action/DataStreamLifecycleUsageTransportActionIT.java

@@ -134,7 +134,8 @@ public class DataStreamLifecycleUsageTransportActionIT extends ESIntegTestCase {
                     IndexMode.STANDARD,
                     lifecycle,
                     false,
-                    List.of()
+                    List.of(),
+                    null
                 );
                 dataStreamMap.put(dataStream.getName(), dataStream);
             }

+ 7 - 20
x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java

@@ -76,7 +76,13 @@ class LicensedWriteLoadForecaster implements WriteLoadForecaster {
 
         clearPreviousForecast(dataStream, metadata);
 
-        final List<IndexWriteLoad> indicesWriteLoadWithinMaxAgeRange = getIndicesWithinMaxAgeRange(dataStream, metadata).stream()
+        final List<IndexWriteLoad> indicesWriteLoadWithinMaxAgeRange = DataStream.getIndicesWithinMaxAgeRange(
+            dataStream,
+            metadata::getSafe,
+            maxIndexAge,
+            threadPool::absoluteTimeInMillis
+        )
+            .stream()
             .filter(index -> index.equals(dataStream.getWriteIndex()) == false)
             .map(metadata::getSafe)
             .map(IndexMetadata::getStats)
@@ -134,25 +140,6 @@ class LicensedWriteLoadForecaster implements WriteLoadForecaster {
         return totalShardUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(totalWeightedWriteLoad / totalShardUptime);
     }
 
-    // Visible for testing
-    List<Index> getIndicesWithinMaxAgeRange(DataStream dataStream, Metadata.Builder metadata) {
-        final List<Index> dataStreamIndices = dataStream.getIndices();
-        final long currentTimeMillis = threadPool.absoluteTimeInMillis();
-        // Consider at least 1 index (including the write index) for cases where rollovers happen less often than maxIndexAge
-        int firstIndexWithinAgeRange = Math.max(dataStreamIndices.size() - 2, 0);
-        for (int i = 0; i < dataStreamIndices.size(); i++) {
-            Index index = dataStreamIndices.get(i);
-            final IndexMetadata indexMetadata = metadata.getSafe(index);
-            final long indexAge = currentTimeMillis - indexMetadata.getCreationDate();
-            if (indexAge < maxIndexAge.getMillis()) {
-                // We need to consider the previous index too in order to cover the entire max-index-age range.
-                firstIndexWithinAgeRange = i == 0 ? 0 : i - 1;
-                break;
-            }
-        }
-        return dataStreamIndices.subList(firstIndexWithinAgeRange, dataStreamIndices.size());
-    }
-
     @Override
     @SuppressForbidden(reason = "This is the only place where IndexMetadata#getForecastedWriteLoad is allowed to be used")
     public OptionalDouble getForecastedWriteLoad(IndexMetadata indexMetadata) {

+ 0 - 59
x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java

@@ -287,65 +287,6 @@ public class LicensedWriteLoadForecasterTests extends ESTestCase {
         }
     }
 
-    public void testGetIndicesWithinMaxAgeRange() {
-        final TimeValue maxIndexAge = TimeValue.timeValueDays(7);
-        final LicensedWriteLoadForecaster writeLoadForecaster = new LicensedWriteLoadForecaster(() -> true, threadPool, maxIndexAge);
-
-        final Metadata.Builder metadataBuilder = Metadata.builder();
-        final int numberOfBackingIndicesOlderThanMinAge = randomIntBetween(0, 10);
-        final int numberOfBackingIndicesWithinMinAnge = randomIntBetween(0, 10);
-        final int numberOfShards = 1;
-        final List<Index> backingIndices = new ArrayList<>();
-        final String dataStreamName = "logs-es";
-        final List<Index> backingIndicesOlderThanMinAge = new ArrayList<>();
-        for (int i = 0; i < numberOfBackingIndicesOlderThanMinAge; i++) {
-            long creationDate = System.currentTimeMillis() - maxIndexAge.millis() * 2;
-            final IndexMetadata indexMetadata = createIndexMetadata(
-                DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size(), creationDate),
-                numberOfShards,
-                randomIndexWriteLoad(numberOfShards),
-                creationDate
-            );
-            backingIndices.add(indexMetadata.getIndex());
-            backingIndicesOlderThanMinAge.add(indexMetadata.getIndex());
-            metadataBuilder.put(indexMetadata, false);
-        }
-
-        final List<Index> backingIndicesWithinMinAge = new ArrayList<>();
-        for (int i = 0; i < numberOfBackingIndicesWithinMinAnge; i++) {
-            final long createdAt = System.currentTimeMillis() - (maxIndexAge.getMillis() / 2);
-            final IndexMetadata indexMetadata = createIndexMetadata(
-                DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size(), createdAt),
-                numberOfShards,
-                randomIndexWriteLoad(numberOfShards),
-                createdAt
-            );
-            backingIndices.add(indexMetadata.getIndex());
-            backingIndicesWithinMinAge.add(indexMetadata.getIndex());
-            metadataBuilder.put(indexMetadata, false);
-        }
-
-        final String writeIndexName = DataStream.getDefaultBackingIndexName(dataStreamName, backingIndices.size());
-        final IndexMetadata writeIndexMetadata = createIndexMetadata(writeIndexName, numberOfShards, null, System.currentTimeMillis());
-        backingIndices.add(writeIndexMetadata.getIndex());
-        metadataBuilder.put(writeIndexMetadata, false);
-
-        final DataStream dataStream = createDataStream(dataStreamName, backingIndices);
-
-        metadataBuilder.put(dataStream);
-
-        final List<Index> indicesWithinMaxAgeRange = writeLoadForecaster.getIndicesWithinMaxAgeRange(dataStream, metadataBuilder);
-
-        final List<Index> expectedIndicesWithinMaxAgeRange = new ArrayList<>();
-        if (numberOfBackingIndicesOlderThanMinAge > 0) {
-            expectedIndicesWithinMaxAgeRange.add(backingIndicesOlderThanMinAge.get(backingIndicesOlderThanMinAge.size() - 1));
-        }
-        expectedIndicesWithinMaxAgeRange.addAll(backingIndicesWithinMinAge);
-        expectedIndicesWithinMaxAgeRange.add(writeIndexMetadata.getIndex());
-
-        assertThat(indicesWithinMaxAgeRange, is(equalTo(expectedIndicesWithinMaxAgeRange)));
-    }
-
     private IndexWriteLoad randomIndexWriteLoad(int numberOfShards) {
         IndexWriteLoad.Builder builder = IndexWriteLoad.builder(numberOfShards);
         for (int shardId = 0; shardId < numberOfShards; shardId++) {