Browse Source

Don't immediately scale down startups triggered by non-master nodes in inference adaptive allocations. (#125297)

Jan Kuipers 6 months ago
parent
commit
a16eaf1423

+ 32 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java

@@ -30,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
+import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
 import org.elasticsearch.xpack.ml.MachineLearning;
@@ -213,6 +214,7 @@ public class AdaptiveAllocationsScalerService implements ClusterStateListener {
     private volatile Scheduler.Cancellable cancellable;
     private final AtomicBoolean busy;
     private final long scaleToZeroAfterNoRequestsSeconds;
+    private final long scaleUpCooldownTimeMillis;
     private final Set<String> deploymentIdsWithInFlightScaleFromZeroRequests = new ConcurrentSkipListSet<>();
     private final Map<String, String> lastWarningMessages = new ConcurrentHashMap<>();
 
@@ -224,7 +226,17 @@ public class AdaptiveAllocationsScalerService implements ClusterStateListener {
         MeterRegistry meterRegistry,
         boolean isNlpEnabled
     ) {
-        this(threadPool, clusterService, client, inferenceAuditor, meterRegistry, isNlpEnabled, DEFAULT_TIME_INTERVAL_SECONDS);
+        this(
+            threadPool,
+            clusterService,
+            client,
+            inferenceAuditor,
+            meterRegistry,
+            isNlpEnabled,
+            DEFAULT_TIME_INTERVAL_SECONDS,
+            SCALE_TO_ZERO_AFTER_NO_REQUESTS_TIME_SECONDS,
+            SCALE_UP_COOLDOWN_TIME_MILLIS
+        );
     }
 
     // visible for testing
@@ -235,7 +247,9 @@ public class AdaptiveAllocationsScalerService implements ClusterStateListener {
         InferenceAuditor inferenceAuditor,
         MeterRegistry meterRegistry,
         boolean isNlpEnabled,
-        int timeIntervalSeconds
+        int timeIntervalSeconds,
+        long scaleToZeroAfterNoRequestsSeconds,
+        long scaleUpCooldownTimeMillis
     ) {
         this.threadPool = threadPool;
         this.clusterService = clusterService;
@@ -244,6 +258,8 @@ public class AdaptiveAllocationsScalerService implements ClusterStateListener {
         this.meterRegistry = meterRegistry;
         this.isNlpEnabled = isNlpEnabled;
         this.timeIntervalSeconds = timeIntervalSeconds;
+        this.scaleToZeroAfterNoRequestsSeconds = scaleToZeroAfterNoRequestsSeconds;
+        this.scaleUpCooldownTimeMillis = scaleUpCooldownTimeMillis;
 
         lastInferenceStatsByDeploymentAndNode = new HashMap<>();
         lastInferenceStatsTimestampMillis = null;
@@ -251,7 +267,6 @@ public class AdaptiveAllocationsScalerService implements ClusterStateListener {
         scalers = new HashMap<>();
         metrics = new Metrics();
         busy = new AtomicBoolean(false);
-        scaleToZeroAfterNoRequestsSeconds = SCALE_TO_ZERO_AFTER_NO_REQUESTS_TIME_SECONDS;
     }
 
     public synchronized void start() {
@@ -375,6 +390,9 @@ public class AdaptiveAllocationsScalerService implements ClusterStateListener {
 
         Map<String, Stats> recentStatsByDeployment = new HashMap<>();
         Map<String, Integer> numberOfAllocations = new HashMap<>();
+        // Check for recent scale ups in the deployment stats, because a different node may have
+        // caused a scale up when an inference request arrives and there were zero allocations.
+        Set<String> hasRecentObservedScaleUp = new HashSet<>();
 
         for (AssignmentStats assignmentStats : statsResponse.getStats().results()) {
             String deploymentId = assignmentStats.getDeploymentId();
@@ -401,6 +419,12 @@ public class AdaptiveAllocationsScalerService implements ClusterStateListener {
                         (key, value) -> value == null ? recentStats : value.add(recentStats)
                     );
                 }
+                if (nodeStats.getRoutingState() != null && nodeStats.getRoutingState().getState() == RoutingState.STARTING) {
+                    hasRecentObservedScaleUp.add(deploymentId);
+                }
+                if (nodeStats.getStartTime() != null && now < nodeStats.getStartTime().toEpochMilli() + scaleUpCooldownTimeMillis) {
+                    hasRecentObservedScaleUp.add(deploymentId);
+                }
             }
         }
 
@@ -416,9 +440,12 @@ public class AdaptiveAllocationsScalerService implements ClusterStateListener {
             Integer newNumberOfAllocations = adaptiveAllocationsScaler.scale();
             if (newNumberOfAllocations != null) {
                 Long lastScaleUpTimeMillis = lastScaleUpTimesMillis.get(deploymentId);
+                // hasRecentScaleUp indicates whether this service has recently scaled up the deployment.
+                // hasRecentObservedScaleUp indicates whether a deployment recently has started,
+                // potentially triggered by another node.
+                boolean hasRecentScaleUp = lastScaleUpTimeMillis != null && now < lastScaleUpTimeMillis + scaleUpCooldownTimeMillis;
                 if (newNumberOfAllocations < numberOfAllocations.get(deploymentId)
-                    && lastScaleUpTimeMillis != null
-                    && now < lastScaleUpTimeMillis + SCALE_UP_COOLDOWN_TIME_MILLIS) {
+                    && (hasRecentScaleUp || hasRecentObservedScaleUp.contains(deploymentId))) {
                     logger.debug("adaptive allocations scaler: skipping scaling down [{}] because of recent scaleup.", deploymentId);
                     continue;
                 }

+ 178 - 11
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java

@@ -36,8 +36,8 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 import org.junit.After;
 import org.junit.Before;
 
-import java.io.IOException;
 import java.time.Instant;
+import java.time.temporal.ChronoUnit;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -114,7 +114,12 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
         return clusterState;
     }
 
-    private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllocations, int inferenceCount, double latency) {
+    private GetDeploymentStatsAction.Response getDeploymentStatsResponse(
+        int numAllocations,
+        int inferenceCount,
+        double latency,
+        boolean recentStartup
+    ) {
         return new GetDeploymentStatsAction.Response(
             List.of(),
             List.of(),
@@ -127,7 +132,7 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
                     new AdaptiveAllocationsSettings(true, null, null),
                     1024,
                     ByteSizeValue.ZERO,
-                    Instant.now(),
+                    Instant.now().minus(1, ChronoUnit.DAYS),
                     List.of(
                         AssignmentStats.NodeStats.forStartedState(
                             randomBoolean() ? DiscoveryNodeUtils.create("node_1") : null,
@@ -140,7 +145,7 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
                             0,
                             0,
                             Instant.now(),
-                            Instant.now(),
+                            recentStartup ? Instant.now() : Instant.now().minus(1, ChronoUnit.HOURS),
                             1,
                             numAllocations,
                             inferenceCount,
@@ -156,7 +161,7 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
         );
     }
 
-    public void test() throws IOException {
+    public void test_scaleUp() {
         // Initialize the cluster with a deployment with 1 allocation.
         ClusterState clusterState = getClusterState(1);
         when(clusterService.state()).thenReturn(clusterState);
@@ -168,7 +173,9 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
             inferenceAuditor,
             meterRegistry,
             true,
-            1
+            1,
+            60,
+            60_000
         );
         service.start();
 
@@ -182,7 +189,7 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
         doAnswer(invocationOnMock -> {
             @SuppressWarnings("unchecked")
             var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
-            listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0));
+            listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false));
             return Void.TYPE;
         }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
 
@@ -198,7 +205,7 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
         doAnswer(invocationOnMock -> {
             @SuppressWarnings("unchecked")
             var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
-            listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0));
+            listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0, false));
             return Void.TYPE;
         }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
         doAnswer(invocationOnMock -> {
@@ -229,7 +236,137 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
         doAnswer(invocationOnMock -> {
             @SuppressWarnings("unchecked")
             var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
-            listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0));
+            listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0, false));
+            return Void.TYPE;
+        }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
+        doAnswer(invocationOnMock -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<CreateTrainedModelAssignmentAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(null);
+            return Void.TYPE;
+        }).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any());
+
+        safeSleep(1000);
+
+        verify(client, times(1)).threadPool();
+        verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
+        verifyNoMoreInteractions(client, clusterService);
+
+        service.stop();
+    }
+
+    public void test_scaleDownToZero_whenNoRequests() {
+        // Initialize the cluster with a deployment with 1 allocation.
+        ClusterState clusterState = getClusterState(1);
+        when(clusterService.state()).thenReturn(clusterState);
+
+        AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
+            threadPool,
+            clusterService,
+            client,
+            inferenceAuditor,
+            meterRegistry,
+            true,
+            1,
+            1,
+            2_000
+        );
+        service.start();
+
+        verify(clusterService).state();
+        verify(clusterService).addListener(same(service));
+        verifyNoMoreInteractions(client, clusterService);
+        reset(client, clusterService);
+
+        // First cycle: 1 inference request, so no need for scaling.
+        when(client.threadPool()).thenReturn(threadPool);
+        doAnswer(invocationOnMock -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, false));
+            return Void.TYPE;
+        }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
+
+        safeSleep(1200);
+
+        verify(client, times(1)).threadPool();
+        verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
+        verifyNoMoreInteractions(client, clusterService);
+        reset(client, clusterService);
+
+        // Second cycle: 0 inference requests for 1 second, so scale down to 0 allocations.
+        when(client.threadPool()).thenReturn(threadPool);
+        doAnswer(invocationOnMock -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false));
+            return Void.TYPE;
+        }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
+        doAnswer(invocationOnMock -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<CreateTrainedModelAssignmentAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(null);
+            return Void.TYPE;
+        }).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any());
+
+        safeSleep(1000);
+
+        verify(client, times(2)).threadPool();
+        verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
+        var updateRequest = new UpdateTrainedModelDeploymentAction.Request("test-deployment");
+        updateRequest.setNumberOfAllocations(0);
+        updateRequest.setIsInternal(true);
+        verify(client, times(1)).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), eq(updateRequest), any());
+        verifyNoMoreInteractions(client, clusterService);
+
+        service.stop();
+    }
+
+    public void test_noScaleDownToZero_whenRecentlyScaledUpByOtherNode() {
+        // Initialize the cluster with a deployment with 1 allocation.
+        ClusterState clusterState = getClusterState(1);
+        when(clusterService.state()).thenReturn(clusterState);
+
+        AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService(
+            threadPool,
+            clusterService,
+            client,
+            inferenceAuditor,
+            meterRegistry,
+            true,
+            1,
+            1,
+            2_000
+        );
+        service.start();
+
+        verify(clusterService).state();
+        verify(clusterService).addListener(same(service));
+        verifyNoMoreInteractions(client, clusterService);
+        reset(client, clusterService);
+
+        // First cycle: 1 inference request, so no need for scaling.
+        when(client.threadPool()).thenReturn(threadPool);
+        doAnswer(invocationOnMock -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0, true));
+            return Void.TYPE;
+        }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
+
+        safeSleep(1200);
+
+        verify(client, times(1)).threadPool();
+        verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
+        verifyNoMoreInteractions(client, clusterService);
+        reset(client, clusterService);
+
+        // Second cycle: 0 inference requests for 1 second, but a recent scale up by another node.
+        when(client.threadPool()).thenReturn(threadPool);
+        doAnswer(invocationOnMock -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, true));
             return Void.TYPE;
         }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
         doAnswer(invocationOnMock -> {
@@ -244,6 +381,32 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
         verify(client, times(1)).threadPool();
         verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
         verifyNoMoreInteractions(client, clusterService);
+        reset(client, clusterService);
+
+        // Third cycle: 0 inference requests for 1 second and no recent scale up, so scale down to 0 allocations.
+        when(client.threadPool()).thenReturn(threadPool);
+        doAnswer(invocationOnMock -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<GetDeploymentStatsAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(getDeploymentStatsResponse(1, 0, 10.0, false));
+            return Void.TYPE;
+        }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any());
+        doAnswer(invocationOnMock -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<CreateTrainedModelAssignmentAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(null);
+            return Void.TYPE;
+        }).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any());
+
+        safeSleep(1000);
+
+        verify(client, times(2)).threadPool();
+        verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any());
+        var updateRequest = new UpdateTrainedModelDeploymentAction.Request("test-deployment");
+        updateRequest.setNumberOfAllocations(0);
+        updateRequest.setIsInternal(true);
+        verify(client, times(1)).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), eq(updateRequest), any());
+        verifyNoMoreInteractions(client, clusterService);
 
         service.stop();
     }
@@ -256,7 +419,9 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
             inferenceAuditor,
             meterRegistry,
             true,
-            1
+            1,
+            60,
+            60_000
         );
 
         when(client.threadPool()).thenReturn(threadPool);
@@ -289,7 +454,9 @@ public class AdaptiveAllocationsScalerServiceTests extends ESTestCase {
             inferenceAuditor,
             meterRegistry,
             true,
-            1
+            1,
+            60,
+            60_000
         );
 
         var latch = new CountDownLatch(1);