Browse Source

[ML] Distribute trained model allocations across availability zones (#89822)

When a model deployment is started with 2 or more allocations
and availability zones are present we should distribute the allocations
across availability zones so that there is resilience.

This commit adds a `ZoneAwareAssignmentPlanner` that attempts to evenly
distribute the allocations of a deployment across the available zones.
Dimitris Athanasiou 3 năm trước cách đây
mục cha
commit
c733eb8908

+ 5 - 0
docs/changelog/89822.yaml

@@ -0,0 +1,5 @@
+pr: 89822
+summary: Distribute trained model allocations across availability zones
+area: Machine Learning
+type: enhancement
+issues: []

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -1119,7 +1119,8 @@ public class MachineLearning extends Plugin
                 clusterService,
                 threadPool,
                 new NodeLoadDetector(memoryTracker),
-                new SystemAuditor(client, clusterService)
+                new SystemAuditor(client, clusterService),
+                nodeAvailabilityZoneMapper
             )
         );
 

+ 6 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java

@@ -42,6 +42,7 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
 import org.elasticsearch.xpack.ml.job.NodeLoad;
 import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
 import org.elasticsearch.xpack.ml.notifications.SystemAuditor;
@@ -69,6 +70,7 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
     private final ThreadPool threadPool;
     private final NodeLoadDetector nodeLoadDetector;
     private final SystemAuditor systemAuditor;
+    private final NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper;
     private volatile int maxMemoryPercentage;
     private volatile boolean useAuto;
     private volatile int maxOpenJobs;
@@ -78,12 +80,14 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
         ClusterService clusterService,
         ThreadPool threadPool,
         NodeLoadDetector nodeLoadDetector,
-        SystemAuditor systemAuditor
+        SystemAuditor systemAuditor,
+        NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper
     ) {
         this.clusterService = Objects.requireNonNull(clusterService);
         this.threadPool = Objects.requireNonNull(threadPool);
         this.nodeLoadDetector = Objects.requireNonNull(nodeLoadDetector);
         this.systemAuditor = Objects.requireNonNull(systemAuditor);
+        this.nodeAvailabilityZoneMapper = Objects.requireNonNull(nodeAvailabilityZoneMapper);
         this.maxMemoryPercentage = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings);
         this.useAuto = MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings);
         this.maxOpenJobs = MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(settings);
@@ -462,6 +466,7 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
         TrainedModelAssignmentRebalancer rebalancer = new TrainedModelAssignmentRebalancer(
             TrainedModelAssignmentMetadata.fromState(currentState),
             nodeLoads,
+            nodeAvailabilityZoneMapper,
             modelToAdd
         );
         return rebalancer.rebalance();

+ 44 - 16
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java

@@ -19,11 +19,13 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
 import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
 import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
