Browse Source

Auto sharding uses the sum of shards write loads (#106785)

Data stream auto sharding uses the index write load to decide the
optimal number of shards. We read this previously from the indexing
stats output, using the `total/write_load` value however, this
proved to be wrong as that value takes into account the search shard
write load (which will always be 0).
Even more, the `total/write_load` value averages the write loads for
every shard so you can end up with indices that only have one primary
and one replica, with the primary shard having a write load of 1.7 and
the `total/write_load` reporting to be `0.8`.

For data stream auto sharding we're interested in the **total** index
write load, defined as the sum of all the shards write loads (yes we
can include the replica shard write loads in this sum as they're 0).

This PR changes the rollover write load computation to sum all the shard
write loads for the data stream write index, and in the
`DataStreamAutoShardingService` when looking at the historic write load
over the cooldown period to, again, sum the write loads of every shard
in the index metadata/stats.
Andrei Dan 1 year ago
parent
commit
9776f54928

+ 72 - 113
modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/DataStreamAutoshardingIT.java

@@ -129,22 +129,17 @@ public class DataStreamAutoshardingIT extends ESIntegTestCase {
             for (int i = 0; i < firstGenerationMeta.getNumberOfShards(); i++) {
                 // the shard stats will yield a write load of 75.0 which will make the auto sharding service recommend an optimal number
                 // of 5 shards
-                shards.add(getShardStats(firstGenerationMeta, i, 75, assignedShardNodeId));
-            }
-
-            for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
-                MockTransportService.getInstance(node.getName())
-                    .addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
-                        TransportIndicesStatsAction instance = internalCluster().getInstance(
-                            TransportIndicesStatsAction.class,
-                            node.getName()
-                        );
-                        channel.sendResponse(
-                            instance.new NodeResponse(node.getId(), firstGenerationMeta.getNumberOfShards(), shards, List.of())
-                        );
-                    });
+                shards.add(
+                    getShardStats(
+                        firstGenerationMeta,
+                        i,
+                        (long) Math.ceil(75.0 / firstGenerationMeta.getNumberOfShards()),
+                        assignedShardNodeId
+                    )
+                );
             }
 
+            mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, firstGenerationMeta, shards);
             assertAcked(indicesAdmin().rolloverIndex(new RolloverRequest(dataStreamName, null)).actionGet());
 
             ClusterState clusterStateAfterRollover = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
@@ -180,21 +175,16 @@ public class DataStreamAutoshardingIT extends ESIntegTestCase {
             for (int i = 0; i < secondGenerationMeta.getNumberOfShards(); i++) {
                 // the shard stats will yield a write load of 100.0 which will make the auto sharding service recommend an optimal number of
                 // 7 shards
-                shards.add(getShardStats(secondGenerationMeta, i, 100, assignedShardNodeId));
-            }
-
-            for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
-                MockTransportService.getInstance(node.getName())
-                    .addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
-                        TransportIndicesStatsAction instance = internalCluster().getInstance(
-                            TransportIndicesStatsAction.class,
-                            node.getName()
-                        );
-                        channel.sendResponse(
-                            instance.new NodeResponse(node.getId(), secondGenerationMeta.getNumberOfShards(), shards, List.of())
-                        );
-                    });
+                shards.add(
+                    getShardStats(
+                        secondGenerationMeta,
+                        i,
+                        (long) Math.ceil(100.0 / secondGenerationMeta.getNumberOfShards()),
+                        assignedShardNodeId
+                    )
+                );
             }
+            mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, secondGenerationMeta, shards);
 
             RolloverResponse response = indicesAdmin().rolloverIndex(new RolloverRequest(dataStreamName, null)).actionGet();
             assertAcked(response);
