Browse Source

[8.19] [ML] Make AssignmentPlan to consider only assigned allocations (#132170) (#132274)

* [ML] Make AssignmentPlan to consider only assigned allocations (#132170)

A follow-up to #131990. This PR ensures that only assigned allocations and not current allocations are used in the memory requirements calculation in AssignmentPlan.

This change led to the simplification of the code in ZoneAwareAssignmentPlanner and TrainedModelRebalancer.

This PR also improves readability by adding comments, code documentation, renaming variables, and making the flow of if statements more straightforward.

Marking is a non-issue since the bug was already documented in #131990.

(cherry picked from commit 80c47f3aaa7c56da92b820c04652079bbc2115d6)

* Fix backport errors

* Fix unit test
Valeriy Khakhutskyy 1 month ago
parent
commit
e546bb8c2f

+ 9 - 17
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

@@ -134,23 +134,15 @@ class TrainedModelAssignmentRebalancer {
         return finalPlanBuilder.build();
     }
 
-    private static void copyAssignments(
-        AssignmentPlan source,
-        AssignmentPlan.Builder dest,
-        Map<String, AssignmentPlan.Node> originalNodeById
-    ) {
-        for (AssignmentPlan.Deployment m : source.models()) {
-            Map<AssignmentPlan.Node, Integer> nodeAssignments = source.assignments(m).orElse(Map.of());
-            for (Map.Entry<AssignmentPlan.Node, Integer> assignment : nodeAssignments.entrySet()) {
-                AssignmentPlan.Node originalNode = originalNodeById.get(assignment.getKey().id());
-                dest.assignModelToNode(m, originalNode, assignment.getValue());
-                if (m.currentAllocationsByNodeId().containsKey(originalNode.id())) {
-                    // TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
-                    // As the node has all its available memory we need to manually account memory of models with
-                    // current allocations.
-                    long requiredMemory = m.estimateMemoryUsageBytes(m.currentAllocationsByNodeId().get(originalNode.id()));
-                    dest.accountMemory(m, originalNode, requiredMemory);
-                }
+    /**
+     *  Transfers assignments from the source AssignmentPlan to the destination AssignmentPlan.Builder.
+     */
+    static void copyAssignments(AssignmentPlan source, AssignmentPlan.Builder dest, Map<String, AssignmentPlan.Node> originalNodeById) {
+        for (AssignmentPlan.Deployment deployment : source.models()) {
+            Map<AssignmentPlan.Node, Integer> sourceNodeAssignments = source.assignments(deployment).orElse(Map.of());
+            for (Map.Entry<AssignmentPlan.Node, Integer> sourceAssignment : sourceNodeAssignments.entrySet()) {
+                AssignmentPlan.Node node = originalNodeById.get(sourceAssignment.getKey().id());
+                dest.assignModelToNode(deployment, node, sourceAssignment.getValue());
             }
         }
     }

+ 28 - 22
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java

@@ -207,23 +207,43 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
         return Comparator.comparing(AssignmentPlan::computeQuality).compare(this, o);
     }
 
+    /**
+     * Checks whether all deployments in the current {@link AssignmentPlan} have at least as many
+     * allocations as currently assigned.
+     */
     public boolean satisfiesCurrentAssignments() {
         return models().stream().allMatch(this::isSatisfyingCurrentAssignmentsForModel);
     }
 
+    /**
+     * Checks whether the current assignments for a given {@link Deployment} meet its allocation requirements.
+     *
+     * It ensures that the total number of allocations assigned to the deployment across all nodes is
+     * at least equal to the deployment's current assigned allocations.
+     */
     private boolean isSatisfyingCurrentAssignmentsForModel(Deployment m) {
         if (m.currentAllocationsByNodeId().isEmpty()) {
             return true;
         }
         Map<Node, Integer> nodeAssignments = assignments.get(m);
-        int currentAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
-        return currentAllocations >= m.getCurrentAssignedAllocations();
+        int inPlanAssignedAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
+        return inPlanAssignedAllocations >= m.getCurrentAssignedAllocations();
     }
 
-    public boolean satisfiesAllocations(Deployment m) {
-        return remainingModelAllocations.getOrDefault(m, 0) == 0;
+    /**
+     * Checks if the current assignments satisfy the deployment's allocation requirements.
+     * @param deployment the deployment to check
+     * @return true if the current assignments satisfy the deployment's allocation requirements, false otherwise
+     */
+    public boolean satisfiesAllocations(Deployment deployment) {
+        return remainingModelAllocations.getOrDefault(deployment, 0) == 0;
     }
 
+    /**
+     * Checks if the current assignments satisfy all deployments' allocation requirements. This means that
+     * each deployment has no remaining allocations left to assign.
+     * @return true if the current assignments satisfy the deployments' allocation requirements, false otherwise
+     */
     public boolean satisfiesAllModels() {
         return models().stream().allMatch(this::satisfiesAllocations);
     }
@@ -408,8 +428,7 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
             if (allocations <= 0) {
                 return this;
             }
-            if (/*isAlreadyAssigned(deployment, node) == false
-                &&*/ requiredMemory > remainingNodeMemory.get(node)) {
+            if (requiredMemory > remainingNodeMemory.get(node)) {
                 throw new IllegalArgumentException(
                     "not enough memory on node ["
                         + node.id()
@@ -434,7 +453,7 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
                 );
             }
 
-            assignments.get(deployment).compute(node, (n, remAllocations) -> remAllocations + allocations);
+            assignments.get(deployment).compute(node, (n, assignedAllocations) -> assignedAllocations + allocations);
             accountMemory(deployment, node, requiredMemory);
 
             if (deployment.priority == Priority.NORMAL) {
@@ -445,23 +464,10 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
         }
 
         private int getAssignedAllocations(Deployment deployment, Node node) {
-            int currentAllocations = getCurrentAllocations(deployment, node);
-            int assignmentAllocations = assignments.get(deployment).get(node);
-            return currentAllocations + assignmentAllocations;
-        }
-
-        private static int getCurrentAllocations(Deployment m, Node n) {
-            return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0;
-        }
-
-        public void accountMemory(Deployment m, Node n) {
-            // TODO (#101612) remove or refactor unused method
-            long requiredMemory = getDeploymentMemoryRequirement(m, n, getCurrentAllocations(m, n));
-            accountMemory(m, n, requiredMemory);
+            return assignments.get(deployment).get(node);
         }
 
-        public void accountMemory(Deployment m, Node n, long requiredMemory) {
-            // TODO (#101612) computation of required memory should be done internally
+        void accountMemory(Deployment m, Node n, long requiredMemory) {
             remainingNodeMemory.computeIfPresent(n, (k, v) -> v - requiredMemory);
             if (remainingNodeMemory.containsKey(n) && remainingNodeMemory.get(n) < 0) {
                 throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.id() + "]");

+ 31 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java

@@ -57,13 +57,26 @@ public class AssignmentPlanner {
         return computePlan(true);
     }
 
-    public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
+    /**
+     * Computes an {@link AssignmentPlan} for the given nodes and deployments.
+     * If {@code tryAssigningAllPreviouslyAllocatedModels} is true, then the plan will
+     * attempt to assign at least one allocation to previously assigned models.
+     * Otherwise, it will only ensure that deployments assigned to existing nodes will preserve at least one allocation
+     *
+     * @param tryAssigningAllPreviouslyAllocatedModels whether to do the best effort assigning previously assigned models somewhere
+     *                                                 with at least one allocation
+     * @return the computed assignment plan
+     */
+    public AssignmentPlan computePlan(boolean tryAssigningAllPreviouslyAllocatedModels) {
         logger.debug(() -> format("Computing plan for nodes = %s; deployments = %s", nodes, deployments));
 
         AssignmentPlan bestPlan;
         AssignmentPlan planSatisfyingCurrentAssignments = solveSatisfyingCurrentAssignments();
         logger.debug(() -> "Plan satisfying current assignments =\n" + planSatisfyingCurrentAssignments.prettyPrint());
-        if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() == false && tryAssigningPreviouslyAssignedModels) {
+        if (planSatisfyingCurrentAssignments.arePreviouslyAssignedModelsAssigned() || tryAssigningAllPreviouslyAllocatedModels == false) {
+            bestPlan = planSatisfyingCurrentAssignments;
+        } else {
+            // try to reuse any deployment that would otherwise drop to zero allocations
             AssignmentPlan planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated =
                 solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated();
             logger.debug(
@@ -79,8 +92,6 @@ public class AssignmentPlanner {
                             ? planSatisfyingCurrentAssignments
                             : planAllocatingAtLeastOnceModelsThatWerePreviouslyAllocated;
             }
-        } else {
-            bestPlan = planSatisfyingCurrentAssignments;
         }
 
         logger.debug(() -> "Best plan =\n" + bestPlan.prettyPrint());
@@ -88,19 +99,30 @@ public class AssignmentPlanner {
         return bestPlan;
     }
 
+    /**
+     * Computes the best assignment plan from two strategies:
+     * 1. Preserving one allocation on current assignments, which is the most flexible
+     * 2. Preserving all allocations on current assignments, which is more conservative
+     * @return the best assignment plan
+     */
     private AssignmentPlan solveSatisfyingCurrentAssignments() {
         AssignmentPlan bestPlan;
         // First solve preserving one allocation per assignment because that is most flexible
         AssignmentPlan planKeepingOneAllocationOnCurrentAssignments = solveKeepingOneAllocationOnCurrentAssignments();
-        if (planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() == false) {
+
+        if (planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels()) {
+            // If the plan satisfies all models, then we can use it as is
+            bestPlan = planKeepingOneAllocationOnCurrentAssignments;
+        } else if (planKeepingOneAllocationOnCurrentAssignments.satisfiesCurrentAssignments() == false) {
+            // If in the new assignment plan, some deployments have fewer allocations than in the current assignments,
+            // try explicitly preserving all allocations on current assignments.
             bestPlan = solvePreservingAllAllocationsOnCurrentAssignments();
-        } else if (planKeepingOneAllocationOnCurrentAssignments.satisfiesAllModels() == false) {
+        } else {
+            // Choose the best strategy according to {@link AssignmentPlan#computeQuality(AssignmentPlan)}
             AssignmentPlan planKeepingAllAllocationsOnCurrentAssignments = solvePreservingAllAllocationsOnCurrentAssignments();
             bestPlan = planKeepingAllAllocationsOnCurrentAssignments.compareTo(planKeepingOneAllocationOnCurrentAssignments) >= 0
                 ? planKeepingAllAllocationsOnCurrentAssignments
                 : planKeepingOneAllocationOnCurrentAssignments;
-        } else {
-            bestPlan = planKeepingOneAllocationOnCurrentAssignments;
         }
         return bestPlan;
     }
@@ -116,7 +138,7 @@ public class AssignmentPlanner {
                     1,
                     m.threadsPerAllocation(),
                     // don't rely on the current allocation
-                    new HashMap<>(),
+                    Map.of(),
                     m.maxAssignedAllocations(),
                     m.getAdaptiveAllocationsSettings(),
                     m.perDeploymentMemoryBytes(),

+ 21 - 18
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java

@@ -167,36 +167,39 @@ public class ZoneAwareAssignmentPlanner {
         List<AssignmentPlan.Deployment> planDeployments = preserveAllAllocations.modelsPreservingAllocations();
         AssignmentPlan plan = new LinearProgrammingPlanSolver(planNodes, planDeployments).solvePlan(false);
         plan = preserveAllAllocations.mergePreservedAllocations(plan);
-        return swapOriginalModelsInPlan(plan, allNodes, modelsAccountingPlans);
+        return swapOriginalDeploymentsInPlan(plan, allNodes, modelsAccountingPlans);
     }
 
-    private AssignmentPlan swapOriginalModelsInPlan(
+    /**
+     * The method is responsible for reconstructing an AssignmentPlan
+     * by replacing the deployments and nodes in the given plan with their original counterparts.
+     * This ensures that the final plan uses the original objects while preserving the assignments
+     * and memory accounting from the input plan.
+     *
+     * @param plan AssignmentPlan to reconstruct with original models and nodes
+     * @param allNodes List of all nodes in the system, used to find original nodes
+     * @param planDeployments List of deployments in the plan, not the original deployments
+     * @return final plan with original models and nodes swapped in
+     */
+    private AssignmentPlan swapOriginalDeploymentsInPlan(
         AssignmentPlan plan,
         List<Node> allNodes,
         List<AssignmentPlan.Deployment> planDeployments
     ) {
-        final Map<String, AssignmentPlan.Deployment> originalModelById = deployments.stream()
+        final Map<String, AssignmentPlan.Deployment> originalDeploymentsById = deployments.stream()
             .collect(Collectors.toMap(AssignmentPlan.Deployment::id, Function.identity()));
         final Map<String, Node> originalNodeById = allNodes.stream().collect(Collectors.toMap(Node::id, Function.identity()));
-        AssignmentPlan.Builder planBuilder = AssignmentPlan.builder(allNodes, deployments);
-        for (AssignmentPlan.Deployment m : planDeployments) {
-            AssignmentPlan.Deployment originalDeployment = originalModelById.get(m.id());
-            Map<Node, Integer> nodeAssignments = plan.assignments(m).orElse(Map.of());
+        AssignmentPlan.Builder finalPlanBuilder = AssignmentPlan.builder(allNodes, deployments);
+
+        for (AssignmentPlan.Deployment planDeployment : planDeployments) {
+            AssignmentPlan.Deployment originalDeployment = originalDeploymentsById.get(planDeployment.id());
+            Map<Node, Integer> nodeAssignments = plan.assignments(planDeployment).orElse(Map.of());
             for (Map.Entry<Node, Integer> assignment : nodeAssignments.entrySet()) {
                 Node originalNode = originalNodeById.get(assignment.getKey().id());
-                planBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
-                if (originalDeployment.currentAllocationsByNodeId().containsKey(originalNode.id())) {
-                    // TODO (#101612) requiredMemory should be calculated by the AssignmentPlan.Builder
-                    // As the node has all its available memory we need to manually account memory of models with
-                    // current allocations.
-                    long requiredMemory = originalDeployment.estimateMemoryUsageBytes(
-                        originalDeployment.currentAllocationsByNodeId().get(originalNode.id())
-                    );
-                    planBuilder.accountMemory(m, originalNode, requiredMemory);
-                }
+                finalPlanBuilder.assignModelToNode(originalDeployment, originalNode, assignment.getValue());
             }
         }
-        return planBuilder.build();
+        return finalPlanBuilder.build();
     }
 
     private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByModelId(List<AssignmentPlan> plans) {

+ 70 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java

@@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
 import org.elasticsearch.xpack.ml.job.NodeLoad;
 
 import java.util.ArrayList;
@@ -28,6 +29,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 import static org.hamcrest.Matchers.aMapWithSize;
 import static org.hamcrest.Matchers.anEmptyMap;
@@ -1153,6 +1156,73 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
         assertThat(assignment.getReason().isPresent(), is(false));
     }
 
+    public void testCopyAssignments() {
+        // Create test nodes
+        AssignmentPlan.Node node1 = new AssignmentPlan.Node("node-1", ByteSizeValue.ofGb(1).getBytes(), 4);
+        AssignmentPlan.Node node2 = new AssignmentPlan.Node("node-2", ByteSizeValue.ofGb(1).getBytes(), 8);
+        List<AssignmentPlan.Node> nodes = List.of(node1, node2);
+
+        // Create test deployments
+        AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment(
+            "model-1",
+            ByteSizeValue.ofMb(100).getBytes(),
+            2,
+            1,
+            Map.of(),
+            0,
+            null,
+            Priority.NORMAL,
+            0,
+            0
+        );
+        AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment(
+
+            "model-2",
+            ByteSizeValue.ofMb(100).getBytes(),
+            1,
+            2,
+            Map.of(),
+            0,
+            null,
+            Priority.LOW,
+            0,
+            0
+        );
+        List<AssignmentPlan.Deployment> deployments = List.of(deployment1, deployment2);
+
+        // Create source plan and assign models to nodes
+        AssignmentPlan.Builder sourceBuilder = AssignmentPlan.builder(nodes, deployments);
+        sourceBuilder.assignModelToNode(deployment1, node1, 1);
+        sourceBuilder.assignModelToNode(deployment1, node2, 1);
+        sourceBuilder.assignModelToNode(deployment2, node2, 1);
+        AssignmentPlan source = sourceBuilder.build();
+
+        // Create destination plan
+        AssignmentPlan.Builder dest = AssignmentPlan.builder(nodes, deployments);
+
+        // Create map of node IDs to original nodes
+        Map<String, AssignmentPlan.Node> originalNodeById = nodes.stream()
+            .collect(Collectors.toMap(AssignmentPlan.Node::id, Function.identity()));
+
+        // Call copyAssignments
+        TrainedModelAssignmentRebalancer.copyAssignments(source, dest, originalNodeById);
+
+        // Build the destination plan
+        AssignmentPlan result = dest.build();
+
+        // Verify assignments
+        Optional<Map<AssignmentPlan.Node, Integer>> deployment1Assignments = result.assignments(deployment1);
+        assertThat(deployment1Assignments.isPresent(), is(true));
+        assertThat(deployment1Assignments.get().size(), equalTo(2));
+        assertThat(deployment1Assignments.get().get(node1), equalTo(1));
+        assertThat(deployment1Assignments.get().get(node2), equalTo(1));
+
+        Optional<Map<AssignmentPlan.Node, Integer>> deployment2Assignments = result.assignments(deployment2);
+        assertThat(deployment2Assignments.isPresent(), is(true));
+        assertThat(deployment2Assignments.get().size(), equalTo(1));
+        assertThat(deployment2Assignments.get().get(node2), equalTo(1));
+    }
+
     private static StartTrainedModelDeploymentAction.TaskParams lowPriorityParams(String deploymentId, long modelSize) {
         return lowPriorityParams(deploymentId, deploymentId, modelSize);
     }

+ 23 - 18
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java

@@ -16,6 +16,7 @@ import java.util.List;
 import java.util.Map;
 
 import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.is;
@@ -128,7 +129,7 @@ public class AssignmentPlanTests extends ESTestCase {
             builder.assignModelToNode(m, n, 1);
 
             assertThat(builder.getRemainingCores(n), equalTo(2));
-            assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(350).getBytes()));
+            assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(50).getBytes()));
             assertThat(builder.getRemainingAllocations(m), equalTo(1));
             assertThat(builder.getRemainingThreads(m), equalTo(2));
 
@@ -156,7 +157,7 @@ public class AssignmentPlanTests extends ESTestCase {
             builder.assignModelToNode(m, n, 1);
 
             assertThat(builder.getRemainingCores(n), equalTo(2));
-            assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(325).getBytes()));
+            assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(0).getBytes()));
             assertThat(builder.getRemainingAllocations(m), equalTo(1));
             assertThat(builder.getRemainingThreads(m), equalTo(2));
 
@@ -180,7 +181,9 @@ public class AssignmentPlanTests extends ESTestCase {
             builder.assignModelToNode(m, n, 1);
 
             assertThat(builder.getRemainingCores(n), equalTo(2));
-            assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(300).getBytes()));
+            // Since perDeployment memory is not specified, we compute the base memory usage.
+            // The remaining memory is 300MB - (240 MB + 2*30 MB) = 0MB
+            assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(0).getBytes()));
             assertThat(builder.getRemainingAllocations(m), equalTo(1));
             assertThat(builder.getRemainingThreads(m), equalTo(2));
 
@@ -209,7 +212,11 @@ public class AssignmentPlanTests extends ESTestCase {
             builder.assignModelToNode(m, n, 1);
 
             assertThat(builder.getRemainingCores(n), equalTo(2));
-            assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(275).getBytes()));
+            // base memory: 240+2*25 = 290MB
+            // since perDeployment memory is specified, we compute the new memory format usage:
+            // 250 (perDeployment) + 1*25 (perAllocation) + 25 (modelDefinition) = 300MB
+            // Then we take the maximum of 290 and 300, which is 300MB
+            assertThat(builder.getRemainingMemory(n), equalTo(ByteSizeValue.ofMb(0).getBytes()));
             assertThat(builder.getRemainingAllocations(m), equalTo(1));
             assertThat(builder.getRemainingThreads(m), equalTo(2));
 
@@ -248,12 +255,8 @@ public class AssignmentPlanTests extends ESTestCase {
 
             AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
 
-            builder.assignModelToNode(m, n, 2);
-            AssignmentPlan plan = builder.build();
-
-            assertThat(plan.models(), contains(m));
-            assertThat(plan.satisfiesCurrentAssignments(), is(true));
-            assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2)));
+            Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 2));
+            assertThat(e.getMessage(), containsString("not enough memory on node"));
         }
         { // new memory format
             Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4);
