Browse Source

[ML] refactor model allocation builders (#76463)

refactoring model allocation builders. The current design is fairly unintuitive and can make adding new states or predicates annoying.
Benjamin Trent 4 years ago
parent
commit
964816091e

+ 38 - 22
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java

@@ -215,8 +215,12 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         });
     }
 
-    private static ClusterState update(ClusterState currentState, TrainedModelAllocationMetadata.Builder modelAllocations) {
-        if (modelAllocations.isChanged()) {
+    private static ClusterState update(
+        ClusterState currentState,
+        TrainedModelAllocationMetadata.Builder modelAllocations,
+        boolean force
+    ) {
+        if (force || modelAllocations.isChanged()) {
             return ClusterState.builder(currentState)
                 .metadata(
                     Metadata.builder(currentState.metadata()).putCustom(TrainedModelAllocationMetadata.NAME, modelAllocations.build())
@@ -227,6 +231,10 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         }
     }
 
+    private static ClusterState update(ClusterState currentState, TrainedModelAllocationMetadata.Builder modelAllocations) {
+        return update(currentState, modelAllocations, false);
+    }
+
     ClusterState createModelAllocation(ClusterState currentState, StartTrainedModelDeploymentAction.TaskParams params) {
         if (MlMetadata.getMlMetadata(currentState).isResetMode()) {
             throw new ElasticsearchStatusException(
@@ -239,20 +247,21 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         if (builder.hasModel(params.getModelId())) {
             throw new ResourceAlreadyExistsException("allocation for model with id [{}] already exist", params.getModelId());
         }
+        TrainedModelAllocation.Builder allocationBuilder = TrainedModelAllocation.Builder.empty(params);
 
         Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
-        builder.addNewAllocation(params);
         for (DiscoveryNode node : currentState.getNodes().getAllNodes()) {
             if (StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node)
                 && shuttingDownNodes.contains(node.getId()) == false) {
                 Optional<String> maybeError = nodeHasCapacity(currentState, params, node);
                 if (maybeError.isPresent()) {
-                    builder.addFailedNode(params.getModelId(), node.getId(), maybeError.get());
+                    allocationBuilder.addNewFailedRoutingEntry(node.getId(), maybeError.get());
                 } else {
-                    builder.addNode(params.getModelId(), node.getId());
+                    allocationBuilder.addNewRoutingEntry(node.getId());
                 }
             }
         }
+        builder.addNewAllocation(params.getModelId(), allocationBuilder);
         return update(currentState, builder);
     }
 
@@ -266,10 +275,9 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         if (existingAllocation.getAllocationState().equals(AllocationState.STOPPING)) {
             return clusterState;
         }
-
         TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(clusterState);
-        builder.setAllocationToStopping(modelId);
-        return update(clusterState, builder);
+        final boolean isChanged = builder.getAllocation(modelId).stopAllocation().isChanged();
+        return update(clusterState, builder, isChanged);
     }
 
     static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTrainedModelAllocationStateAction.Request request) {
@@ -280,13 +288,14 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
             () -> new ParameterizedMessage("[{}] [{}] current metadata before update {}", modelId, nodeId, Strings.toString(metadata))
         );
         final TrainedModelAllocation existingAllocation = metadata.getModelAllocation(modelId);
-
+        final TrainedModelAllocationMetadata.Builder builder =  TrainedModelAllocationMetadata.builder(currentState);
         // If state is stopped, this indicates the node process is closed, remove the node from the allocation
         if (request.getRoutingState().getState().equals(RoutingState.STOPPED)) {
             if (existingAllocation == null || existingAllocation.isRoutedToNode(nodeId) == false) {
                 return currentState;
             }
-            return update(currentState, TrainedModelAllocationMetadata.builder(currentState).removeNode(modelId, nodeId));
+            final boolean isChanged = builder.getAllocation(modelId).removeRoutingEntry(nodeId).isChanged();
+            return update(currentState, builder, isChanged);
         }
 
         if (existingAllocation == null) {
@@ -305,9 +314,8 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         if (existingAllocation.isRoutedToNode(nodeId) == false) {
             throw new ResourceNotFoundException("allocation for model with id [{}]] is not routed to node [{}]", modelId, nodeId);
         }
-        TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
-        builder.updateAllocation(modelId, nodeId, request.getRoutingState());
-        return update(currentState, builder);
+        final boolean isChanged = builder.getAllocation(modelId).updateExistingRoutingEntry(nodeId, request.getRoutingState()).isChanged();
+        return update(currentState, builder, isChanged);
     }
 
     static ClusterState removeAllocation(ClusterState currentState, String modelId) {
@@ -332,14 +340,15 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
     }
 
     ClusterState addRemoveAllocationNodes(ClusterState currentState) {
-        TrainedModelAllocationMetadata previousState = TrainedModelAllocationMetadata.fromState(currentState);
-        TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
+        final TrainedModelAllocationMetadata previousState = TrainedModelAllocationMetadata.fromState(currentState);
+        final TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
         Map<String, List<String>> removedNodeModelLookUp = new HashMap<>();
         Set<String> shuttingDownNodes = nodesShuttingDown(currentState);
         // TODO: make more efficient, right now this is O(nm) where n = sizeof(models) and m = sizeof(nodes)
         // It could probably be O(max(n, m))
         // Add nodes and keep track of currently routed nodes
         // Should we indicate a partial allocation somehow if some nodes don't have space?
+        boolean isChanged = false;
         for (Map.Entry<String, TrainedModelAllocation> modelAllocationEntry : previousState.modelAllocations().entrySet()) {
             // Don't bother adding/removing nodes if this allocation is stopping
             if (modelAllocationEntry.getValue().getAllocationState().equals(AllocationState.STOPPING)) {
@@ -351,10 +360,14 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
                 if (shuttingDownNodes.contains(node.getId()) == false
                     && StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode(node)
                     && modelAllocationEntry.getValue().isRoutedToNode(node.getId()) == false) {
-                    nodeHasCapacity(currentState, modelAllocationEntry.getValue().getTaskParams(), node).ifPresentOrElse(
-                        (error) -> builder.addFailedNode(modelAllocationEntry.getKey(), node.getId(), error),
-                        () -> builder.addNode(modelAllocationEntry.getKey(), node.getId())
-                    );
+                    Optional<String> failure = nodeHasCapacity(currentState, modelAllocationEntry.getValue().getTaskParams(), node);
+                    if (failure.isPresent()) {
+                        isChanged |= builder.getAllocation(modelAllocationEntry.getKey())
+                            .addNewFailedRoutingEntry(node.getId(), failure.get())
+                            .isChanged();
+                    } else {
+                        isChanged |= builder.getAllocation(modelAllocationEntry.getKey()).addNewRoutingEntry(node.getId()).isChanged();
+                    }
                 }
             }
             for (String nodeId : modelAllocationEntry.getValue().getNodeRoutingTable().keySet()) {
@@ -376,10 +389,12 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         for (Map.Entry<String, List<String>> nodeToModels : removedNodeModelLookUp.entrySet()) {
             final String nodeId = nodeToModels.getKey();
             for (String modelId : nodeToModels.getValue()) {
-                builder.removeNode(modelId, nodeId);
+                isChanged |= Optional.ofNullable(builder.getAllocation(modelId))
+                    .map(allocation -> allocation.removeRoutingEntry(nodeId).isChanged())
+                    .orElse(false);
             }
         }
-        return update(currentState, builder);
+        return update(currentState, builder, isChanged);
     }
 
     static boolean shouldAllocateModels(final ClusterChangedEvent event) {
@@ -429,7 +444,8 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
                         load.getAssignedJobMemory(),
                         ByteSizeValue.ofBytes(load.getAssignedJobMemory()).toString(),
                         params.estimateMemoryUsageBytes(),
-                        ByteSizeValue.ofBytes(params.estimateMemoryUsageBytes()).toString() }
+                        ByteSizeValue.ofBytes(params.estimateMemoryUsageBytes()).toString()
+                    }
                 )
             );
         }