-import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlanner;
+import org.elasticsearch.xpack.ml.inference.assignment.planning.ZoneAwareAssignmentPlanner;
 import org.elasticsearch.xpack.ml.job.NodeLoad;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -41,15 +43,18 @@ class TrainedModelAssignmentRebalancer {
 
     private final TrainedModelAssignmentMetadata currentMetadata;
     private final Map<DiscoveryNode, NodeLoad> nodeLoads;
+    private final NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper;
     private final Optional<StartTrainedModelDeploymentAction.TaskParams> modelToAdd;
 
     TrainedModelAssignmentRebalancer(
         TrainedModelAssignmentMetadata currentMetadata,
         Map<DiscoveryNode, NodeLoad> nodeLoads,
+        NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper,
         Optional<StartTrainedModelDeploymentAction.TaskParams> modelToAdd
     ) {
         this.currentMetadata = Objects.requireNonNull(currentMetadata);
         this.nodeLoads = Objects.requireNonNull(nodeLoads);
+        this.nodeAvailabilityZoneMapper = Objects.requireNonNull(nodeAvailabilityZoneMapper);
         this.modelToAdd = Objects.requireNonNull(modelToAdd);
     }
 
@@ -78,24 +83,16 @@ class TrainedModelAssignmentRebalancer {
     }
 
     AssignmentPlan computeAssignmentPlan() {
-        List<AssignmentPlan.Node> planNodes = nodeLoads.entrySet()
-            .stream()
-            .filter(e -> Strings.isNullOrEmpty(e.getValue().getError()))
-            .map(
-                e -> new AssignmentPlan.Node(
-                    e.getKey().getId(),
-                    // We subtract native inference memory as the planner expects available memory for
-                    // native inference including current assignments.
-                    getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(e.getValue()),
-                    getNodeAllocatedProcessors(e.getKey()).orElse(0)
-                )
-            )
-            .toList();
+        final Map<List<String>, List<AssignmentPlan.Node>> nodesByZone = createNodesByZoneMap();
 
         final List<AssignmentPlan.Model> planModels = new ArrayList<>(
             currentMetadata.modelAssignments().size() + (modelToAdd.isPresent() ? 1 : 0)
         );
-        final Set<String> assignableNodeIds = planNodes.stream().map(AssignmentPlan.Node::id).collect(Collectors.toSet());
+        final Set<String> assignableNodeIds = nodesByZone.values()
+            .stream()
+            .flatMap(List::stream)
+            .map(AssignmentPlan.Node::id)
+            .collect(Collectors.toSet());
         currentMetadata.modelAssignments().values().stream().map(assignment -> {
             Map<String, Integer> currentAssignments = assignment.getNodeRoutingTable()
                 .entrySet()
@@ -127,7 +124,38 @@ class TrainedModelAssignmentRebalancer {
                 )
             )
         );
-        return new AssignmentPlanner(planNodes, planModels).computePlan();
+        return new ZoneAwareAssignmentPlanner(nodesByZone, planModels).computePlan();
+    }
+
+    private Map<List<String>, List<AssignmentPlan.Node>> createNodesByZoneMap() {
+        Map<List<String>, Collection<DiscoveryNode>> mlNodesByZone = nodeAvailabilityZoneMapper.getMlNodesByAvailabilityZone();
+        return mlNodesByZone.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> {
+            Collection<DiscoveryNode> discoveryNodes = e.getValue();
+            List<AssignmentPlan.Node> nodes = new ArrayList<>();
+            for (DiscoveryNode discoveryNode : discoveryNodes) {
+                if (nodeLoads.containsKey(discoveryNode)) {
+                    NodeLoad load = nodeLoads.get(discoveryNode);
+                    if (Strings.isNullOrEmpty(load.getError())) {
+                        nodes.add(
+                            new AssignmentPlan.Node(
+                                discoveryNode.getId(),
+                                // We subtract native inference memory as the planner expects available memory for
+                                // native inference including current assignments.
+                                getNodeFreeMemoryExcludingPerNodeOverheadAndNativeInference(load),
+                                getNodeAllocatedProcessors(discoveryNode).orElse(0)
+                            )
+                        );
+                    } else {
+                        logger.warn(
+                            format("ignoring node [%s] as detecting its load failed with [%s]", discoveryNode.getId(), load.getError())
+                        );
+                    }
+                } else {
+                    logger.warn(format("ignoring node [%s] as no load could be detected", discoveryNode.getId()));
+                }
+            }
+            return nodes;
+        }));
     }
 
     private static OptionalInt getNodeAllocatedProcessors(DiscoveryNode node) {

+ 3 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java

@@ -42,14 +42,14 @@ import static org.elasticsearch.core.Strings.format;
  * attempt to find a solution that provides at least one allocation to
  * previously assigned models.
  */
-public class AssignmentPlanner {
+class AssignmentPlanner {
 
     private static final Logger logger = LogManager.getLogger(AssignmentPlanner.class);
 
     private final List<Node> nodes;
     private final List<Model> models;
 
-    public AssignmentPlanner(List<Node> nodes, List<Model> models) {
+    AssignmentPlanner(List<Node> nodes, List<Model> models) {
         this.nodes = nodes.stream().sorted(Comparator.comparing(Node::id)).toList();
         this.models = models.stream().sorted(Comparator.comparing(Model::id)).toList();
     }
@@ -58,7 +58,7 @@ public class AssignmentPlanner {
         return computePlan(true);
     }
 
-    private AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
+    public AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
         logger.debug(() -> format("Computing plan for nodes = %s; models = %s", nodes, models));
 
         AssignmentPlan bestPlan;

+ 211 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java

@@ -0,0 +1,211 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.assignment.planning;
+
+import org.elasticsearch.logging.LogManager;
+import org.elasticsearch.logging.Logger;
+import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Model;
+import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.TreeMap;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import static org.elasticsearch.core.Strings.format;
+
+/**
+ * An assignment planner that is aware of availability zones and tries to distribute
+ * model allocations evenly across zones in order to achieve better resilience in the
+ * case nodes in a particular zone become unavailable.
+ */
+public class ZoneAwareAssignmentPlanner {
+
+    private static final Logger logger = LogManager.getLogger(ZoneAwareAssignmentPlanner.class);
+
+    /**
+     * A map from zone attributes to node.
+     */
+    private final Map<List<String>, List<Node>> nodesByZone;
+
+    private final List<Model> models;
+
+    public ZoneAwareAssignmentPlanner(Map<List<String>, List<Node>> nodesByZone, List<Model> models) {
+        this.nodesByZone = sortByZone(Objects.requireNonNull(nodesByZone));
+        this.models = Objects.requireNonNull(models);
+    }
+
+    private static Map<List<String>, List<Node>> sortByZone(Map<List<String>, List<Node>> nodesByZone) {
+        Map<List<String>, List<Node>> sortedByZone = new TreeMap<>(
+            Comparator.comparing(zoneAttributes -> zoneAttributes.stream().collect(Collectors.joining()))
+        );
+        sortedByZone.putAll(nodesByZone);
+        return sortedByZone;
+    }
+
+    public AssignmentPlan computePlan() {
+        // There is only one zone; we can optimize and compute a plan directly.
+        if (nodesByZone.size() == 1) {
+            return new AssignmentPlanner(nodesByZone.values().iterator().next(), models).computePlan(true);
+        }
+
+        // First we try to compute a plan without forcing assigning previously assigned models as this may
+        // produce better plans. If that plan has failed to assign previously assigned models we then try
+        // again this time prioritizing assigning such models.
+        AssignmentPlan plan = computePlan(false);
+        if (plan.arePreviouslyAssignedModelsAssigned() == false) {
+            plan = computePlan(true);
+        }
+        return plan;
+    }
+
+    private AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
+        logger.debug(
+            () -> format(
+                "computing plan%s trying to assign previously assigned models",
+                tryAssigningPreviouslyAssignedModels ? "" : " without"
+            )
+        );
+        // The idea here is that we solve per zone trying to distribute allocations evenly.
+        // After computing a plan for each zone it is possible that there are still unsatisfied allocations
+        // that can be allocated, so we solve a final time across all zones preserving the allocations we
+        // allocated on the first per zone assignment plans.
+
+        int remainingZones = nodesByZone.size();
+        Map<String, Integer> modelIdToRemainingAllocations = models.stream().collect(Collectors.toMap(Model::id, Model::allocations));
+        List<AssignmentPlan> plans = new ArrayList<>();
+        for (var zoneToNodes : nodesByZone.entrySet()) {
+            logger.debug(() -> format("computing plan for availability zone %s", zoneToNodes.getKey()));
+            AssignmentPlan plan = computeZonePlan(
+                zoneToNodes.getValue(),
+                modelIdToRemainingAllocations,
+                remainingZones,
+                tryAssigningPreviouslyAssignedModels
+            );
+            plan.models()
+                .forEach(
+                    m -> modelIdToRemainingAllocations.computeIfPresent(
+                        m.id(),
+                        (modelId, remainingAllocations) -> remainingAllocations - plan.totalAllocations(m)
+                    )
+                );
+            plans.add(plan);
+            remainingZones--;
+        }
+        AssignmentPlan plan = computePlanAcrossAllNodes(plans);
+        logger.debug(() -> "Zone aware plan =\n" + plan.prettyPrint());
+        return plan;
+    }
+
+    private AssignmentPlan computeZonePlan(
+        List<Node> nodes,
+        Map<String, Integer> modelIdToRemainingAllocations,
+        int remainingZones,
+        boolean tryAssigningPreviouslyAssignedModels
+    ) {
+        Map<String, Integer> modelIdToTargetAllocations = modelIdToRemainingAllocations.entrySet()
+            .stream()
+            .filter(e -> e.getValue() > 0)
+            .collect(Collectors.toMap(e -> e.getKey(), e -> (e.getValue() - 1) / remainingZones + 1));
+
+        List<Model> modifiedModels = models.stream()
+            .filter(m -> modelIdToTargetAllocations.getOrDefault(m.id(), 0) > 0)
+            .map(
+                m -> new Model(
+                    m.id(),
+                    m.memoryBytes(),
+                    modelIdToTargetAllocations.get(m.id()),
+                    m.threadsPerAllocation(),
+                    m.currentAllocationsByNodeId(),
+                    // Only force assigning at least once previously assigned models that have not had any allocation yet
+                    (tryAssigningPreviouslyAssignedModels && modelIdToRemainingAllocations.get(m.id()) == m.allocations())
+                        ? m.maxAssignedAllocations()
+                        : 0
+                )
+            )
+            .toList();
+        return new AssignmentPlanner(nodes, modifiedModels).computePlan(tryAssigningPreviouslyAssignedModels);
+    }
+
+    private AssignmentPlan computePlanAcrossAllNodes(List<AssignmentPlan> plans) {
+        logger.debug(() -> "computing plan across all nodes");
+        final List<Node> allNodes = new ArrayList<>();
+        nodesByZone.values().forEach(allNodes::addAll);
+
+        Map<String, Map<String, Integer>> allocationsByNodeIdByModelId = mergeAllocationsByNodeIdByModelId(plans);
+
+        List<Model> modelsAccountingPlans = models.stream()
+            .map(
+                m -> new Model(
+                    m.id(),
+                    m.memoryBytes(),
+                    m.allocations(),
+                    m.threadsPerAllocation(),
+                    allocationsByNodeIdByModelId.get(m.id()),
+                    m.maxAssignedAllocations()
+                )
+            )
+            .toList();
+
+        PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(allNodes, modelsAccountingPlans);
+        List<Node> planNodes = preserveAllAllocations.nodesPreservingAllocations();
+        List<Model> planModels = preserveAllAllocations.modelsPreservingAllocations();
+        AssignmentPlan plan = new LinearProgrammingPlanSolver(planNodes, planModels).solvePlan(false);
+        plan = preserveAllAllocations.mergePreservedAllocations(plan);
+        return swapOriginalModelsInPlan(plan, allNodes, modelsAccountingPlans);
+    }
+
+    private AssignmentPlan swapOriginalModelsInPlan(AssignmentPlan plan, List<Node> allNodes, List<Model> planModels) {
+        final Map<String, Model> originalModelById = models.stream().collect(Collectors.toMap(Model::id, Function.identity()));
+        final Map<String, Node> originalNodeById = allNodes.stream().collect(Collectors.toMap(Node::id, Function.identity()));
+        AssignmentPlan.Builder planBuilder = AssignmentPlan.builder(allNodes, models);
+        for (Model m : planModels) {
+            Optional<Map<Node, Integer>> nodeAssignments = plan.assignments(m);
+            if (nodeAssignments.isPresent()) {
+                nodeAssignments.get()
+                    .entrySet()
+                    .forEach(
+                        e -> planBuilder.assignModelToNode(
+                            originalModelById.get(m.id()),
+                            originalNodeById.get(e.getKey().id()),
+                            e.getValue()
+                        )
+                    );
+            }
+        }
+        return planBuilder.build();
+    }
+
+    private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByModelId(List<AssignmentPlan> plans) {
+        Map<String, Map<String, Integer>> allocationsByNodeIdByModelId = new HashMap<>();
+        models.forEach(m -> allocationsByNodeIdByModelId.put(m.id(), new HashMap<>()));
+        for (AssignmentPlan plan : plans) {
+            for (Model m : plan.models()) {
+                Map<String, Integer> nodeIdToAllocations = allocationsByNodeIdByModelId.get(m.id());
+                Optional<Map<Node, Integer>> assignments = plan.assignments(m);
+                if (assignments.isPresent()) {
+                    for (Map.Entry<Node, Integer> nodeAssignments : assignments.get().entrySet()) {
+                        nodeIdToAllocations.compute(
+                            nodeAssignments.getKey().id(),
+                            (nodeId, existingAllocations) -> existingAllocations == null
+                                ? nodeAssignments.getValue()
+                                : existingAllocations + nodeAssignments.getValue()
+                        );
+                    }
+                }
+            }
+        }
+        return allocationsByNodeIdByModelId;
+    }
+}

+ 37 - 11
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.cluster.routing.allocation.decider.AwarenessAllocationDecider;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.collect.MapBuilder;
 import org.elasticsearch.common.settings.ClusterSettings;
@@ -45,6 +46,7 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReaso
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
 import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
 import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutorTests;
 import org.elasticsearch.xpack.ml.notifications.SystemAuditor;
@@ -79,6 +81,7 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase {
     private ThreadPool threadPool;
     private NodeLoadDetector nodeLoadDetector;
     private SystemAuditor systemAuditor;
+    private NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper;
 
     @Before
     public void setupObjects() {
@@ -265,16 +268,22 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase {
     }
 
     public void testCreateAssignment() throws Exception {
+        Settings settings = Settings.EMPTY;
+        ClusterSettings clusterSettings = new ClusterSettings(
+            settings,
+            Set.of(AwarenessAllocationDecider.CLUSTER_ROUTING_ALLOCATION_AWARENESS_ATTRIBUTE_SETTING)
+        );
+        DiscoveryNodes discoveryNodes = DiscoveryNodes.builder()
+            .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes(), 2))
+            .add(buildNode("ml-node-without-room", true, 1000L, 2))
+            .add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes(), 2))
+            .add(buildNode("ml-node-shutting-down", true, ByteSizeValue.ofGb(4).getBytes(), 2))
+            .add(buildOldNode("old-ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes(), 2))
+            .build();
+        nodeAvailabilityZoneMapper = new NodeAvailabilityZoneMapper(settings, clusterSettings, discoveryNodes);
+
         ClusterState currentState = ClusterState.builder(new ClusterName("testCreateAssignment"))
-            .nodes(
-                DiscoveryNodes.builder()
-                    .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes(), 2))
-                    .add(buildNode("ml-node-without-room", true, 1000L, 2))
-                    .add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes(), 2))
-                    .add(buildNode("ml-node-shutting-down", true, ByteSizeValue.ofGb(4).getBytes(), 2))
-                    .add(buildOldNode("old-ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes(), 2))
-                    .build()
-            )
+            .nodes(discoveryNodes)
             .metadata(Metadata.builder().putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata("ml-node-shutting-down")))
             .build();
 
@@ -300,8 +309,18 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase {
     }
 
     public void testCreateAssignmentWhileResetModeIsTrue() throws InterruptedException {
+        Settings settings = Settings.EMPTY;
+        ClusterSettings clusterSettings = new ClusterSettings(
+            settings,
+            Set.of(AwarenessAllocationDecider.CLUSTER_ROUTING_ALLOCATION_AWARENESS_ATTRIBUTE_SETTING)
+        );
+        DiscoveryNodes discoveryNodes = DiscoveryNodes.builder()
+            .add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes(), 8))
+            .build();
+        nodeAvailabilityZoneMapper = new NodeAvailabilityZoneMapper(settings, clusterSettings, discoveryNodes);
+
         ClusterState currentState = ClusterState.builder(new ClusterName("testCreateAssignment"))