@@ -271,12 +274,8 @@ public class AssignmentPlanTests extends ESTestCase {
 
             AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
 
-            builder.assignModelToNode(m, n, 2);
-            AssignmentPlan plan = builder.build();
-
-            assertThat(plan.models(), contains(m));
-            assertThat(plan.satisfiesCurrentAssignments(), is(true));
-            assertThat(plan.assignments(m).get(), equalTo(Map.of(n, 2)));
+            Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 2));
+            assertThat(e.getMessage(), containsString("not enough memory on node"));
         }
     }
 
@@ -375,7 +374,9 @@ public class AssignmentPlanTests extends ESTestCase {
             // old memory format
             Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(31).getBytes(), 1, 1, Map.of("n_1", 1), 0, null, 0, 0);
             AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
-            assertThat(builder.canAssign(m, n, 1), is(true));
+            // 240 + 2*31 = 302MB, this doesn't fit in 300MB. We don't care that the deployment is currently allocated since
+            // only previous assignments should be considered
+            assertThat(builder.canAssign(m, n, 1), is(false));
         }
         {
             // new memory format
@@ -387,11 +388,15 @@ public class AssignmentPlanTests extends ESTestCase {
                 Map.of("n_1", 1),
                 0,
                 null,
-                ByteSizeValue.ofMb(300).getBytes(),
+                ByteSizeValue.ofMb(265).getBytes(),
                 ByteSizeValue.ofMb(10).getBytes()
             );
             AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
+            // 265 + 1*10 + 25 = 300MB, this doesn't fit in 300MB. We don't care that the deployment is currently allocated since
             assertThat(builder.canAssign(m, n, 1), is(true));
+            builder.assignModelToNode(m, n, 1);
+            // After assignment, no more memory is available
+            assertThat(builder.canAssign(m, n, 1), is(false));
         }
     }