+ 10 - 62
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java

@@ -7,7 +7,7 @@
 
 package org.elasticsearch.xpack.ml.inference.allocation;
 
-import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.ResourceAlreadyExistsException;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.AbstractDiffable;
 import org.elasticsearch.cluster.ClusterState;
@@ -19,8 +19,6 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
-import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
@@ -155,57 +153,19 @@ public class TrainedModelAllocationMetadata implements Metadata.Custom {
             return modelRoutingEntries.containsKey(modelId);
         }
 
-        public Builder addNewAllocation(StartTrainedModelDeploymentAction.TaskParams taskParams) {
-            if (modelRoutingEntries.containsKey(taskParams.getModelId())) {
-                return this;
-            }
-            modelRoutingEntries.put(taskParams.getModelId(), TrainedModelAllocation.Builder.empty(taskParams));
-            isChanged = true;
-            return this;
-        }
-
-        public Builder updateAllocation(String modelId, String nodeId, RoutingStateAndReason state) {
-            TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId);
-            if (allocation == null) {
-                return this;
-            }
-            isChanged |= allocation.updateExistingRoutingEntry(nodeId, state).isChanged();
-            return this;
-        }
-
-        public Builder addNode(String modelId, String nodeId) {
-            TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId);
-            if (allocation == null) {
-                throw new ResourceNotFoundException(
-                    "unable to add node [{}] to model [{}] routing table as allocation does not exist",
-                    nodeId,
+        public Builder addNewAllocation(String modelId, TrainedModelAllocation.Builder allocation) {
+            if (modelRoutingEntries.containsKey(modelId)) {
+                throw new ResourceAlreadyExistsException(
+                    "[{}] allocation already exists",
                     modelId
-                );
-            }
-            isChanged |= allocation.addNewRoutingEntry(nodeId).isChanged();
-            return this;
-        }
-
-        public Builder addFailedNode(String modelId, String nodeId, String reason) {
-            TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId);
-            if (allocation == null) {
-                throw new ResourceNotFoundException(
-                    "unable to add failed node [{}] to model [{}] routing table as allocation does not exist",
-                    nodeId,
-                    modelId
-                );
-            }
-            isChanged |= allocation.addNewFailedRoutingEntry(nodeId, reason).isChanged();
+                );            }
+            modelRoutingEntries.put(modelId, allocation);
+            isChanged = true;
             return this;
         }
 