-            .nodes(DiscoveryNodes.builder().add(buildNode("ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes(), 8)).build())
+            .nodes(discoveryNodes)
             .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isResetMode(true).build()))
             .build();
         when(clusterService.state()).thenReturn(currentState);
@@ -1399,7 +1418,14 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase {
     }
 
     private TrainedModelAssignmentClusterService createClusterService() {
-        return new TrainedModelAssignmentClusterService(Settings.EMPTY, clusterService, threadPool, nodeLoadDetector, systemAuditor);
+        return new TrainedModelAssignmentClusterService(
+            Settings.EMPTY,
+            clusterService,
+            threadPool,
+            nodeLoadDetector,
+            systemAuditor,
+            nodeAvailabilityZoneMapper
+        );
     }
 
     private static DiscoveryNode buildNode(String name, boolean isML, long nativeMemory, int allocatedProcessors) {

+ 157 - 68
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java

@@ -11,7 +11,11 @@ import org.elasticsearch.ResourceAlreadyExistsException;
 import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.cluster.node.DiscoveryNodes;
+import org.elasticsearch.cluster.routing.allocation.decider.AwarenessAllocationDecider;
 import org.elasticsearch.common.collect.MapBuilder;
+import org.elasticsearch.common.settings.ClusterSettings;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@@ -20,12 +24,14 @@ import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
 import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.autoscaling.NodeAvailabilityZoneMapper;
 import org.elasticsearch.xpack.ml.job.NodeLoad;
 
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 
 import static org.hamcrest.Matchers.aMapWithSize;
 import static org.hamcrest.Matchers.anEmptyMap;
@@ -37,10 +43,17 @@ import static org.hamcrest.Matchers.notNullValue;
 
 public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
 
+    private NodeAvailabilityZoneMapper nodeAvailabilityZoneMapper = new NodeAvailabilityZoneMapper(
+        Settings.EMPTY,
+        new ClusterSettings(Settings.EMPTY, Set.of(AwarenessAllocationDecider.CLUSTER_ROUTING_ALLOCATION_AWARENESS_ATTRIBUTE_SETTING)),
+        DiscoveryNodes.EMPTY_NODES
+    );
+
     public void testRebalance_GivenNoAssignments() throws Exception {
         TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
             TrainedModelAssignmentMetadata.Builder.empty().build(),
             Map.of(),
+            nodeAvailabilityZoneMapper,
             Optional.empty()
         ).rebalance().build();
         assertThat(result.modelAssignments().isEmpty(), is(true));
@@ -68,14 +81,22 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
         nodeLoads.put(buildNode("node-1", oneGbBytes, 4), NodeLoad.builder("node-1").setMaxMemory(oneGbBytes).build());
         nodeLoads.put(buildNode("node-2", oneGbBytes, 4), NodeLoad.builder("node-2").setMaxMemory(oneGbBytes).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.empty())
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.empty()
+        ).rebalance().build();
 
         assertThat(currentMetadata, equalTo(result));
     }
 
     public void testRebalance_GivenAllAssignmentsAreSatisfied_GivenOutdatedRoutingEntry_ShouldRebalance() throws Exception {
+        long oneGbBytes = ByteSizeValue.ofGb(1).getBytes();
+        DiscoveryNode node1 = buildNode("node-1", oneGbBytes, 4);
+        DiscoveryNode node2 = buildNode("node-2", oneGbBytes, 4);
+        givenNodeAvailabilityZoneMapper(DiscoveryNodes.builder().add(node1).add(node2).build());
+
         String modelId1 = "model-1";
         String modelId2 = "model-2";
         StartTrainedModelDeploymentAction.TaskParams taskParams1 = newParams(modelId1, 1024L, 1, 2);
@@ -93,13 +114,15 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
             )
             .build();
         Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