@@ -232,21 +222,11 @@ public class DataStreamAutoshardingIT extends ESIntegTestCase {
                 for (int i = 0; i < thirdGenIndex.getNumberOfShards(); i++) {
                     // the shard stats will yield a write load of 100.0 which will make the auto sharding service recommend an optimal
                     // number of 7 shards
-                    shards.add(getShardStats(thirdGenIndex, i, 100, assignedShardNodeId));
-                }
-
-                for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
-                    MockTransportService.getInstance(node.getName())
-                        .addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
-                            TransportIndicesStatsAction instance = internalCluster().getInstance(
-                                TransportIndicesStatsAction.class,
-                                node.getName()
-                            );
-                            channel.sendResponse(
-                                instance.new NodeResponse(node.getId(), thirdGenIndex.getNumberOfShards(), shards, List.of())
-                            );
-                        });
+                    shards.add(
+                        getShardStats(thirdGenIndex, i, (long) Math.ceil(100.0 / thirdGenIndex.getNumberOfShards()), assignedShardNodeId)
+                    );
                 }
+                mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, thirdGenIndex, shards);
 
                 RolloverRequest request = new RolloverRequest(dataStreamName, null);
                 request.setConditions(RolloverConditions.newBuilder().addMaxIndexDocsCondition(1_000_000L).build());
@@ -309,22 +289,10 @@ public class DataStreamAutoshardingIT extends ESIntegTestCase {
             for (int i = 0; i < firstGenerationMeta.getNumberOfShards(); i++) {
                 // the shard stats will yield a write load of 2.0 which will make the auto sharding service recommend an optimal number
                 // of 2 shards
-                shards.add(getShardStats(firstGenerationMeta, i, 2, assignedShardNodeId));
-            }
-
-            for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
-                MockTransportService.getInstance(node.getName())
-                    .addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
-                        TransportIndicesStatsAction instance = internalCluster().getInstance(
-                            TransportIndicesStatsAction.class,
-                            node.getName()
-                        );
-                        channel.sendResponse(
-                            instance.new NodeResponse(node.getId(), firstGenerationMeta.getNumberOfShards(), shards, List.of())
-                        );
-                    });
+                shards.add(getShardStats(firstGenerationMeta, i, i < 2 ? 1 : 0, assignedShardNodeId));
             }
 
+            mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, firstGenerationMeta, shards);
             assertAcked(indicesAdmin().rolloverIndex(new RolloverRequest(dataStreamName, null)).actionGet());
 
             ClusterState clusterStateAfterRollover = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
@@ -356,23 +324,11 @@ public class DataStreamAutoshardingIT extends ESIntegTestCase {
                     .index(dataStreamBeforeRollover.getIndices().get(1));
                 List<ShardStats> shards = new ArrayList<>(secondGenerationIndex.getNumberOfShards());
                 for (int i = 0; i < secondGenerationIndex.getNumberOfShards(); i++) {
-                    // the shard stats will yield a write load of 2.0 which will make the auto sharding service recommend an optimal
-                    // number of 2 shards
-                    shards.add(getShardStats(secondGenerationIndex, i, 2, assignedShardNodeId));
-                }
-
-                for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
-                    MockTransportService.getInstance(node.getName())
-                        .addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
-                            TransportIndicesStatsAction instance = internalCluster().getInstance(
-                                TransportIndicesStatsAction.class,
-                                node.getName()
-                            );
-                            channel.sendResponse(
-                                instance.new NodeResponse(node.getId(), secondGenerationIndex.getNumberOfShards(), shards, List.of())
-                            );
-                        });
+                    // the shard stats will yield a write load of 2.0 which will make the auto sharding service recommend an
+                    // optimal number of 2 shards
+                    shards.add(getShardStats(secondGenerationIndex, i, i < 2 ? 1 : 0, assignedShardNodeId));
                 }
+                mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, secondGenerationIndex, shards);
 
                 RolloverRequest request = new RolloverRequest(dataStreamName, null);
                 // adding condition that does NOT match