-        Builder removeNode(String modelId, String nodeId) {
-            TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId);
-            if (allocation == null) {
-                return this;
-            }
-            isChanged |= allocation.removeRoutingEntry(nodeId).isChanged();
-            return this;
+        public TrainedModelAllocation.Builder getAllocation(String modelId) {
+            return modelRoutingEntries.get(modelId);
         }
 
         public Builder removeAllocation(String modelId) {
@@ -213,18 +173,6 @@ public class TrainedModelAllocationMetadata implements Metadata.Custom {
             return this;
         }
 
-        public Builder setAllocationToStopping(String modelId) {
-            TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId);
-            if (allocation == null) {
-                throw new ResourceNotFoundException(
-                    "unable to set model allocation [{}] to stopping as it does not exist",
-                    modelId
-                );
-            }
-            isChanged |= allocation.stopAllocation().isChanged();
-            return this;
-        }
-
         public boolean isChanged() {
             return isChanged;
         }

+ 111 - 47
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java

@@ -82,8 +82,11 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                 Metadata.builder()
                     .putCustom(
                         TrainedModelAllocationMetadata.NAME,
-                        TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(modelId, 10_000L))
-                            .addNode(modelId, nodeId)
+                        TrainedModelAllocationMetadata.Builder.empty()
+                            .addNewAllocation(
+                                modelId,
+                                TrainedModelAllocation.Builder.empty(newParams(modelId, 10_000L)).addNewRoutingEntry(nodeId)
+                            )
                             .build()
                     )
                     .build()
@@ -168,7 +171,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                 Metadata.builder()
                     .putCustom(
                         TrainedModelAllocationMetadata.NAME,
-                        TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(modelId, randomNonNegativeLong())).build()
+                        TrainedModelAllocationMetadata.Builder.empty()
+                            .addNewAllocation(modelId, TrainedModelAllocation.Builder.empty(newParams(modelId, randomNonNegativeLong())))
+                            .build()
                     )
                     .build()
             )