-        long oneGbBytes = ByteSizeValue.ofGb(1).getBytes();
-        nodeLoads.put(buildNode("node-1", oneGbBytes, 4), NodeLoad.builder("node-1").setMaxMemory(oneGbBytes).build());
-        nodeLoads.put(buildNode("node-2", oneGbBytes, 4), NodeLoad.builder("node-2").setMaxMemory(oneGbBytes).build());
+        nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(oneGbBytes).build());
+        nodeLoads.put(node2, NodeLoad.builder("node-2").setMaxMemory(oneGbBytes).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.empty())
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.empty()
+        ).rebalance().build();
 
         assertThat(result.modelAssignments(), is(aMapWithSize(2)));
 
@@ -122,7 +145,8 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
             .build();
         expectThrows(
             ResourceAlreadyExistsException.class,
-            () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Optional.of(taskParams)).rebalance()
+            () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), nodeAvailabilityZoneMapper, Optional.of(taskParams))
+                .rebalance()
         );
     }
 
@@ -131,9 +155,12 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
         StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, 1024L, 1, 1);
         TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build();
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Optional.of(taskParams))
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            Map.of(),
+            nodeAvailabilityZoneMapper,
+            Optional.of(taskParams)
+        ).rebalance().build();
 
         TrainedModelAssignment assignment = result.getModelAssignment(modelId);
         assertThat(assignment, is(notNullValue()));
@@ -144,16 +171,23 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
     }
 
     public void testRebalance_GivenFirstModelToAdd_NotEnoughProcessors() throws Exception {
+        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
+        DiscoveryNode node = buildNode("node-1", nodeMemoryBytes, 3);
+        givenNodeAvailabilityZoneMapper(DiscoveryNodes.builder().add(node).build());
+
         String modelId = "model-to-add";
         StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, 1024L, 1, 4);
         TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build();
         Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
-        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
-        nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 3), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams))
-            .rebalance()
-            .build();
+        nodeLoads.put(node, NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
+
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.of(taskParams)
+        ).rebalance().build();
 
         TrainedModelAssignment assignment = result.getModelAssignment(modelId);
         assertThat(assignment, is(notNullValue()));
@@ -177,9 +211,12 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
         long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
         nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 3), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams))
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.of(taskParams)
+        ).rebalance().build();
 
         TrainedModelAssignment assignment = result.getModelAssignment(modelId);
         assertThat(assignment, is(notNullValue()));
