Răsfoiți Sursa

Fix memory usage estimation for ELSER models (#131630)

* Pass model ID instead of deployment ID to memory estimator

* Update docs/changelog/131630.yaml
Jan Kuipers 3 luni în urmă
părinte
comite
f393dbab36

+ 5 - 0
docs/changelog/131630.yaml

@@ -0,0 +1,5 @@
+pr: 131630
+summary: Fix memory usage estimation for ELSER models
+area: Machine Learning
+type: bug
+issues: []

+ 4 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

@@ -169,6 +169,7 @@ class TrainedModelAssignmentRebalancer {
                     .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getTargetAllocations()));
                 return new AssignmentPlan.Deployment(
                     assignment.getDeploymentId(),
+                    assignment.getModelId(),
                     assignment.getTaskParams().getModelBytes(),
                     assignment.getTaskParams().getNumberOfAllocations(),
                     assignment.getTaskParams().getThreadsPerAllocation(),
@@ -185,6 +186,7 @@ class TrainedModelAssignmentRebalancer {
             planDeployments.add(
                 new AssignmentPlan.Deployment(
                     taskParams.getDeploymentId(),
+                    taskParams.getModelId(),
                     taskParams.getModelBytes(),
                     taskParams.getNumberOfAllocations(),
                     taskParams.getThreadsPerAllocation(),
@@ -225,6 +227,7 @@ class TrainedModelAssignmentRebalancer {
             .map(
                 assignment -> new AssignmentPlan.Deployment(
                     assignment.getDeploymentId(),
+                    assignment.getModelId(),
                     assignment.getTaskParams().getModelBytes(),
                     assignment.getTaskParams().getNumberOfAllocations(),
                     assignment.getTaskParams().getThreadsPerAllocation(),
@@ -242,6 +245,7 @@ class TrainedModelAssignmentRebalancer {
             planDeployments.add(
                 new AssignmentPlan.Deployment(
                     taskParams.getDeploymentId(),
+                    taskParams.getModelId(),
                     taskParams.getModelBytes(),
                     taskParams.getNumberOfAllocations(),
                     taskParams.getThreadsPerAllocation(),

+ 1 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java

@@ -55,6 +55,7 @@ abstract class AbstractPreserveAllocations {
 
         return new Deployment(
             m.deploymentId(),
+            m.modelId(),
             m.memoryBytes(),
             m.allocations() - calculatePreservedAllocations(m),
             m.threadsPerAllocation(),

+ 7 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java

@@ -47,6 +47,7 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
      */
     public record Deployment(
         String deploymentId,
+        String modelId,
         long memoryBytes,
         int allocations,
         int threadsPerAllocation,
@@ -59,6 +60,7 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
     ) {
         public Deployment(
             String deploymentId,
+            String modelId,
             long modelBytes,
             int allocations,
             int threadsPerAllocation,
@@ -70,6 +72,7 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
         ) {
             this(
                 deploymentId,
+                modelId,
                 modelBytes,
                 allocations,
                 threadsPerAllocation,
@@ -96,7 +99,7 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
 
         public long estimateMemoryUsageBytes(int allocations) {
             return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
-                deploymentId,
+                modelId,
                 memoryBytes,
                 perDeploymentMemoryBytes,
                 perAllocationMemoryBytes,
@@ -106,24 +109,23 @@ public class AssignmentPlan implements Comparable<AssignmentPlan> {
 
         long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
             return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
-                deploymentId,
+                modelId,
                 memoryBytes,
                 perDeploymentMemoryBytes,
                 perAllocationMemoryBytes,
                 allocationsNew
             ) - StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
-                deploymentId,
+                modelId,
                 memoryBytes,
                 perDeploymentMemoryBytes,
                 perAllocationMemoryBytes,
                 allocationsOld
             );
-
         }
 
         long minimumMemoryRequiredBytes() {
             return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
-                deploymentId,
+                modelId,
                 memoryBytes,
                 perDeploymentMemoryBytes,
                 perAllocationMemoryBytes,

+ 2 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java

@@ -115,6 +115,7 @@ public class AssignmentPlanner {
             .map(
                 m -> new AssignmentPlan.Deployment(
                     m.deploymentId(),
+                    m.modelId(),
                     m.memoryBytes(),
                     1,
                     m.threadsPerAllocation(),
@@ -148,6 +149,7 @@ public class AssignmentPlanner {
                 : Map.of();
             return new AssignmentPlan.Deployment(
                 m.deploymentId(),
+                m.modelId(),
                 m.memoryBytes(),
                 m.allocations(),
                 m.threadsPerAllocation(),

+ 2 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java

@@ -127,6 +127,7 @@ public class ZoneAwareAssignmentPlanner {
                 d -> new AssignmentPlan.Deployment(
                     // replace each deployment with a new deployment
                     d.deploymentId(),
+                    d.modelId(),
                     d.memoryBytes(),
                     deploymentIdToTargetAllocationsPerZone.get(d.deploymentId()),
                     d.threadsPerAllocation(),
@@ -163,6 +164,7 @@ public class ZoneAwareAssignmentPlanner {
             .map(
                 d -> new AssignmentPlan.Deployment(
                     d.deploymentId(),
+                    d.modelId(),
                     d.memoryBytes(),
                     d.allocations(),
                     d.threadsPerAllocation(),

+ 74 - 18
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java

@@ -25,14 +25,14 @@ public class AssignmentPlanTests extends ESTestCase {
 
     public void testBuilderCtor_GivenDuplicateNode() {
         Node n = new Node("n_1", 100, 4);
-        AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, null, 0, 0);
+        AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", "m_1", 40, 1, 2, Map.of(), 0, null, 0, 0);
 
         expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n, n), List.of(m)));
     }
 
     public void testBuilderCtor_GivenDuplicateModel() {
         Node n = new Node("n_1", 100, 4);
-        Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, null, 0, 0);
+        Deployment m = new AssignmentPlan.Deployment("m_1", "m_1", 40, 1, 2, Map.of(), 0, null, 0, 0);
 
         expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n), List.of(m, m)));
     }
@@ -42,6 +42,7 @@ public class AssignmentPlanTests extends ESTestCase {
 
         { // old memory format
             AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(40).getBytes(),
                 1,
@@ -75,6 +76,7 @@ public class AssignmentPlanTests extends ESTestCase {
         }
         { // new memory format
             AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(20).getBytes(),
                 1,
@@ -112,6 +114,7 @@ public class AssignmentPlanTests extends ESTestCase {
         Node n = new Node("n_1", ByteSizeValue.ofMb(350).getBytes(), 4);
         {   // old memory format
             AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(30).getBytes(),
                 2,
@@ -140,6 +143,7 @@ public class AssignmentPlanTests extends ESTestCase {
         }
         {   // new memory format
             AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(25).getBytes(),
                 2,
@@ -173,7 +177,7 @@ public class AssignmentPlanTests extends ESTestCase {
         Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 4);
         {
             // old memory format
-            Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 0, null, 0, 0);
+            Deployment m = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 0, null, 0, 0);
 
             AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
 
@@ -193,6 +197,7 @@ public class AssignmentPlanTests extends ESTestCase {
         {
             // new memory format
             Deployment m = new Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(25).getBytes(),
                 2,
@@ -223,7 +228,7 @@ public class AssignmentPlanTests extends ESTestCase {
 
     public void testAssignModelToNode_GivenPreviouslyUnassignedModelDoesNotFit() {
         Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4);
-        Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 2, Map.of(), 0, null, 0, 0);
+        Deployment m = new AssignmentPlan.Deployment("m_1", "m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 2, Map.of(), 0, null, 0, 0);
 
         AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
         Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 1));
@@ -235,6 +240,7 @@ public class AssignmentPlanTests extends ESTestCase {
         { // old memory format
             Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4);
             AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(50).getBytes(),
                 2,
@@ -258,6 +264,7 @@ public class AssignmentPlanTests extends ESTestCase {
         { // new memory format
             Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4);
             AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(30).getBytes(),
                 2,
@@ -282,7 +289,7 @@ public class AssignmentPlanTests extends ESTestCase {
 
     public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocation() {
         Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 4);
-        Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 5, 1, Map.of(), 0, null, 0, 0);
+        Deployment m = new AssignmentPlan.Deployment("m_1", "m_1", ByteSizeValue.ofMb(100).getBytes(), 5, 1, Map.of(), 0, null, 0, 0);
 
         AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
         Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 5));
@@ -296,6 +303,7 @@ public class AssignmentPlanTests extends ESTestCase {
     public void testAssignModelToNode_GivenNotEnoughCores_AndMultipleThreadsPerAllocation() {
         Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 5);
         AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             3,
@@ -319,6 +327,7 @@ public class AssignmentPlanTests extends ESTestCase {
     public void testAssignModelToNode_GivenSameModelAssignedTwice() {
         Node n = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8);
         Deployment m = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(50).getBytes(),
             4,
@@ -362,7 +371,7 @@ public class AssignmentPlanTests extends ESTestCase {
 
     public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() {
         Node n = new Node("n_1", 100, 5);
-        AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0, null, 0, 0);
+        AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", "m_1", 101, 1, 1, Map.of(), 0, null, 0, 0);
 
         AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
 
@@ -373,13 +382,14 @@ public class AssignmentPlanTests extends ESTestCase {
         Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5);
         {
             // old memory format
-            Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(31).getBytes(), 1, 1, Map.of("n_1", 1), 0, null, 0, 0);
+            Deployment m = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(31).getBytes(), 1, 1, Map.of("n_1", 1), 0, null, 0, 0);
             AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
             assertThat(builder.canAssign(m, n, 1), is(true));
         }
         {
             // new memory format
             Deployment m = new Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(25).getBytes(),
                 1,
@@ -398,6 +408,7 @@ public class AssignmentPlanTests extends ESTestCase {
     public void testCanAssign_GivenEnoughMemory() {
         Node n = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 5);
         AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             3,
@@ -422,13 +433,25 @@ public class AssignmentPlanTests extends ESTestCase {
         Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5);
 
         {
-            Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 2), 0, null, 0, 0);
+            Deployment m = new AssignmentPlan.Deployment(
+                "m_1",
+                "m_1",
+                ByteSizeValue.ofMb(30).getBytes(),
+                3,
+                2,
+                Map.of("n_1", 2),
+                0,
+                null,
+                0,
+                0
+            );
             AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
             builder.assignModelToNode(m, n, 2);
             planSatisfyingPreviousAssignments = builder.build();
         }
         {
             AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(30).getBytes(),
                 3,
@@ -453,6 +476,7 @@ public class AssignmentPlanTests extends ESTestCase {
         AssignmentPlan planWithFewerAllocations;
         Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5);
         AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(30).getBytes(),
             3,
@@ -485,13 +509,25 @@ public class AssignmentPlanTests extends ESTestCase {
         Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5);
 
         {
-            Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 1), 0, null, 0, 0);
+            Deployment m = new AssignmentPlan.Deployment(
+                "m_1",
+                "m_1",
+                ByteSizeValue.ofMb(30).getBytes(),
+                3,
+                2,
+                Map.of("n_1", 1),
+                0,
+                null,
+                0,
+                0
+            );
             AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m));
             builder.assignModelToNode(m, n, 2);
             planUsingMoreMemory = builder.build();
         }
         {
             AssignmentPlan.Deployment m = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(29).getBytes(),
                 3,
@@ -517,6 +553,7 @@ public class AssignmentPlanTests extends ESTestCase {
         {
             // old memory format
             AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(50).getBytes(),
                 1,
@@ -528,6 +565,7 @@ public class AssignmentPlanTests extends ESTestCase {
                 0
             );
             AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment(
+                "m_2",
                 "m_2",
                 ByteSizeValue.ofMb(30).getBytes(),
                 2,
@@ -539,6 +577,7 @@ public class AssignmentPlanTests extends ESTestCase {
                 0
             );
             AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment(
+                "m_3",
                 "m_3",
                 ByteSizeValue.ofMb(20).getBytes(),
                 4,
@@ -560,6 +599,7 @@ public class AssignmentPlanTests extends ESTestCase {
         {
             // new memory format
             AssignmentPlan.Deployment deployment1 = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(50).getBytes(),
                 1,
@@ -571,6 +611,7 @@ public class AssignmentPlanTests extends ESTestCase {
                 ByteSizeValue.ofMb(10).getBytes()
             );
             AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment(
+                "m_2",
                 "m_2",
                 ByteSizeValue.ofMb(30).getBytes(),
                 2,
@@ -582,6 +623,7 @@ public class AssignmentPlanTests extends ESTestCase {
                 ByteSizeValue.ofMb(10).getBytes()
             );
             AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment(
+                "m_3",
                 "m_3",
                 ByteSizeValue.ofMb(20).getBytes(),
                 4,
@@ -605,9 +647,9 @@ public class AssignmentPlanTests extends ESTestCase {
     public void testSatisfiesAllDeployments_GivenOneModelHasOneAllocationLess() {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, null, 0, 0);
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment3 = new Deployment("m_3", "m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
         AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3))
             .assignModelToNode(deployment1, node1, 1)
             .assignModelToNode(deployment2, node2, 2)
@@ -620,9 +662,9 @@ public class AssignmentPlanTests extends ESTestCase {
     public void testArePreviouslyAssignedDeploymentsAssigned_GivenTrue() {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0);
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, null, 0, 0);
-        Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, null, 0, 0);
+        Deployment deployment3 = new Deployment("m_3", "m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
         AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3))
             .assignModelToNode(deployment1, node1, 1)
             .assignModelToNode(deployment2, node2, 1)
@@ -633,8 +675,8 @@ public class AssignmentPlanTests extends ESTestCase {
     public void testArePreviouslyAssignedDeploymentsAssigned_GivenFalse() {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0);
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, null, 0, 0);
         AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2))
             .assignModelToNode(deployment1, node1, 1)
             .build();
@@ -644,8 +686,20 @@ public class AssignmentPlanTests extends ESTestCase {
     public void testCountPreviouslyAssignedThatAreStillAssigned() {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4);
-        Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0);
+        Deployment deployment1 = new AssignmentPlan.Deployment(
+            "m_1",
+            "m_1",
+            ByteSizeValue.ofMb(50).getBytes(),
+            1,
+            2,
+            Map.of(),
+            3,
+            null,
+            0,
+            0
+        );
         AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment(
+            "m_2",
             "m_2",
             ByteSizeValue.ofMb(30).getBytes(),
             2,
@@ -657,6 +711,7 @@ public class AssignmentPlanTests extends ESTestCase {
             0
         );
         AssignmentPlan.Deployment deployment3 = new AssignmentPlan.Deployment(
+            "m_3",
             "m_3",
             ByteSizeValue.ofMb(20).getBytes(),
             4,
@@ -668,6 +723,7 @@ public class AssignmentPlanTests extends ESTestCase {
             0
         );
         AssignmentPlan.Deployment deployment4 = new AssignmentPlan.Deployment(
+            "m_4",
             "m_4",
             ByteSizeValue.ofMb(20).getBytes(),
             4,

+ 181 - 43
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java

@@ -42,13 +42,25 @@ public class AssignmentPlannerTests extends ESTestCase {
     public void testModelThatDoesNotFitInMemory() {
         { // Without perDeploymentMemory and perAllocationMemory specified
             List<Node> nodes = List.of(new Node("n_1", scaleNodeSize(50), 4));
-            Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(51).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
+            Deployment deployment = new AssignmentPlan.Deployment(
+                "m_1",
+                "m_1",
+                ByteSizeValue.ofMb(51).getBytes(),
+                4,
+                1,
+                Map.of(),
+                0,
+                null,
+                0,
+                0
+            );
             AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan();
             assertThat(plan.assignments(deployment), isEmpty());
         }
         { // With perDeploymentMemory and perAllocationMemory specified
             List<Node> nodes = List.of(new Node("n_1", scaleNodeSize(55), 4));
             Deployment deployment = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(50).getBytes(),
                 4,
@@ -66,7 +78,18 @@ public class AssignmentPlannerTests extends ESTestCase {
 
     public void testModelWithThreadsPerAllocationNotFittingOnAnyNode() {
         List<Node> nodes = List.of(new Node("n_1", scaleNodeSize(100), 4), new Node("n_2", scaleNodeSize(100), 5));
-        Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(1).getBytes(), 1, 6, Map.of(), 0, null, 0, 0);
+        Deployment deployment = new AssignmentPlan.Deployment(
+            "m_1",
+            "m_1",
+            ByteSizeValue.ofMb(1).getBytes(),
+            1,
+            6,
+            Map.of(),
+            0,
+            null,
+            0,
+            0
+        );
         AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan();
         assertThat(plan.assignments(deployment), isEmpty());
     }
@@ -74,19 +97,31 @@ public class AssignmentPlannerTests extends ESTestCase {
     public void testSingleModelThatFitsFullyOnSingleNode() {
         {
             Node node = new Node("n_1", scaleNodeSize(100), 4);
-            Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
+            Deployment deployment = new AssignmentPlan.Deployment(
+                "m_1",
+                "m_1",
+                ByteSizeValue.ofMb(100).getBytes(),
+                1,
+                1,
+                Map.of(),
+                0,
+                null,
+                0,
+                0
+            );
             AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan();
             assertModelFullyAssignedToNode(plan, deployment, node);
         }
         {
             Node node = new Node("n_1", scaleNodeSize(1000), 8);
-            Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(1000).getBytes(), 8, 1, Map.of(), 0, null, 0, 0);
+            Deployment deployment = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(1000).getBytes(), 8, 1, Map.of(), 0, null, 0, 0);
             AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan();
             assertModelFullyAssignedToNode(plan, deployment, node);
         }
         {
             Node node = new Node("n_1", scaleNodeSize(10000), 16);
             AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(10000).getBytes(),
                 1,
@@ -102,7 +137,18 @@ public class AssignmentPlannerTests extends ESTestCase {
         }
         {
             Node node = new Node("n_1", scaleNodeSize(100), 4);
-            Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
+            Deployment deployment = new AssignmentPlan.Deployment(
+                "m_1",
+                "m_1",
+                ByteSizeValue.ofMb(100).getBytes(),
+                1,
+                1,
+                Map.of(),
+                0,
+                null,
+                0,
+                0
+            );
             AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan();
             assertModelFullyAssignedToNode(plan, deployment, node);
         }
@@ -112,6 +158,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         {
             Node node = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 4);
             Deployment deployment = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(100).getBytes(),
                 1,
@@ -128,6 +175,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         {
             Node node = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8);
             Deployment deployment = new Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(100).getBytes(),
                 8,
@@ -146,7 +194,18 @@ public class AssignmentPlannerTests extends ESTestCase {
     public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFullyAssignedOnOneNode() {
         Node node1 = new Node("n_1", scaleNodeSize(100), 4);
         Node node2 = new Node("n_2", scaleNodeSize(100), 4);
-        AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
+        AssignmentPlan.Deployment deployment = new Deployment(
+            "m_1",
+            "m_1",
+            ByteSizeValue.ofMb(100).getBytes(),
+            4,
+            1,
+            Map.of(),
+            0,
+            null,
+            0,
+            0
+        );
 
         AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan();
 
@@ -162,6 +221,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4);
         AssignmentPlan.Deployment deployment = new Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             4,
@@ -184,7 +244,18 @@ public class AssignmentPlannerTests extends ESTestCase {
     }
 
     public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerAllocation() {
-        AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 1, Map.of(), 0, null, 0, 0);
+        AssignmentPlan.Deployment deployment = new Deployment(
+            "m_1",
+            "m_1",
+            ByteSizeValue.ofMb(30).getBytes(),
+            10,
+            1,
+            Map.of(),
+            0,
+            null,
+            0,
+            0
+        );
         // Single node
         {
             Node node = new Node("n_1", scaleNodeSize(100), 4);
@@ -219,6 +290,7 @@ public class AssignmentPlannerTests extends ESTestCase {
 
     public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerAllocation_NewMemoryFields() {
         AssignmentPlan.Deployment deployment = new Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             10,
@@ -266,10 +338,10 @@ public class AssignmentPlannerTests extends ESTestCase {
         Node node2 = new Node("n_2", 2 * scaleNodeSize(50), 7);
         Node node3 = new Node("n_3", 2 * scaleNodeSize(50), 2);
         Node node4 = new Node("n_4", 2 * scaleNodeSize(50), 2);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 4, Map.of(), 0, null, 0, 0);
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 2, 3, Map.of(), 0, null, 0, 0);
-        Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, null, 0, 0);
-        Deployment deployment4 = new Deployment("m_4", ByteSizeValue.ofMb(50).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 4, Map.of(), 0, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(50).getBytes(), 2, 3, Map.of(), 0, null, 0, 0);
+        Deployment deployment3 = new Deployment("m_3", "m_3", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, null, 0, 0);
+        Deployment deployment4 = new Deployment("m_4", "m_4", ByteSizeValue.ofMb(50).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
 
         AssignmentPlan plan = new AssignmentPlanner(
             List.of(node1, node2, node3, node4),
@@ -322,6 +394,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         Node node3 = new Node("n_3", ByteSizeValue.ofMb(900).getBytes(), 2);
         Node node4 = new Node("n_4", ByteSizeValue.ofMb(900).getBytes(), 2);
         Deployment deployment1 = new Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(50).getBytes(),
             2,
@@ -333,6 +406,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             ByteSizeValue.ofMb(50).getBytes()
         );
         Deployment deployment2 = new Deployment(
+            "m_2",
             "m_2",
             ByteSizeValue.ofMb(50).getBytes(),
             2,
@@ -344,6 +418,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             ByteSizeValue.ofMb(50).getBytes()
         );
         Deployment deployment3 = new Deployment(
+            "m_3",
             "m_3",
             ByteSizeValue.ofMb(50).getBytes(),
             1,
@@ -355,6 +430,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             ByteSizeValue.ofMb(50).getBytes()
         );
         Deployment deployment4 = new Deployment(
+            "m_4",
             "m_4",
             ByteSizeValue.ofMb(50).getBytes(),
             2,
@@ -412,7 +488,18 @@ public class AssignmentPlannerTests extends ESTestCase {
     }
 
     public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerAllocation() {
-        Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 3, Map.of(), 0, null, 0, 0);
+        Deployment deployment = new AssignmentPlan.Deployment(
+            "m_1",
+            "m_1",
+            ByteSizeValue.ofMb(30).getBytes(),
+            10,
+            3,
+            Map.of(),
+            0,
+            null,
+            0,
+            0
+        );
         // Single node
         {
             Node node = new Node("n_1", scaleNodeSize(100), 4);
@@ -447,6 +534,7 @@ public class AssignmentPlannerTests extends ESTestCase {
 
     public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerAllocation_NewMemoryFields() {
         Deployment deployment = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(50).getBytes(),
             10,
@@ -492,6 +580,7 @@ public class AssignmentPlannerTests extends ESTestCase {
     public void testModelWithPreviousAssignmentAndNoMoreCoresAvailable() {
         Node node = new Node("n_1", scaleNodeSize(100), 4);
         AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(30).getBytes(),
             4,
@@ -518,18 +607,18 @@ public class AssignmentPlannerTests extends ESTestCase {
             new Node("n_6", ByteSizeValue.ofGb(32).getBytes(), 16)
         );
         List<Deployment> deployments = List.of(
-            new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0, null, 0, 0),
-            new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0, null, 0, 0),
-            new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0, null, 0, 0),
-            new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0, null, 0, 0),
-            new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0, null, 0, 0),
-            new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0, null, 0, 0),
-            new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0, null, 0, 0),
-            new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, null, 0, 0),
-            new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, null, 0, 0),
-            new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, null, 0, 0),
-            new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, null, 0, 0),
-            new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, null, 0, 0)
+            new Deployment("m_1", "m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0, null, 0, 0),
+            new AssignmentPlan.Deployment("m_2", "m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0, null, 0, 0),
+            new AssignmentPlan.Deployment("m_3", "m_4", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0, null, 0, 0),
+            new Deployment("m_4", "m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0, null, 0, 0),
+            new Deployment("m_5", "m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0, null, 0, 0),
+            new Deployment("m_6", "m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0, null, 0, 0),
+            new AssignmentPlan.Deployment("m_7", "m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0, null, 0, 0),
+            new Deployment("m_8", "m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, null, 0, 0),
+            new Deployment("m_9", "m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, null, 0, 0),
+            new AssignmentPlan.Deployment("m_10", "m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, null, 0, 0),
+            new Deployment("m_11", "m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, null, 0, 0),
+            new Deployment("m_12", "m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, null, 0, 0)
         );
 
         AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan();
@@ -556,6 +645,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         // Use mix of old and new memory fields
         List<Deployment> deployments = List.of(
             new Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(100).getBytes(),
                 10,
@@ -566,8 +656,9 @@ public class AssignmentPlannerTests extends ESTestCase {
                 ByteSizeValue.ofMb(400).getBytes(),
                 ByteSizeValue.ofMb(100).getBytes()
             ),
-            new Deployment("m_2", ByteSizeValue.ofMb(100).getBytes(), 3, 1, Map.of("n_3", 2), 0, null, 0, 0),
+            new Deployment("m_2", "m_2", ByteSizeValue.ofMb(100).getBytes(), 3, 1, Map.of("n_3", 2), 0, null, 0, 0),
             new Deployment(
+                "m_3",
                 "m_3",
                 ByteSizeValue.ofMb(50).getBytes(),
                 3,
@@ -579,6 +670,7 @@ public class AssignmentPlannerTests extends ESTestCase {
                 ByteSizeValue.ofMb(50).getBytes()
             ),
             new Deployment(
+                "m_4",
                 "m_4",
                 ByteSizeValue.ofMb(50).getBytes(),
                 4,
@@ -590,6 +682,7 @@ public class AssignmentPlannerTests extends ESTestCase {
                 ByteSizeValue.ofMb(100).getBytes()
             ),
             new Deployment(
+                "m_5",
                 "m_5",
                 ByteSizeValue.ofMb(500).getBytes(),
                 2,
@@ -601,6 +694,7 @@ public class AssignmentPlannerTests extends ESTestCase {
                 ByteSizeValue.ofMb(100).getBytes()
             ),
             new Deployment(
+                "m_6",
                 "m_6",
                 ByteSizeValue.ofMb(50).getBytes(),
                 12,
@@ -612,6 +706,7 @@ public class AssignmentPlannerTests extends ESTestCase {
                 ByteSizeValue.ofMb(20).getBytes()
             ),
             new Deployment(
+                "m_7",
                 "m_7",
                 ByteSizeValue.ofMb(50).getBytes(),
                 12,
@@ -622,11 +717,11 @@ public class AssignmentPlannerTests extends ESTestCase {
                 ByteSizeValue.ofMb(300).getBytes(),
                 ByteSizeValue.ofMb(50).getBytes()
             ),
-            new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, null, 0, 0),
-            new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, null, 0, 0),
-            new Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, null, 0, 0),
-            new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, null, 0, 0),
-            new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, null, 0, 0)
+            new Deployment("m_8", "m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, null, 0, 0),
+            new Deployment("m_9", "m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, null, 0, 0),
+            new Deployment("m_10", "m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, null, 0, 0),
+            new Deployment("m_11", "m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, null, 0, 0),
+            new Deployment("m_12", "m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, null, 0, 0)
         );
 
         AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan();
@@ -731,6 +826,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             previousModelsPlusNew.add(
                 new AssignmentPlan.Deployment(
                     m.deploymentId(),
+                    m.modelId(),
                     m.memoryBytes(),
                     m.allocations(),
                     m.threadsPerAllocation(),
@@ -754,6 +850,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         Node node2 = new Node("n_2", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2);
         Node node3 = new Node("n_3", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2);
         Deployment deployment1 = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(1200).getBytes(),
             3,
@@ -764,7 +861,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             0,
             0
         );
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
         AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment1, deployment2))
             .computePlan();
         assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L));
@@ -790,6 +887,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         Node node1 = new Node("n_1", ByteSizeValue.ofGb(6).getBytes(), 2);
         Node node2 = new Node("n_2", ByteSizeValue.ofGb(6).getBytes(), 2);
         AssignmentPlan.Deployment deployment1 = new Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(1200).getBytes(),
             3,
@@ -801,6 +899,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             0
         );
         AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment(
+            "m_2",
             "m_2",
             ByteSizeValue.ofMb(1100).getBytes(),
             1,
@@ -829,8 +928,30 @@ public class AssignmentPlannerTests extends ESTestCase {
 
     public void testGivenPreviouslyAssignedDeployments_CannotAllBeAllocated() {
         Node node1 = new Node("n_1", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2);
-        AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1, null, 0, 0);
-        AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1, null, 0, 0);
+        AssignmentPlan.Deployment deployment1 = new Deployment(
+            "m_1",
+            "m_1",
+            ByteSizeValue.ofMb(1200).getBytes(),
+            1,
+            1,
+            Map.of(),
+            1,
+            null,
+            0,
+            0
+        );
+        AssignmentPlan.Deployment deployment2 = new Deployment(
+            "m_2",
+            "m_2",
+            ByteSizeValue.ofMb(1100).getBytes(),
+            1,
+            1,
+            Map.of(),
+            1,
+            null,
+            0,
+            0
+        );
 
         AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1), List.of(deployment1, deployment2)).computePlan();
 
@@ -840,9 +961,20 @@ public class AssignmentPlannerTests extends ESTestCase {
     public void testGivenClusterResize_AllocationShouldNotExceedMemoryConstraints() {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment2 = new AssignmentPlan.Deployment(
+            "m_2",
+            "m_2",
+            ByteSizeValue.ofMb(800).getBytes(),
+            1,
+            1,
+            Map.of(),
+            0,
+            null,
+            0,
+            0
+        );
+        Deployment deployment3 = new Deployment("m_3", "m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
 
         // First only start m_1
         AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan();
@@ -882,9 +1014,9 @@ public class AssignmentPlannerTests extends ESTestCase {
     public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(2600).getBytes(), 2);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(2600).getBytes(), 2);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment3 = new Deployment("m_3", "m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
 
         // First only start m_1
         AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan();
@@ -953,9 +1085,9 @@ public class AssignmentPlannerTests extends ESTestCase {
         // Ensure that plan is removing previously allocated models if not enough memory is available
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment3 = new Deployment("m_3", "m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
 
         // Create a plan where all deployments are assigned at least once
         AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2, deployment3))
@@ -981,6 +1113,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(700).getBytes(), 2);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 2);
         Deployment deployment1 = new Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             2,
@@ -992,6 +1125,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             ByteSizeValue.ofMb(100).getBytes()
         );
         Deployment deployment2 = new Deployment(
+            "m_2",
             "m_2",
             ByteSizeValue.ofMb(100).getBytes(),
             1,
@@ -1003,6 +1137,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             ByteSizeValue.ofMb(150).getBytes()
         );
         Deployment deployment3 = new Deployment(
+            "m_3",
             "m_3",
             ByteSizeValue.ofMb(50).getBytes(),
             1,
@@ -1048,6 +1183,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             deployments.add(
                 new Deployment(
                     m.deploymentId(),
+                    m.modelId(),
                     m.memoryBytes(),
                     m.allocations(),
                     m.threadsPerAllocation(),
@@ -1116,6 +1252,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         // randomly choose between old and new memory fields format
         if (randomBoolean()) {
             return new Deployment(
+                "m_" + idSuffix,
                 "m_" + idSuffix,
                 randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(10).getBytes()),
                 randomIntBetween(1, 32),
@@ -1128,6 +1265,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             );
         } else {
             return new Deployment(
+                "m_" + idSuffix,
                 "m_" + idSuffix,
                 randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()),
                 randomIntBetween(1, 32),
@@ -1165,7 +1303,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         }
         List<Deployment> deployments = new ArrayList<>();
         for (int i = 0; i < modelsSize; i++) {
-            deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0, null, 0, 0));
+            deployments.add(new Deployment("m_" + i, "m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0, null, 0, 0));
         }
 
         // Check plan is computed without OOM exception

+ 7 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java

@@ -25,8 +25,8 @@ public class PreserveAllAllocationsTests extends ESTestCase {
     public void testGivenNoPreviousAssignments() {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, null, 0, 0);
         PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(
             List.of(node1, node2),
             List.of(deployment1, deployment2)
@@ -39,6 +39,7 @@ public class PreserveAllAllocationsTests extends ESTestCase {
             Node node1 = new Node("n_1", ByteSizeValue.ofMb(640).getBytes(), 8);
             Node node2 = new Node("n_2", ByteSizeValue.ofMb(640).getBytes(), 8);
             Deployment deployment1 = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(30).getBytes(),
                 2,
@@ -50,6 +51,7 @@ public class PreserveAllAllocationsTests extends ESTestCase {
                 0
             );
             Deployment deployment2 = new Deployment(
+                "m_2",
                 "m_2",
                 ByteSizeValue.ofMb(50).getBytes(),
                 6,
@@ -122,6 +124,7 @@ public class PreserveAllAllocationsTests extends ESTestCase {
             Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8);
             Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 8);
             Deployment deployment1 = new AssignmentPlan.Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(30).getBytes(),
                 2,
@@ -133,6 +136,7 @@ public class PreserveAllAllocationsTests extends ESTestCase {
                 ByteSizeValue.ofMb(10).getBytes()
             );
             Deployment deployment2 = new Deployment(
+                "m_2",
                 "m_2",
                 ByteSizeValue.ofMb(50).getBytes(),
                 6,
@@ -208,7 +212,7 @@ public class PreserveAllAllocationsTests extends ESTestCase {
 
     public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() {
         Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4);
-        Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, null, 0, 0);
+        Deployment deployment = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, null, 0, 0);
         PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(List.of(node), List.of(deployment));
 
         AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build();

+ 30 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java

@@ -26,8 +26,30 @@ public class PreserveOneAllocationTests extends ESTestCase {
     public void testGivenNoPreviousAssignments() {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4);
-        Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
-        AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new AssignmentPlan.Deployment(
+            "m_1",
+            "m_1",
+            ByteSizeValue.ofMb(30).getBytes(),
+            2,
+            1,
+            Map.of(),
+            0,
+            null,
+            0,
+            0
+        );
+        AssignmentPlan.Deployment deployment2 = new Deployment(
+            "m_2",
+            "m_2",
+            ByteSizeValue.ofMb(30).getBytes(),
+            2,
+            4,
+            Map.of(),
+            0,
+            null,
+            0,
+            0
+        );
         PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node1, node2), List.of(deployment1, deployment2));
 
         List<Node> nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations();
@@ -42,8 +64,9 @@ public class PreserveOneAllocationTests extends ESTestCase {
             // old memory format
             Node node1 = new Node("n_1", ByteSizeValue.ofMb(640).getBytes(), 8);
             Node node2 = new Node("n_2", ByteSizeValue.ofMb(640).getBytes(), 8);
-            Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of("n_1", 1), 1, null, 0, 0);
+            Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of("n_1", 1), 1, null, 0, 0);
             Deployment deployment2 = new Deployment(
+                "m_2",
                 "m_2",
                 ByteSizeValue.ofMb(50).getBytes(),
                 6,
@@ -121,6 +144,7 @@ public class PreserveOneAllocationTests extends ESTestCase {
             Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 8);
             Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 8);
             Deployment deployment1 = new Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(30).getBytes(),
                 2,
@@ -132,6 +156,7 @@ public class PreserveOneAllocationTests extends ESTestCase {
                 ByteSizeValue.ofMb(10).getBytes()
             );
             Deployment deployment2 = new Deployment(
+                "m_2",
                 "m_2",
                 ByteSizeValue.ofMb(50).getBytes(),
                 6,
@@ -211,7 +236,7 @@ public class PreserveOneAllocationTests extends ESTestCase {
         {
             // old memory format
             Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4);
-            Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, null, 0, 0);
+            Deployment deployment = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, null, 0, 0);
             PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment));
 
             AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build();
@@ -227,6 +252,7 @@ public class PreserveOneAllocationTests extends ESTestCase {
             // new memory format
             Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4);
             Deployment deployment = new Deployment(
+                "m_1",
                 "m_1",
                 ByteSizeValue.ofMb(30).getBytes(),
                 2,

+ 16 - 9
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java

@@ -36,7 +36,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
 
     public void testGivenOneModel_OneNode_OneZone_DoesNotFit() {
         Node node = new Node("n_1", 100, 1);
-        AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0, null, 0, 0);
+        AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", "m_1", 100, 1, 2, Map.of(), 0, null, 0, 0);
 
         AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan();
 
@@ -46,6 +46,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
     public void testGivenOneModel_OneNode_OneZone_FullyFits() {
         Node node = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4);
         AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             2,
@@ -65,6 +66,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
     public void testGivenOneModel_OneNode_OneZone_PartiallyFits() {
         Node node = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 5);
         AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             3,
@@ -87,6 +89,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4);
         AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             1,
@@ -115,6 +118,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4);
         AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             2,
@@ -142,6 +146,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
         Node node1 = new Node("n_1", ByteSizeValue.ofGb(16).getBytes(), 8);
         Node node2 = new Node("n_2", ByteSizeValue.ofGb(16).getBytes(), 8);
         AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             4,
@@ -169,6 +174,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4);
         AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment(
+            "m_1",
             "m_1",
             ByteSizeValue.ofMb(100).getBytes(),
             3,
@@ -200,9 +206,9 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
         Node node4 = new Node("n_4", ByteSizeValue.ofMb(1000).getBytes(), 4);
         Node node5 = new Node("n_5", ByteSizeValue.ofMb(1000).getBytes(), 4);
         Node node6 = new Node("n_6", ByteSizeValue.ofMb(1000).getBytes(), 4);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 6, 2, Map.of(), 0, null, 0, 0);
-        Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(30).getBytes(), 2, 3, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(30).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(30).getBytes(), 6, 2, Map.of(), 0, null, 0, 0);
+        Deployment deployment3 = new Deployment("m_3", "m_3", ByteSizeValue.ofMb(30).getBytes(), 2, 3, Map.of(), 0, null, 0, 0);
 
         Map<List<String>, List<Node>> nodesByZone = Map.of(
             List.of("z_1"),
@@ -248,8 +254,8 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4);
         Node node3 = new Node("n_3", ByteSizeValue.ofMb(1000).getBytes(), 4);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
 
         AssignmentPlan plan = new ZoneAwareAssignmentPlanner(
             Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2), List.of("z3"), List.of(node3)),
@@ -282,6 +288,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
             previousModelsPlusNew.add(
                 new AssignmentPlan.Deployment(
                     m.deploymentId(),
+                    m.modelId(),
                     m.memoryBytes(),
                     m.allocations(),
                     m.threadsPerAllocation(),
@@ -303,9 +310,9 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
     public void testGivenClusterResize_GivenOneZone_ShouldAllocateEachModelAtLeastOnce() {
         Node node1 = new Node("n_1", ByteSizeValue.ofMb(2580).getBytes(), 2);
         Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2);
-        Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
-        Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment1 = new Deployment("m_1", "m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment2 = new Deployment("m_2", "m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0);
+        Deployment deployment3 = new Deployment("m_3", "m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0);
 
         // First only start m_1
         AssignmentPlan assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1, node2)), List.of(deployment1))