@@ -281,15 +286,22 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                     .putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata("ml-node-shutting-down"))
                     .putCustom(
                         TrainedModelAllocationMetadata.NAME,
-                        TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams("model-1", 10_000))
-                            .addNode("model-1", "ml-node-with-room")
-                            .updateAllocation("model-1", "ml-node-with-room", new RoutingStateAndReason(RoutingState.STARTED, ""))
-                            .addNode("model-1", "old-ml-node-with-room")
-                            .updateAllocation("model-1", "old-ml-node-with-room", new RoutingStateAndReason(RoutingState.STARTED, ""))
-                            .addNode("model-1", "ml-node-shutting-down")
-                            .addNewAllocation(newParams("model-2", 10_000))
-                            .addNode("model-2", "old-ml-node-with-room")
-                            .updateAllocation("model-2", "old-ml-node-with-room", new RoutingStateAndReason(RoutingState.STARTED, ""))
+                        TrainedModelAllocationMetadata.Builder.empty()
+                            .addNewAllocation(
+                                "model-1",
+                                TrainedModelAllocation.Builder.empty(newParams("model-1", 10_000))
+                                    .addNewRoutingEntry("ml-node-with-room")
+                                    .updateExistingRoutingEntry("ml-node-with-room", started())
+                                    .addNewRoutingEntry("old-ml-node-with-room")
+                                    .updateExistingRoutingEntry("old-ml-node-with-room", started())
+                                    .addNewRoutingEntry("ml-node-shutting-down")
+                            )
+                            .addNewAllocation(
+                                "model-2",
+                                TrainedModelAllocation.Builder.empty(newParams("model-2", 10_000))
+                                    .addNewRoutingEntry("old-ml-node-with-room")
+                                    .updateExistingRoutingEntry("old-ml-node-with-room", started())
+                            )
                             .build()
                     )
             )
@@ -327,6 +339,8 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
     }
 
     public void testShouldAllocateModels() {
+        String model1 = "model-1";
+        String model2 = "model-2";
         String mlNode1 = "ml-node-with-room";
         String mlNode2 = "new-ml-node-with-room";
         DiscoveryNode mlNode1Node = buildNode(mlNode1, true, ByteSizeValue.ofGb(4).getBytes());
@@ -352,7 +366,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build()
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(model1, TrainedModelAllocation.Builder.empty(newParams(model1, 100)))
+                                        .build()
                                 )
                                 .build()
                         )
@@ -397,7 +413,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build()
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(model1, TrainedModelAllocation.Builder.empty(newParams(model1, 100)))
+                                        .build()
                                 )
                                 .build()
                         )
@@ -407,7 +425,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build()
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(model1, TrainedModelAllocation.Builder.empty(newParams(model1, 100)))
+                                        .build()
                                 )
                                 .build()
                         )
@@ -427,7 +447,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build()
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(model1, TrainedModelAllocation.Builder.empty(newParams(model1, 100)))
+                                        .build()
                                 )
                                 .build()
                         )
@@ -437,7 +459,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build()
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(model1, TrainedModelAllocation.Builder.empty(newParams(model1, 100)))
+                                        .build()
                                 )
                                 .build()
                         )
@@ -457,7 +481,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build()
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(model1, TrainedModelAllocation.Builder.empty(newParams(model1, 100)))
+                                        .build()
                                 )
                                 .build()
                         )
@@ -467,7 +493,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build()
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(model1, TrainedModelAllocation.Builder.empty(newParams(model1, 100)))
+                                        .build()
                                 )
                                 .build()
                         )
@@ -488,8 +516,10 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
                                     TrainedModelAllocationMetadata.Builder.empty()
-                                        .addNewAllocation(newParams(mlNode1, 100))
-                                        .setAllocationToStopping(mlNode1)
+                                        .addNewAllocation(
+                                            model1,
+                                            TrainedModelAllocation.Builder.empty(newParams(model1, 100)).stopAllocation()
+                                        )
                                         .build()
                                 )
                                 .build()
@@ -500,7 +530,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build()
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(model1, TrainedModelAllocation.Builder.empty(newParams(model1, 100)))
+                                        .build()
                                 )
                                 .build()
                         )
@@ -520,7 +552,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build()
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(model1, TrainedModelAllocation.Builder.empty(newParams(model1, 100)))
+                                        .build()
                                 )
                                 .putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata(mlNode2))
                                 .build()
@@ -531,7 +565,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(mlNode1, 100)).build()
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(model1, TrainedModelAllocation.Builder.empty(newParams(model1, 100)))
+                                        .build()
                                 )
                                 .build()
                         )
@@ -551,11 +587,17 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams("model-1", 100))
-                                        .addNode("model-1", mlNode1)
-                                        .addNewAllocation(newParams("model-2", 100))
-                                        .addNode("model-2", mlNode1)
-                                        .addNode("model-2", mlNode2)
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(
+                                            model1,
+                                            TrainedModelAllocation.Builder.empty(newParams(model1, 100)).addNewRoutingEntry(mlNode1)
+                                        )
+                                        .addNewAllocation(
+                                            model2,
+                                            TrainedModelAllocation.Builder.empty(newParams("model-2", 100))
+                                                .addNewRoutingEntry(mlNode1)
+                                                .addNewRoutingEntry(mlNode2)
+                                        )
                                         .build()
                                 )
                                 .build()