@@ -203,9 +240,12 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
             NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).setError("error detecting load").build()
         );
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams))
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.of(taskParams)
+        ).rebalance().build();
 
         TrainedModelAssignment assignment = result.getModelAssignment(modelId);
         assertThat(assignment, is(notNullValue()));
@@ -219,22 +259,23 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
     }
 
     public void testRebalance_GivenProblemsOnMultipleNodes() throws Exception {
+        DiscoveryNode node1 = buildNode("node-1", ByteSizeValue.ofGb(1).getBytes(), 8);
+        DiscoveryNode node2 = buildNode("node-2", ByteSizeValue.ofGb(10).getBytes(), 3);
+        givenNodeAvailabilityZoneMapper(DiscoveryNodes.builder().add(node1).add(node2).build());
+
         String modelId = "model-to-add";
         StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, ByteSizeValue.ofGb(2).getBytes(), 1, 4);
         TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build();
         Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
-        nodeLoads.put(
-            buildNode("node-1", ByteSizeValue.ofGb(1).getBytes(), 8),
-            NodeLoad.builder("node-1").setMaxMemory(ByteSizeValue.ofGb(1).getBytes()).build()
-        );
-        nodeLoads.put(
-            buildNode("node-2", ByteSizeValue.ofGb(10).getBytes(), 3),
-            NodeLoad.builder("node-2").setMaxMemory(ByteSizeValue.ofGb(10).getBytes()).build()
-        );
+        nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(ByteSizeValue.ofGb(1).getBytes()).build());
+        nodeLoads.put(node2, NodeLoad.builder("node-2").setMaxMemory(ByteSizeValue.ofGb(10).getBytes()).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams))
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.of(taskParams)
+        ).rebalance().build();
 
         TrainedModelAssignment assignment = result.getModelAssignment(modelId);
         assertThat(assignment, is(notNullValue()));
@@ -252,16 +293,22 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
     }
 
     public void testRebalance_GivenFirstModelToAdd_FitsFully() throws Exception {
+        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
+        DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 4);
+        givenNodeAvailabilityZoneMapper(DiscoveryNodes.builder().add(node1).build());
+
         String modelId = "model-to-add";
         StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelId, 1024L, 1, 1);
         TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty().build();
         Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
-        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
-        nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 4), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
+        nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams))
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.of(taskParams)
+        ).rebalance().build();
 
         TrainedModelAssignment assignment = result.getModelAssignment(modelId);
         assertThat(assignment, is(notNullValue()));
@@ -275,6 +322,11 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
     }
 
     public void testRebalance_GivenModelToAdd_AndPreviousAssignments_AndTwoNodes_AllFit() throws Exception {
+        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
+        DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 4);
+        DiscoveryNode node2 = buildNode("node-2", nodeMemoryBytes, 4);
+        givenNodeAvailabilityZoneMapper(DiscoveryNodes.builder().add(node1).add(node2).build());
+
         String modelToAddId = "model-to-add";
         String previousModelId = "previous-model";
         StartTrainedModelDeploymentAction.TaskParams taskParams = newParams(modelToAddId, 1024L, 1, 2);
@@ -287,13 +339,15 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
             )
             .build();
         Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
-        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
-        nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 4), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
-        nodeLoads.put(buildNode("node-2", nodeMemoryBytes, 4), NodeLoad.builder("node-2").setMaxMemory(nodeMemoryBytes).build());
+        nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
+        nodeLoads.put(node2, NodeLoad.builder("node-2").setMaxMemory(nodeMemoryBytes).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.of(taskParams))
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.of(taskParams)
+        ).rebalance().build();
 
         assertThat(result.modelAssignments(), is(aMapWithSize(2)));
 
@@ -326,6 +380,12 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
     }
 
     public void testRebalance_GivenPreviousAssignments_AndNewNode() throws Exception {
+        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
+        DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 4);
+        DiscoveryNode node2 = buildNode("node-2", nodeMemoryBytes, 4);
+        DiscoveryNode node3 = buildNode("node-3", nodeMemoryBytes, 4);
+        givenNodeAvailabilityZoneMapper(DiscoveryNodes.builder().add(node1).add(node2).add(node3).build());
+
         String previousModel1Id = "previous-model-1";
         String previousModel2Id = "previous-model-2";
         TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty()
@@ -342,14 +402,16 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
             )
             .build();
         Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
-        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
-        nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 4), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
-        nodeLoads.put(buildNode("node-2", nodeMemoryBytes, 4), NodeLoad.builder("node-2").setMaxMemory(nodeMemoryBytes).build());
-        nodeLoads.put(buildNode("node-3", nodeMemoryBytes, 4), NodeLoad.builder("node-3").setMaxMemory(nodeMemoryBytes).build());
+        nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
+        nodeLoads.put(node2, NodeLoad.builder("node-2").setMaxMemory(nodeMemoryBytes).build());
+        nodeLoads.put(node3, NodeLoad.builder("node-3").setMaxMemory(nodeMemoryBytes).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.empty())
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.empty()
+        ).rebalance().build();
 
         assertThat(result.modelAssignments(), is(aMapWithSize(2)));
 
@@ -386,6 +448,10 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
     }
 
     public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNodeNotLargeEnough() throws Exception {
+        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
+        DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 4);
+        givenNodeAvailabilityZoneMapper(DiscoveryNodes.builder().add(node1).build());
+
         String previousModel1Id = "previous-model-1";
         String previousModel2Id = "previous-model-2";
         TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty()
@@ -402,12 +468,14 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
             )
             .build();
         Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
-        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
-        nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 4), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
+        nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.empty())
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.empty()
+        ).rebalance().build();
 
         assertThat(result.modelAssignments(), is(aMapWithSize(2)));
 
@@ -450,6 +518,10 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
     }
 
     public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNodeLargeEnough() throws Exception {
+        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
+        DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 7);
+        givenNodeAvailabilityZoneMapper(DiscoveryNodes.builder().add(node1).build());
+
         String previousModel1Id = "previous-model-1";
         String previousModel2Id = "previous-model-2";
         TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty()
@@ -466,12 +538,14 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
             )
             .build();
         Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
-        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
-        nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 7), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
+        nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.empty())
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.empty()
+        ).rebalance().build();
 
         assertThat(result.modelAssignments(), is(aMapWithSize(2)));
 
