Browse Source

[ML] Prevent resource over-subscription in model allocation planner (#100392)

When rescaling or changing the number of nodes, the allocator tries to ensure that previously allocated model deployments always have at least one allocation available. However, we did not ensure that the Resource Tracker would always have a positive amount of memory on the nodes when taking into account the memory of the previously allocated deployments. In a particular scenario, when we downsize the nodes, this can lead to the overallocation of resources.

This PR adds the missing asserting and hardens the unit tests to ensure that we don't use more resources than we have.
Valeriy Khakhutskyy 2 years ago
parent
commit
8ff7dee0f0

+ 5 - 0
docs/changelog/100392.yaml

@@ -0,0 +1,5 @@
+pr: 100392
+summary: Prevent resource over-subscription in model allocation planner
+area: Machine Learning
+type: bug
+issues: []

+ 6 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java

@@ -80,10 +80,12 @@ abstract class AbstractPreserveAllocations {
             for (Node n : nodes) {
                 int allocations = assignmentsByModelNodeIdPair.getOrDefault(Tuple.tuple(m.id(), n.id()), 0);
                 if (m.currentAllocationsByNodeId().containsKey(n.id())) {
-                    allocations += addPreservedAllocations(n, m);
-                    // As the node has all its available memory we need to manually account memory of models with
-                    // current allocations.
-                    mergedPlanBuilder.accountMemory(m, n);
+                    if (mergedPlanBuilder.getRemainingMemory(n) >= m.memoryBytes()) {
+                        allocations += addPreservedAllocations(n, m);
+                        // As the node has all its available memory we need to manually account memory of models with
+                        // current allocations.
+                        mergedPlanBuilder.accountMemory(m, n);
+                    }
                 }
                 if (allocations > 0) {
                     mergedPlanBuilder.assignModelToNode(m, n, allocations);

+ 3 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java

@@ -349,6 +349,9 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
 
         public void accountMemory(Deployment m, Node n) {
             remainingNodeMemory.computeIfPresent(n, (k, v) -> v - m.memoryBytes());
+            if (remainingNodeMemory.get(n) < 0) {
+                throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.id() + "]");
+            }
         }
 
         public AssignmentPlan build() {

+ 41 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java

@@ -361,6 +361,9 @@ public class AssignmentPlannerTests extends ESTestCase {
         Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0);
         AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment1, deployment2))
             .computePlan();
+        assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L));
+        assertThat(assignmentPlan.getRemainingNodeMemory("n_2"), greaterThanOrEqualTo(0L));
+        assertThat(assignmentPlan.getRemainingNodeMemory("n_3"), greaterThanOrEqualTo(0L));
         {
             assertThat(assignmentPlan.assignments(deployment1).isPresent(), is(true));
             Map<Node, Integer> assignments = assignmentPlan.assignments(deployment1).get();
@@ -403,6 +406,8 @@ public class AssignmentPlannerTests extends ESTestCase {
         assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2"));
         assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2)));
         assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1)));
+        assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L));
+        assertThat(assignmentPlan.getRemainingNodeMemory("n_2"), greaterThanOrEqualTo(0L));
     }
 
     public void testGivenPreviouslyAssignedModels_CannotAllBeAllocated() {
@@ -419,7 +424,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2);
         Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0);
-        Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0);
+        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0);
         Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0);
 
         // First only start m_1
@@ -458,13 +463,21 @@ public class AssignmentPlannerTests extends ESTestCase {
 
         // First, one node goes away.
         assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan();
+        assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L));
 
         // Then, a node double in memory size is added.
         assignmentPlan = new AssignmentPlanner(List.of(node1, node3), createModelsFromPlan(assignmentPlan)).computePlan();
+        assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L));
+        assertThat(assignmentPlan.getRemainingNodeMemory(node3.id()), greaterThanOrEqualTo(0L));
         // And another.
         assignmentPlan = new AssignmentPlanner(List.of(node1, node3, node4), createModelsFromPlan(assignmentPlan)).computePlan();
+        assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L));
+        assertThat(assignmentPlan.getRemainingNodeMemory(node3.id()), greaterThanOrEqualTo(0L));
+        assertThat(assignmentPlan.getRemainingNodeMemory(node4.id()), greaterThanOrEqualTo(0L));
         // Finally, the remaining smaller node is removed
         assignmentPlan = new AssignmentPlanner(List.of(node3, node4), createModelsFromPlan(assignmentPlan)).computePlan();
+        assertThat(assignmentPlan.getRemainingNodeMemory(node3.id()), greaterThanOrEqualTo(0L));
+        assertThat(assignmentPlan.getRemainingNodeMemory(node4.id()), greaterThanOrEqualTo(0L));
 
         indexedBasedPlan = convertToIdIndexed(assignmentPlan);
         assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2", "m_3"));
@@ -477,6 +490,33 @@ public class AssignmentPlannerTests extends ESTestCase {
         assertThat(assignmentPlan.getRemainingNodeCores("n_2"), equalTo(0));
     }
 
+    public void testGivenClusterResize_ShouldRemoveAllocatedModels() {
+        // Ensure that plan is removing previously allocated models if not enough memory is available
+        Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2);
+        Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2);
+        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0);
+        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0);
+        Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0);
+
+        // Create a plan where all deployments are assigned at least once
+        AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2, deployment3))
+            .computePlan();
+        Map<String, Map<String, Integer>> indexedBasedPlan = convertToIdIndexed(assignmentPlan);
+        assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2", "m_3"));
+        assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2)));
+        assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1)));
+        assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1)));
+        assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L));
+        assertThat(assignmentPlan.getRemainingNodeMemory(node2.id()), greaterThanOrEqualTo(0L));
+
+        // Now the cluster starts getting resized. Ensure that resources are not over-allocated.
+        assignmentPlan = new AssignmentPlanner(List.of(node1), createModelsFromPlan(assignmentPlan)).computePlan();
+        assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2)));
+        assertThat(assignmentPlan.getRemainingNodeMemory(node1.id()), greaterThanOrEqualTo(0L));
+        assertThat(assignmentPlan.getRemainingNodeCores(node1.id()), greaterThanOrEqualTo(0));
+
+    }
+
     public static List<Deployment> createModelsFromPlan(AssignmentPlan plan) {
         List<Deployment> deployments = new ArrayList<>();
         for (Deployment m : plan.models()) {