@@ -566,11 +608,17 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                             Metadata.builder()
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
-                                    TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams("model-1", 100))
-                                        .addNode("model-1", mlNode1)
-                                        .addNewAllocation(newParams("model-2", 100))
-                                        .addNode("model-2", mlNode1)
-                                        .addNode("model-2", mlNode2)
+                                    TrainedModelAllocationMetadata.Builder.empty()
+                                        .addNewAllocation(
+                                            model1,
+                                            TrainedModelAllocation.Builder.empty(newParams(model1, 100)).addNewRoutingEntry(mlNode1)
+                                        )
+                                        .addNewAllocation(
+                                            model2,
+                                            TrainedModelAllocation.Builder.empty(newParams("model-2", 100))
+                                                .addNewRoutingEntry(mlNode1)
+                                                .addNewRoutingEntry(mlNode2)
+                                        )
                                         .build()
                                 )
                                 .build()
@@ -592,12 +640,17 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
                                     TrainedModelAllocationMetadata.Builder.empty()
-                                        .addNewAllocation(newParams("model-1", 100))
-                                        .addNode("model-1", mlNode1)
-                                        .addNewAllocation(newParams("model-2", 100))
-                                        .addNode("model-2", mlNode1)
-                                        .addNode("model-2", mlNode2)
-                                        .setAllocationToStopping("model-2")
+                                        .addNewAllocation(
+                                            model1,
+                                            TrainedModelAllocation.Builder.empty(newParams(model1, 100)).addNewRoutingEntry(mlNode1)
+                                        )
+                                        .addNewAllocation(
+                                            model2,
+                                            TrainedModelAllocation.Builder.empty(newParams("model-2", 100))
+                                                .addNewRoutingEntry(mlNode1)
+                                                .addNewRoutingEntry(mlNode2)
+                                                .stopAllocation()
+                                        )
                                         .build()
                                 )
                                 .build()
@@ -609,11 +662,16 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                                 .putCustom(
                                     TrainedModelAllocationMetadata.NAME,
                                     TrainedModelAllocationMetadata.Builder.empty()
-                                        .addNewAllocation(newParams("model-1", 100))
-                                        .addNode("model-1", mlNode1)
-                                        .addNewAllocation(newParams("model-2", 100))
-                                        .addNode("model-2", mlNode1)
-                                        .addNode("model-2", mlNode2)
+                                        .addNewAllocation(
+                                            model1,
+                                            TrainedModelAllocation.Builder.empty(newParams(model1, 100)).addNewRoutingEntry(mlNode1)
+                                        )
+                                        .addNewAllocation(
+                                            model2,
+                                            TrainedModelAllocation.Builder.empty(newParams("model-2", 100))
+                                                .addNewRoutingEntry(mlNode1)
+                                                .addNewRoutingEntry(mlNode2)
+                                        )
                                         .build()
                                 )
                                 .build()
@@ -641,7 +699,9 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                 Metadata.builder()
                     .putCustom(
                         TrainedModelAllocationMetadata.NAME,
-                        TrainedModelAllocationMetadata.Builder.empty().addNewAllocation(newParams(modelId, randomNonNegativeLong())).build()
+                        TrainedModelAllocationMetadata.Builder.empty()
+                            .addNewAllocation(modelId, TrainedModelAllocation.Builder.empty(newParams(modelId, randomNonNegativeLong())))
+                            .build()
                     )
                     .build()
             )
@@ -667,7 +727,7 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
         }
         TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(original);
         for (String modelId : tempMetadata.modelAllocations().keySet()) {
-            builder.setAllocationToStopping(modelId);
+            builder.getAllocation(modelId).stopAllocation();
         }
         TrainedModelAllocationMetadata metadataWithStopping = builder.build();
         ClusterState originalWithStoppingAllocations = ClusterState.builder(original)
@@ -704,6 +764,10 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
         );
     }
 
+    private static RoutingStateAndReason started() {
+        return new RoutingStateAndReason(RoutingState.STARTED, "");
+    }
+
     private static DiscoveryNode buildOldNode(String name, boolean isML, long nativeMemory) {
         return buildNode(name, isML, nativeMemory, Version.V_7_15_0);
     }