@@ -438,6 +394,11 @@ public class DataStreamAutoshardingIT extends ESIntegTestCase {
             IndexMetadata firstGenerationMeta = clusterStateBeforeRollover.getMetadata().index(firstGenerationIndex);
 
             List<ShardStats> shards = new ArrayList<>(firstGenerationMeta.getNumberOfShards());
+            String assignedShardNodeId = clusterStateBeforeRollover.routingTable()
+                .index(dataStreamBeforeRollover.getWriteIndex())
+                .shard(0)
+                .primaryShard()
+                .currentNodeId();
             for (int i = 0; i < firstGenerationMeta.getNumberOfShards(); i++) {
                 // the shard stats will yield a write load of 75.0 which will make the auto sharding service recommend an optimal number
                 // of 5 shards
@@ -445,29 +406,13 @@ public class DataStreamAutoshardingIT extends ESIntegTestCase {
                     getShardStats(
                         firstGenerationMeta,
                         i,
-                        75,
-                        clusterStateBeforeRollover.routingTable()
-                            .index(dataStreamBeforeRollover.getWriteIndex())
-                            .shard(0)
-                            .primaryShard()
-                            .currentNodeId()
+                        (long) Math.ceil(75.0 / firstGenerationMeta.getNumberOfShards()),
+                        assignedShardNodeId
                     )
                 );
             }
 
-            for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
-                MockTransportService.getInstance(node.getName())
-                    .addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
-                        TransportIndicesStatsAction instance = internalCluster().getInstance(
-                            TransportIndicesStatsAction.class,
-                            node.getName()
-                        );
-                        channel.sendResponse(
-                            instance.new NodeResponse(node.getId(), firstGenerationMeta.getNumberOfShards(), shards, List.of())
-                        );
-                    });
-            }
-
+            mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, firstGenerationMeta, shards);
             assertAcked(indicesAdmin().rolloverIndex(new RolloverRequest(dataStreamName, null)).actionGet());
 
             ClusterState clusterStateAfterRollover = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
@@ -491,37 +436,22 @@ public class DataStreamAutoshardingIT extends ESIntegTestCase {
                 ClusterState clusterStateBeforeRollover = internalCluster().getCurrentMasterNodeInstance(ClusterService.class).state();
                 DataStream dataStreamBeforeRollover = clusterStateBeforeRollover.getMetadata().dataStreams().get(dataStreamName);
 
+                String assignedShardNodeId = clusterStateBeforeRollover.routingTable()
+                    .index(dataStreamBeforeRollover.getWriteIndex())
+                    .shard(0)
+                    .primaryShard()
+                    .currentNodeId();
                 IndexMetadata secondGenIndex = clusterStateBeforeRollover.metadata().index(dataStreamBeforeRollover.getIndices().get(1));
                 List<ShardStats> shards = new ArrayList<>(secondGenIndex.getNumberOfShards());
                 for (int i = 0; i < secondGenIndex.getNumberOfShards(); i++) {
                     // the shard stats will yield a write load of 100.0 which will make the auto sharding service recommend an optimal
                     // number of 7 shards
                     shards.add(
-                        getShardStats(
-                            secondGenIndex,
-                            i,
-                            100,
-                            clusterStateBeforeRollover.routingTable()
-                                .index(dataStreamBeforeRollover.getWriteIndex())
-                                .shard(i)
-                                .primaryShard()
-                                .currentNodeId()
-                        )
+                        getShardStats(secondGenIndex, i, (long) Math.ceil(100.0 / secondGenIndex.getNumberOfShards()), assignedShardNodeId)
                     );
                 }
 
-                for (DiscoveryNode node : clusterStateBeforeRollover.nodes().getAllNodes()) {
-                    MockTransportService.getInstance(node.getName())
-                        .addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
-                            TransportIndicesStatsAction instance = internalCluster().getInstance(
-                                TransportIndicesStatsAction.class,
-                                node.getName()
-                            );
-                            channel.sendResponse(
-                                instance.new NodeResponse(node.getId(), secondGenIndex.getNumberOfShards(), shards, List.of())
-                            );
-                        });
-                }
+                mockStatsForIndex(clusterStateBeforeRollover, assignedShardNodeId, secondGenIndex, shards);
 
                 RolloverRequest request = new RolloverRequest(dataStreamName, null);
                 request.lazy(true);