@@ -500,6 +574,10 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
     }
 
     public void testRebalance_GivenFailedAssignment_RestartsAssignment() throws Exception {
+        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
+        DiscoveryNode node1 = buildNode("node-1", nodeMemoryBytes, 4);
+        givenNodeAvailabilityZoneMapper(DiscoveryNodes.builder().add(node1).build());
+
         String modelId = "model-1";
         TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty()
             .addNewAssignment(
@@ -509,12 +587,14 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
             )
             .build();
         Map<DiscoveryNode, NodeLoad> nodeLoads = new HashMap<>();
-        long nodeMemoryBytes = ByteSizeValue.ofGb(1).getBytes();
-        nodeLoads.put(buildNode("node-1", nodeMemoryBytes, 4), NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
+        nodeLoads.put(node1, NodeLoad.builder("node-1").setMaxMemory(nodeMemoryBytes).build());
 
-        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(currentMetadata, nodeLoads, Optional.empty())
-            .rebalance()
-            .build();
+        TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer(
+            currentMetadata,
+            nodeLoads,
+            nodeAvailabilityZoneMapper,
+            Optional.empty()
+        ).rebalance().build();
 
         assertThat(result.modelAssignments(), is(aMapWithSize(1)));
 
@@ -529,6 +609,15 @@ public class TrainedModelAssignmentRebalancerTests extends ESTestCase {
         assertThat(assignment.getReason().isPresent(), is(false));
     }
 
+    private void givenNodeAvailabilityZoneMapper(DiscoveryNodes discoveryNodes) {
+        Settings settings = Settings.EMPTY;
+        ClusterSettings clusterSettings = new ClusterSettings(
+            settings,
+            Set.of(AwarenessAllocationDecider.CLUSTER_ROUTING_ALLOCATION_AWARENESS_ATTRIBUTE_SETTING)
+        );
+        nodeAvailabilityZoneMapper = new NodeAvailabilityZoneMapper(settings, clusterSettings, discoveryNodes);
+    }
+
     private static StartTrainedModelDeploymentAction.TaskParams newParams(
         String modelId,
         long modelSize,

+ 14 - 10
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java

@@ -287,7 +287,7 @@ public class AssignmentPlannerTests extends ESTestCase {
             int scale = randomIntBetween(0, 10);
             double load = randomDoubleBetween(0.1, 1.0, true);
             List<Node> nodes = randomNodes(scale);
-            List<Model> models = randomModels(scale, load, nodes);
+            List<Model> models = randomModels(scale, load);
             nodeSizes.add(nodes.size());
             modelSizes.add(models.size());
             logger.debug("Nodes = " + nodes.size() + "; Models = " + models.size());
@@ -324,7 +324,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         int scale = randomIntBetween(0, 10);
         double load = randomDoubleBetween(0.1, 1.0, true);
         List<Node> nodes = randomNodes(scale);
-        List<Model> models = randomModels(scale, load, nodes);
+        List<Model> models = randomModels(scale, load);
         AssignmentPlan originalPlan = new AssignmentPlanner(nodes, models).computePlan();
 
         List<Model> previousModelsPlusNew = new ArrayList<>(models.size() + 1);
@@ -453,7 +453,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         assertThat(assignmentPlan.getRemainingNodeCores("n_2"), equalTo(0));
     }
 
-    private static List<Model> createModelsFromPlan(AssignmentPlan plan) {
+    public static List<Model> createModelsFromPlan(AssignmentPlan plan) {
         List<Model> models = new ArrayList<>();
         for (Model m : plan.models()) {
             Optional<Map<Node, Integer>> assignments = plan.assignments(m);
@@ -479,7 +479,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         return models;
     }
 
-    private static Map<String, Map<String, Integer>> convertToIdIndexed(AssignmentPlan plan) {
+    public static Map<String, Map<String, Integer>> convertToIdIndexed(AssignmentPlan plan) {
         Map<String, Map<String, Integer>> result = new HashMap<>();
         for (Model m : plan.models()) {
             Optional<Map<Node, Integer>> assignments = plan.assignments(m);
@@ -492,14 +492,18 @@ public class AssignmentPlannerTests extends ESTestCase {
         return result;
     }
 
-    private static void assertModelFullyAssignedToNode(AssignmentPlan plan, Model m, Node n) {
+    public static void assertModelFullyAssignedToNode(AssignmentPlan plan, Model m, Node n) {
         Optional<Map<Node, Integer>> assignments = plan.assignments(m);
         assertThat(assignments.isPresent(), is(true));
         assertThat(assignments.get().size(), equalTo(1));
         assertThat(assignments.get().get(n), equalTo(m.allocations()));
     }
 
-    private List<Node> randomNodes(int scale) {
+    public static List<Node> randomNodes(int scale) {
+        return randomNodes(scale, "");
+    }
+
+    public static List<Node> randomNodes(int scale, String nodeIdPrefix) {
         Long[] memBytesPerCoreValues = {
             ByteSizeValue.ofGb(1).getBytes() / 2,
             ByteSizeValue.ofGb(1).getBytes(),
@@ -511,12 +515,12 @@ public class AssignmentPlannerTests extends ESTestCase {
         for (int i = 0; i < 1 + 3 * scale; i++) {
             int cores = randomIntBetween(2, 32);
             long memBytesPerCore = randomFrom(memBytesPerCoreValues);
-            nodes.add(new Node("n_" + i, cores * memBytesPerCore, cores));
+            nodes.add(new Node(nodeIdPrefix + "n_" + i, cores * memBytesPerCore, cores));
         }
         return nodes;
     }
 
-    private List<Model> randomModels(int scale, double load, List<Node> nodes) {
+    public static List<Model> randomModels(int scale, double load) {
         List<Model> models = new ArrayList<>();
         for (int i = 0; i < Math.max(2, Math.round(load * (1 + 8 * scale))); i++) {
             models.add(randomModel(String.valueOf(i)));
@@ -524,7 +528,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         return models;
     }
 
-    private static Model randomModel(String idSuffix) {
+    public static Model randomModel(String idSuffix) {
         int allocations = randomIntBetween(1, 32);
         return new Model(
             "m_" + idSuffix,
@@ -536,7 +540,7 @@ public class AssignmentPlannerTests extends ESTestCase {
         );
     }
 
-    private static void assertPreviousAssignmentsAreSatisfied(List<Model> models, AssignmentPlan assignmentPlan) {
+    public static void assertPreviousAssignmentsAreSatisfied(List<Model> models, AssignmentPlan assignmentPlan) {
         for (Model m : models.stream().filter(m -> m.currentAllocationsByNodeId().isEmpty() == false).toList()) {
             Map<Node, Integer> assignments = assignmentPlan.assignments(m).get();
             Set<String> assignedNodeIds = new HashSet<>();

+ 281 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java

@@ -0,0 +1,281 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.assignment.planning;
+
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Model;
+import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan.Node;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlannerTests.assertModelFullyAssignedToNode;
+import static org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlannerTests.assertPreviousAssignmentsAreSatisfied;
+import static org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlannerTests.convertToIdIndexed;
+import static org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlannerTests.createModelsFromPlan;
+import static org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlannerTests.randomModel;
+import static org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlannerTests.randomModels;
+import static org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlannerTests.randomNodes;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.hasItems;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+
+public class ZoneAwareAssignmentPlannerTests extends ESTestCase {
+
+    public void testGivenOneModel_OneNode_OneZone_DoesNotFit() {
+        Node node = new Node("n_1", 100, 1);
+        Model model = new Model("m_1", 100, 1, 2, Map.of(), 0);
+
+        AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(model)).computePlan();
+
+        assertThat(plan.assignments(model).isEmpty(), is(true));
+    }
+
+    public void testGivenOneModel_OneNode_OneZone_FullyFits() {
+        Node node = new Node("n_1", 100, 4);
+        Model model = new Model("m_1", 100, 2, 2, Map.of(), 0);
+
+        AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(model)).computePlan();
+
+        assertModelFullyAssignedToNode(plan, model, node);
+    }
+
+    public void testGivenOneModel_OneNode_OneZone_PartiallyFits() {
+        Node node = new Node("n_1", 100, 5);
+        Model model = new Model("m_1", 100, 3, 2, Map.of(), 0);
+
+        AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(model)).computePlan();
+
+        Map<String, Map<String, Integer>> indexedBasedPlan = convertToIdIndexed(plan);
+        assertThat(indexedBasedPlan.keySet(), hasItems("m_1"));
+        assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2)));
+    }
+
+    public void testGivenOneModelWithSingleAllocation_OneNode_TwoZones() {
+        Node node1 = new Node("n_1", 100, 4);
+        Node node2 = new Node("n_2", 100, 4);
+        Model model = new Model("m_1", 100, 1, 2, Map.of(), 0);
+
+        AssignmentPlan plan = new ZoneAwareAssignmentPlanner(
+            Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2)),
+            List.of(model)
+        ).computePlan();
+
+        assertThat(plan.satisfiesAllModels(), is(true));
+
+        assertThat(plan.assignments(model).isPresent(), is(true));
+        Map<Node, Integer> assignments = plan.assignments(model).get();
+        assertThat(assignments.keySet(), hasSize(1));
+        assertThat(assignments.get(assignments.keySet().iterator().next()), equalTo(1));
+    }
+
+    public void testGivenOneModel_OneNodePerZone_TwoZones_FullyFits() {
+        Node node1 = new Node("n_1", 100, 4);
+        Node node2 = new Node("n_2", 100, 4);
+        Model model = new Model("m_1", 100, 2, 2, Map.of(), 0);
+
+        AssignmentPlan plan = new ZoneAwareAssignmentPlanner(
+            Map.of(List.of("z_1"), List.of(node1), List.of("z_2"), List.of(node2)),
+            List.of(model)
+        ).computePlan();
+
+        assertThat(plan.satisfiesAllModels(), is(true));
+
+        Map<String, Map<String, Integer>> indexedBasedPlan = convertToIdIndexed(plan);
+        assertThat(indexedBasedPlan.keySet(), hasItems("m_1"));
+        assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 1, "n_2", 1)));
+    }
+
+    public void testGivenOneModel_OneNodePerZone_TwoZones_PartiallyFits() {
+        Node node1 = new Node("n_1", 100, 4);
+        Node node2 = new Node("n_2", 100, 4);
+        Model model = new Model("m_1", 100, 3, 3, Map.of(), 0);
+
+        AssignmentPlan plan = new ZoneAwareAssignmentPlanner(
+            Map.of(List.of("z_1"), List.of(node1), List.of("z_2"), List.of(node2)),
+            List.of(model)
+        ).computePlan();
+
+        Map<String, Map<String, Integer>> indexedBasedPlan = convertToIdIndexed(plan);
+        assertThat(indexedBasedPlan.keySet(), hasItems("m_1"));
+        assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 1, "n_2", 1)));
+    }
+
+    public void testGivenThreeModels_TwoNodesPerZone_ThreeZones_FullyFit() {
+        Node node1 = new Node("n_1", 100, 4);
+        Node node2 = new Node("n_2", 100, 4);
+        Node node3 = new Node("n_3", 100, 4);
+        Node node4 = new Node("n_4", 100, 4);
+        Node node5 = new Node("n_5", 100, 4);
+        Node node6 = new Node("n_6", 100, 4);
+        Model model1 = new Model("m_1", 25, 4, 1, Map.of(), 0);
+        Model model2 = new Model("m_2", 25, 6, 2, Map.of(), 0);
+        Model model3 = new Model("m_3", 25, 2, 3, Map.of(), 0);
+
+        Map<List<String>, List<Node>> nodesByZone = Map.of(
+            List.of("z_1"),
+            List.of(node1, node2),
+            List.of("z_2"),
+            List.of(node3, node4),
+            List.of("z_3"),
+            List.of(node5, node6)
+        );
+
+        AssignmentPlan plan = new ZoneAwareAssignmentPlanner(nodesByZone, List.of(model1, model2, model3)).computePlan();
+
+        assertThat(plan.satisfiesAllModels(), is(true));
+
+        {
+            assertThat(plan.assignments(model1).isPresent(), is(true));
+            Map<Node, Integer> assignments = plan.assignments(model1).get();
+            for (List<Node> zoneNodes : nodesByZone.values()) {
+                assertThat(Sets.haveNonEmptyIntersection(assignments.keySet(), zoneNodes.stream().collect(Collectors.toSet())), is(true));
+            }
+        }
+        {
+            assertThat(plan.assignments(model2).isPresent(), is(true));
+            Map<Node, Integer> assignments = plan.assignments(model2).get();
+            for (List<Node> zoneNodes : nodesByZone.values()) {
+                assertThat(Sets.haveNonEmptyIntersection(assignments.keySet(), zoneNodes.stream().collect(Collectors.toSet())), is(true));
+            }
+        }
+        {
+            assertThat(plan.assignments(model3).isPresent(), is(true));
+            Map<Node, Integer> assignments = plan.assignments(model3).get();
+            int zonesWithAllocations = 0;
+            for (List<Node> zoneNodes : nodesByZone.values()) {
+                if (Sets.haveNonEmptyIntersection(assignments.keySet(), zoneNodes.stream().collect(Collectors.toSet()))) {
+                    zonesWithAllocations++;
+                }
+            }
+            assertThat(zonesWithAllocations, equalTo(2));
+        }
+    }
+
+    public void testGivenTwoModelsWithSingleAllocation_OneNode_ThreeZones() {
+        Node node1 = new Node("n_1", 100, 4);
+        Node node2 = new Node("n_2", 100, 4);
+        Node node3 = new Node("n_3", 100, 4);
+        Model model1 = new Model("m_1", 25, 1, 1, Map.of(), 0);
+        Model model2 = new Model("m_2", 25, 1, 1, Map.of(), 0);
+
+        AssignmentPlan plan = new ZoneAwareAssignmentPlanner(
+            Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2), List.of("z3"), List.of(node3)),
+            List.of(model1, model2)
+        ).computePlan();
+
+        assertThat(plan.satisfiesAllModels(), is(true));
+    }
+
+    public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewModel() {
+        int scale = randomIntBetween(0, 10);
+        double load = randomDoubleBetween(0.1, 1.0, true);
+        Map<List<String>, List<Node>> nodesByZone = Map.of(
+            List.of("z_1"),
+            randomNodes(scale, "z_1_"),
+            List.of("z_2"),
+            randomNodes(scale, "z_2_"),
+            List.of("z_3"),
+            randomNodes(scale, "z_3_")
+        );
+        List<Model> models = randomModels(scale, load);
+        AssignmentPlan originalPlan = new ZoneAwareAssignmentPlanner(nodesByZone, models).computePlan();
+
+        List<Model> previousModelsPlusNew = new ArrayList<>(models.size() + 1);
+        for (Model m : models) {
+            Map<Node, Integer> assignments = originalPlan.assignments(m).orElse(Map.of());
+            Map<String, Integer> previousAssignments = assignments.entrySet()
+                .stream()
+                .collect(Collectors.toMap(e -> e.getKey().id(), Map.Entry::getValue));
+            previousModelsPlusNew.add(
+                new Model(m.id(), m.memoryBytes(), m.allocations(), m.threadsPerAllocation(), previousAssignments, 0)
+            );
+        }
+        previousModelsPlusNew.add(randomModel("new"));
+
+        AssignmentPlan assignmentPlan = new ZoneAwareAssignmentPlanner(nodesByZone, previousModelsPlusNew).computePlan();
+
+        assertPreviousAssignmentsAreSatisfied(previousModelsPlusNew, assignmentPlan);
+    }
+
+    public void testGivenClusterResize_GivenOneZone_ShouldAllocateEachModelAtLeastOnce() {
+        Node node1 = new Node("n_1", ByteSizeValue.ofMb(1200).getBytes(), 2);
+        Node node2 = new Node("n_2", ByteSizeValue.ofMb(1200).getBytes(), 2);
+        Model model1 = new Model("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0);
+        Model model2 = new Model("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0);
+        Model model3 = new Model("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0);
+
+        // First only start m_1
+        AssignmentPlan assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1, node2)), List.of(model1))
+            .computePlan();
+
+        Map<String, Map<String, Integer>> indexedBasedPlan = convertToIdIndexed(assignmentPlan);
+        assertThat(indexedBasedPlan.keySet(), hasItems("m_1"));
+        assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2)));
+
+        // Then start m_2
+        assignmentPlan = new ZoneAwareAssignmentPlanner(
+            Map.of(List.of(), List.of(node1, node2)),
+            Stream.concat(createModelsFromPlan(assignmentPlan).stream(), Stream.of(model2)).toList()
+        ).computePlan();
+
+        indexedBasedPlan = convertToIdIndexed(assignmentPlan);
+        assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2"));
+        assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2)));
+        assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1)));
+
+        // Then start m_3
+        assignmentPlan = new ZoneAwareAssignmentPlanner(
+            Map.of(List.of(), List.of(node1, node2)),
+            Stream.concat(createModelsFromPlan(assignmentPlan).stream(), Stream.of(model3)).toList()
+        ).computePlan();
+
+        indexedBasedPlan = convertToIdIndexed(assignmentPlan);
+        assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2", "m_3"));
+        assertThat(indexedBasedPlan.get("m_1"), equalTo(Map.of("n_1", 2)));
+        assertThat(indexedBasedPlan.get("m_2"), equalTo(Map.of("n_2", 1)));
+        assertThat(indexedBasedPlan.get("m_3"), equalTo(Map.of("n_2", 1)));
+
+        // Now the cluster starts getting resized.
+        Node node3 = new Node("n_3", ByteSizeValue.ofMb(2400).getBytes(), 2);
+        Node node4 = new Node("n_4", ByteSizeValue.ofMb(2400).getBytes(), 2);
+
+        // First, one node goes away.
+        assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1)), createModelsFromPlan(assignmentPlan))
+            .computePlan();
+
+        // Then, a node double in memory size is added.
+        assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1, node3)), createModelsFromPlan(assignmentPlan))
+            .computePlan();
+        // And another.
+        assignmentPlan = new ZoneAwareAssignmentPlanner(
+            Map.of(List.of(), List.of(node1, node3, node4)),
+            createModelsFromPlan(assignmentPlan)
+        ).computePlan();
+        // Finally, the remaining smaller node is removed
+        assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node3, node4)), createModelsFromPlan(assignmentPlan))
+            .computePlan();
+
+        indexedBasedPlan = convertToIdIndexed(assignmentPlan);
+        assertThat(indexedBasedPlan.keySet(), hasItems("m_1", "m_2", "m_3"));
+        assertThat(indexedBasedPlan.get("m_1").values().stream().mapToInt(Integer::intValue).sum(), greaterThanOrEqualTo(1));
+        assertThat(indexedBasedPlan.get("m_2").values().stream().mapToInt(Integer::intValue).sum(), greaterThanOrEqualTo(1));
+        assertThat(indexedBasedPlan.get("m_3").values().stream().mapToInt(Integer::intValue).sum(), greaterThanOrEqualTo(1));
+
+        // Assert that all cores are utilized
+        assertThat(assignmentPlan.getRemainingNodeCores("n_1"), equalTo(0));
+        assertThat(assignmentPlan.getRemainingNodeCores("n_2"), equalTo(0));
+    }
+}