+ 3 - 33
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java

@@ -11,7 +11,6 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
-import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReasonTests;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocationTests;
 
@@ -58,45 +57,15 @@ public class TrainedModelAllocationMetadataTests extends AbstractSerializingTest
         assertThat(builder.isChanged(), is(false));
 
         assertUnchanged(builder, b -> b.removeAllocation(newModel));
-        assertUnchanged(builder, b -> b.updateAllocation(newModel, "foo", RoutingStateAndReasonTests.randomInstance()));
-        assertUnchanged(builder, b -> b.removeNode(newModel, "foo"));
 
-        if (original.modelAllocations().isEmpty() == false) {
-            String randomExistingModel = randomFrom(original.modelAllocations().keySet().toArray(String[]::new));
-            assertUnchanged(builder, b -> b.addNewAllocation(randomParams(randomExistingModel)));
-        }
-
-        builder.addNewAllocation(new StartTrainedModelDeploymentAction.TaskParams(newModel, randomNonNegativeLong()));
-        assertThat(builder.isChanged(), is(true));
-    }
-
-    public void testBuilderChanged_WhenAddingRemovingNodeFromModel() {
-        String newModel = "foo_model";
-        TrainedModelAllocationMetadata original = TrainedModelAllocationMetadata.Builder.fromMetadata(randomInstance())
-            .addNewAllocation(randomParams(newModel))
-            .build();
-        TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.Builder.fromMetadata(original);
-        assertThat(builder.isChanged(), is(false));
-
-        String newNode = "foo";
-        if (randomBoolean()) {
-            builder.addNode(newModel, newNode);
-        } else {
-            builder.addFailedNode(newModel, newNode, "failure");
-        }
-        assertThat(builder.isChanged(), is(true));
-
-        builder = TrainedModelAllocationMetadata.Builder.fromMetadata(builder.build());
-        assertThat(builder.isChanged(), is(false));
-
-        builder.removeNode(newModel, newNode);
+        builder.addNewAllocation(newModel, TrainedModelAllocation.Builder.empty(randomParams(newModel)));
         assertThat(builder.isChanged(), is(true));
     }
 
     public void testIsAllocated() {
         String allocatedModelId = "test_model_id";
         TrainedModelAllocationMetadata metadata = TrainedModelAllocationMetadata.Builder.empty()
-            .addNewAllocation(randomParams(allocatedModelId))
+            .addNewAllocation(allocatedModelId, TrainedModelAllocation.Builder.empty(randomParams(allocatedModelId)))
             .build();
         assertThat(metadata.isAllocated(allocatedModelId), is(true));
         assertThat(metadata.isAllocated("unknown_model_id"), is(false));
@@ -114,4 +83,5 @@ public class TrainedModelAllocationMetadataTests extends AbstractSerializingTest
     private static StartTrainedModelDeploymentAction.TaskParams randomParams(String modelId) {
         return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong());
     }
+
 }

+ 41 - 20
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeServiceTests.java

@@ -30,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
 import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
 import org.junit.After;
@@ -244,12 +245,18 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
                         .putCustom(
                             TrainedModelAllocationMetadata.NAME,
                             TrainedModelAllocationMetadata.Builder.empty()
-                                .addNewAllocation(newParams(modelOne))
-                                .addNode(modelOne, NODE_ID)
-                                .addNewAllocation(newParams(modelTwo))
-                                .addNode(modelTwo, NODE_ID)
-                                .addNewAllocation(newParams(notUsedModel))
-                                .addNode(notUsedModel, "some-other-node")
+                                .addNewAllocation(
+                                    modelOne,
+                                    TrainedModelAllocation.Builder.empty(newParams(modelOne)).addNewRoutingEntry(NODE_ID)
+                                )
+                                .addNewAllocation(
+                                    modelTwo,
+                                    TrainedModelAllocation.Builder.empty(newParams(modelTwo)).addNewRoutingEntry(NODE_ID)
+                                )
+                                .addNewAllocation(
+                                    notUsedModel,
+                                    TrainedModelAllocation.Builder.empty(newParams(notUsedModel)).addNewRoutingEntry("some-other-node")
+                                )
                                 .build()
                         )
                         .putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isResetMode(true).build())