@@ -612,4 +542,33 @@ public class DataStreamAutoshardingIT extends ESIntegTestCase {
         }
     }
 
+    private static void mockStatsForIndex(
+        ClusterState clusterState,
+        String assignedShardNodeId,
+        IndexMetadata indexMetadata,
+        List<ShardStats> shards
+    ) {
+        for (DiscoveryNode node : clusterState.nodes().getAllNodes()) {
+            // one node returns the stats for all our shards, the other nodes don't return any stats
+            if (node.getId().equals(assignedShardNodeId)) {
+                MockTransportService.getInstance(node.getName())
+                    .addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
+                        TransportIndicesStatsAction instance = internalCluster().getInstance(
+                            TransportIndicesStatsAction.class,
+                            node.getName()
+                        );
+                        channel.sendResponse(instance.new NodeResponse(node.getId(), indexMetadata.getNumberOfShards(), shards, List.of()));
+                    });
+            } else {
+                MockTransportService.getInstance(node.getName())
+                    .addRequestHandlingBehavior(IndicesStatsAction.NAME + "[n]", (handler, request, channel, task) -> {
+                        TransportIndicesStatsAction instance = internalCluster().getInstance(
+                            TransportIndicesStatsAction.class,
+                            node.getName()
+                        );
+                        channel.sendResponse(instance.new NodeResponse(node.getId(), 0, List.of(), List.of()));
+                    });
+            }
+        }
+    }
 }

+ 10 - 5
server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java

@@ -250,11 +250,16 @@ public class TransportRolloverAction extends TransportMasterNodeAction<RolloverR
                     final Optional<IndexStats> indexStats = Optional.ofNullable(statsResponse)
                         .map(stats -> stats.getIndex(dataStream.getWriteIndex().getName()));
 
-                    Double writeLoad = indexStats.map(stats -> stats.getTotal().getIndexing())
-                        .map(indexing -> indexing.getTotal().getWriteLoad())
-                        .orElse(null);
-
-                    rolloverAutoSharding = dataStreamAutoShardingService.calculate(clusterState, dataStream, writeLoad);
+                    Double indexWriteLoad = indexStats.map(
+                        stats -> Arrays.stream(stats.getShards())
+                            .filter(shardStats -> shardStats.getStats().indexing != null)
+                            // only take primaries into account as in stateful the replicas also index data
+                            .filter(shardStats -> shardStats.getShardRouting().primary())
+                            .map(shardStats -> shardStats.getStats().indexing.getTotal().getWriteLoad())
+                            .reduce(0.0, Double::sum)
+                    ).orElse(null);
+
+                    rolloverAutoSharding = dataStreamAutoShardingService.calculate(clusterState, dataStream, indexWriteLoad);
                     logger.debug("auto sharding result for data stream [{}] is [{}]", dataStream.getName(), rolloverAutoSharding);
 
                     // if auto sharding recommends increasing the number of shards we want to trigger a rollover even if there are no

+ 2 - 19
server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java

@@ -29,7 +29,6 @@ import org.elasticsearch.index.Index;
 import java.util.List;
 import java.util.Objects;
 import java.util.OptionalDouble;
-import java.util.OptionalLong;
 import java.util.function.Function;
 import java.util.function.LongSupplier;
 