@@ -291,12 +298,18 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
                         .putCustom(
                             TrainedModelAllocationMetadata.NAME,
                             TrainedModelAllocationMetadata.Builder.empty()
-                                .addNewAllocation(newParams(modelOne))
-                                .addNode(modelOne, NODE_ID)
-                                .addNewAllocation(newParams(modelTwo))
-                                .addNode(modelTwo, NODE_ID)
-                                .addNewAllocation(newParams(notUsedModel))
-                                .addNode(notUsedModel, "some-other-node")
+                                .addNewAllocation(
+                                    modelOne,
+                                    TrainedModelAllocation.Builder.empty(newParams(modelOne)).addNewRoutingEntry(NODE_ID)
+                                )
+                                .addNewAllocation(
+                                    modelTwo,
+                                    TrainedModelAllocation.Builder.empty(newParams(modelTwo)).addNewRoutingEntry(NODE_ID)
+                                )
+                                .addNewAllocation(
+                                    notUsedModel,
+                                    TrainedModelAllocation.Builder.empty(newParams(notUsedModel)).addNewRoutingEntry("some-other-node")
+                                )
                                 .build()
                         )
                         .build()
@@ -316,12 +329,18 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
                         .putCustom(
                             TrainedModelAllocationMetadata.NAME,
                             TrainedModelAllocationMetadata.Builder.empty()
-                                .addNewAllocation(newParams(modelOne))
-                                .addNode(modelOne, NODE_ID)
-                                .addNewAllocation(newParams(modelTwo))
-                                .addNode(modelTwo, "some-other-node")
-                                .addNewAllocation(newParams(notUsedModel))
-                                .addNode(notUsedModel, "some-other-node")
+                                .addNewAllocation(
+                                    modelOne,
+                                    TrainedModelAllocation.Builder.empty(newParams(modelOne)).addNewRoutingEntry(NODE_ID)
+                                )
+                                .addNewAllocation(
+                                    modelTwo,
+                                    TrainedModelAllocation.Builder.empty(newParams(modelTwo)).addNewRoutingEntry("some-other-node")
+                                )
+                                .addNewAllocation(
+                                    notUsedModel,
+                                    TrainedModelAllocation.Builder.empty(newParams(notUsedModel)).addNewRoutingEntry("some-other-node")
+                                )
                                 .build()
                         )
                         .build()
@@ -359,8 +378,10 @@ public class TrainedModelAllocationNodeServiceTests extends ESTestCase {
                         .putCustom(
                             TrainedModelAllocationMetadata.NAME,
                             TrainedModelAllocationMetadata.Builder.empty()
-                                .addNewAllocation(newParams(modelOne))
-                                .addNode(modelOne, NODE_ID)
+                                .addNewAllocation(
+                                    modelOne,
+                                    TrainedModelAllocation.Builder.empty(newParams(modelOne)).addNewRoutingEntry(NODE_ID)
+                                )
                                 .build()
                         )
                         .build()

+ 13 - 11
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java

@@ -19,6 +19,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
@@ -88,18 +89,19 @@ public class NodeLoadDetectorTests extends ESTestCase {
                             TrainedModelAllocationMetadata.NAME,
                             TrainedModelAllocationMetadata.Builder.empty()
                                 .addNewAllocation(
-                                    new StartTrainedModelDeploymentAction.TaskParams("model1", MODEL_MEMORY_REQUIREMENT)
-                                )
-                                .addNode("model1", "_node_id4")
-                                .addFailedNode("model1", "_node_id2", "test")
-                                .addNode("model1", "_node_id1")
-                                .updateAllocation(
                                     "model1",
-                                    "_node_id1",
-                                    new RoutingStateAndReason(
-                                        randomFrom(RoutingState.STOPPED, RoutingState.FAILED),
-                                        "test"
-                                    )
+                                    TrainedModelAllocation.Builder
+                                        .empty(new StartTrainedModelDeploymentAction.TaskParams("model1", MODEL_MEMORY_REQUIREMENT))
+                                        .addNewRoutingEntry("_node_id4")
+                                        .addNewFailedRoutingEntry("_node_id2", "test")
+                                        .addNewRoutingEntry("_node_id1")
+                                        .updateExistingRoutingEntry(
+                                            "_node_id1",
+                                            new RoutingStateAndReason(
+                                                randomFrom(RoutingState.STOPPED, RoutingState.FAILED),
+                                                "test"
+                                            )
+                                        )
                                 )
                                 .build()
                         )