@@ -381,27 +380,11 @@ public class DataStreamAutoShardingService {
         // assume the current write index load is the highest observed and look back to find the actual maximum
         double maxIndexLoadWithinCoolingPeriod = writeIndexLoad;
         for (IndexWriteLoad writeLoad : writeLoadsWithinCoolingPeriod) {
-            // the IndexWriteLoad stores _for each shard_ a shard average write load ( calculated using : shard indexing time / shard
-            // uptime ) and its corresponding shard uptime
-            //
-            // to reconstruct the average _index_ write load we recalculate the shard indexing time by multiplying the shard write load
-            // to its uptime, and then, having the indexing time and uptime for each shard we calculate the average _index_ write load using
-            // (indexingTime_shard0 + indexingTime_shard1) / (uptime_shard0 + uptime_shard1)
-            // as {@link org.elasticsearch.index.shard.IndexingStats#add} does
-            double totalShardIndexingTime = 0;
-            long totalShardUptime = 0;
+            double totalIndexLoad = 0;
             for (int shardId = 0; shardId < writeLoad.numberOfShards(); shardId++) {
                 final OptionalDouble writeLoadForShard = writeLoad.getWriteLoadForShard(shardId);
-                final OptionalLong uptimeInMillisForShard = writeLoad.getUptimeInMillisForShard(shardId);
-                if (writeLoadForShard.isPresent()) {
-                    assert uptimeInMillisForShard.isPresent();
-                    double shardIndexingTime = writeLoadForShard.getAsDouble() * uptimeInMillisForShard.getAsLong();
-                    long shardUptimeInMillis = uptimeInMillisForShard.getAsLong();
-                    totalShardIndexingTime += shardIndexingTime;
-                    totalShardUptime += shardUptimeInMillis;
-                }
+                totalIndexLoad += writeLoadForShard.orElse(0);
             }
-            double totalIndexLoad = totalShardUptime == 0 ? 0.0 : (totalShardIndexingTime / totalShardUptime);
             if (totalIndexLoad > maxIndexLoadWithinCoolingPeriod) {
                 maxIndexLoadWithinCoolingPeriod = totalIndexLoad;
             }

+ 5 - 6
server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java

@@ -51,9 +51,7 @@ import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType
 import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.INCREASE_SHARDS;
 import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.NO_CHANGE_REQUIRED;
 import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
-import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.is;
-import static org.hamcrest.Matchers.lessThan;
 
 public class DataStreamAutoShardingServiceTests extends ESTestCase {
 
@@ -646,10 +644,10 @@ public class DataStreamAutoShardingServiceTests extends ESTestCase {
             () -> now
         );
         // to cover the entire cooldown period, the last index before the cooling period is taken into account
-        assertThat(maxIndexLoadWithinCoolingPeriod, is(lastIndexBeforeCoolingPeriodHasLowWriteLoad ? 5.0 : 999.0));
+        assertThat(maxIndexLoadWithinCoolingPeriod, is(lastIndexBeforeCoolingPeriodHasLowWriteLoad ? 15.0 : 999.0));
     }
 
-    public void testIndexLoadWithinCoolingPeriodIsShardLoadsAvg() {
+    public void testIndexLoadWithinCoolingPeriodIsSumOfShardsLoads() {
         final TimeValue coolingPeriod = TimeValue.timeValueDays(3);
 
         final Metadata.Builder metadataBuilder = Metadata.builder();
@@ -658,6 +656,8 @@ public class DataStreamAutoShardingServiceTests extends ESTestCase {
         final String dataStreamName = "logs";
         long now = System.currentTimeMillis();
 
+        double expectedIsSumOfShardLoads = 0.5 + 3.0 + 0.3333;
+
         for (int i = 0; i < numberOfBackingIndicesWithinCoolingPeriod; i++) {
             final long createdAt = now - (coolingPeriod.getMillis() / 2);
             IndexMetadata indexMetadata;
@@ -705,8 +705,7 @@ public class DataStreamAutoShardingServiceTests extends ESTestCase {
             coolingPeriod,
             () -> now
         );
-        assertThat(maxIndexLoadWithinCoolingPeriod, is(greaterThan(0.499)));
-        assertThat(maxIndexLoadWithinCoolingPeriod, is(lessThan(0.5)));
+        assertThat(maxIndexLoadWithinCoolingPeriod, is(expectedIsSumOfShardLoads));
     }
 
     public void testAutoShardingResultValidation() {