Browse Source

[ML] enable autoscaling for trained model deployments (#77931)

Initial support for autoscaling with model deployments.

This will request more capacity for every model deployment that is STARTING state and is not allocated on any node.

Once a deployment is on at least one node, we will no longer request extra capacity due to that deployment
Benjamin Trent 4 years ago
parent
commit
a3bc38ca6b

+ 120 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AutoscalingIT.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.integration;
 
 
 import org.elasticsearch.action.admin.cluster.node.info.NodeInfo;
 import org.elasticsearch.action.admin.cluster.node.info.NodeInfo;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.core.TimeValue;
@@ -16,6 +17,15 @@ import org.elasticsearch.xpack.autoscaling.action.GetAutoscalingCapacityAction;
 import org.elasticsearch.xpack.autoscaling.action.PutAutoscalingPolicyAction;
 import org.elasticsearch.xpack.autoscaling.action.PutAutoscalingPolicyAction;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderResult;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderResult;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderResults;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderResults;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
 import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
@@ -24,6 +34,8 @@ import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.autoscaling.MlAutoscalingDeciderService;
 import org.elasticsearch.xpack.ml.autoscaling.MlAutoscalingDeciderService;
 import org.elasticsearch.xpack.ml.autoscaling.NativeMemoryCapacity;
 import org.elasticsearch.xpack.ml.autoscaling.NativeMemoryCapacity;
+import org.junit.After;
+import org.junit.Before;
 
 
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
@@ -34,6 +46,7 @@ import java.util.TreeSet;
 import java.util.stream.Collectors;
 import java.util.stream.Collectors;
 
 
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+import static org.elasticsearch.xpack.ml.integration.PyTorchModelIT.BASE_64_ENCODED_MODEL;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.hasKey;
 import static org.hamcrest.Matchers.hasKey;
@@ -44,6 +57,31 @@ public class AutoscalingIT extends MlNativeAutodetectIntegTestCase {
     private static final long NATIVE_PROCESS_OVERHEAD_MB = 30;
     private static final long NATIVE_PROCESS_OVERHEAD_MB = 30;
     private static final long BASELINE_OVERHEAD_MB = BASIC_REQUIREMENT_MB + NATIVE_PROCESS_OVERHEAD_MB;
     private static final long BASELINE_OVERHEAD_MB = BASIC_REQUIREMENT_MB + NATIVE_PROCESS_OVERHEAD_MB;
 
 
+    @Before
+    public void putSettings() {
+        client().admin()
+            .cluster()
+            .prepareUpdateSettings()
+            .setTransientSettings(Settings.builder()
+                .put(MachineLearning.MAX_LAZY_ML_NODES.getKey(), 100)
+                .put("logger.org.elasticsearch.xpack.ml", "TRACE")
+            )
+            .get();
+    }
+
+    @After
+    public void removeSettings() {
+        client().admin()
+            .cluster()
+            .prepareUpdateSettings()
+            .setTransientSettings(Settings.builder()
+                .putNull(MachineLearning.MAX_LAZY_ML_NODES.getKey())
+                .putNull("logger.org.elasticsearch.xpack.ml")
+            )
+            .get();
+        cleanUp();
+    }
+
     // This test assumes that xpack.ml.max_machine_memory_percent is 30
     // This test assumes that xpack.ml.max_machine_memory_percent is 30
     // and that xpack.ml.use_auto_machine_memory_percent is false
     // and that xpack.ml.use_auto_machine_memory_percent is false
     public void testMLAutoscalingCapacity() throws Exception {
     public void testMLAutoscalingCapacity() throws Exception {
@@ -145,6 +183,62 @@ public class AutoscalingIT extends MlNativeAutodetectIntegTestCase {
             0L);
             0L);
     }
     }
 
 
+    public void testMLAutoscalingForLargeModelAllocation() {
+        String modelId = "really_big_model";
+        SortedMap<String, Settings> deciders = new TreeMap<>();
+        deciders.put(
+            MlAutoscalingDeciderService.NAME,
+            Settings.builder().put(MlAutoscalingDeciderService.DOWN_SCALE_DELAY.getKey(), TimeValue.ZERO).build()
+        );
+        final PutAutoscalingPolicyAction.Request request = new PutAutoscalingPolicyAction.Request(
+            "ml_test",
+            new TreeSet<>(Arrays.asList("master", "data", "ingest", "ml")),
+            deciders
+        );
+        assertAcked(client().execute(PutAutoscalingPolicyAction.INSTANCE, request).actionGet());
+        putAndStartModelDeployment("smaller1", ByteSizeValue.ofMb(100).getBytes(), AllocationStatus.State.STARTED);
+        putAndStartModelDeployment("smaller2", ByteSizeValue.ofMb(100).getBytes(), AllocationStatus.State.STARTED);
+        long expectedTierBytes = (long) Math.ceil(
+            ByteSizeValue.ofMb(100 + BASELINE_OVERHEAD_MB + 200 + BASELINE_OVERHEAD_MB).getBytes() * 100 / 30.0
+        );
+        long expectedNodeBytes = (long) Math.ceil(ByteSizeValue.ofMb(200 + BASELINE_OVERHEAD_MB).getBytes() * 100 / 30.0);
+
+        assertMlCapacity(
+            client().execute(GetAutoscalingCapacityAction.INSTANCE, new GetAutoscalingCapacityAction.Request()).actionGet(),
+            "Requesting scale down as tier and/or node size could be smaller",
+            expectedTierBytes,
+            expectedNodeBytes
+        );
+
+        long modelSize = ByteSizeValue.ofMb(50_000).getBytes();
+        putAndStartModelDeployment(modelId, modelSize, AllocationStatus.State.STARTING);
+
+        List<DiscoveryNode> mlNodes = admin()
+            .cluster()
+            .prepareNodesInfo()
+            .all()
+            .get()
+            .getNodes()
+            .stream()
+            .map(NodeInfo::getNode)
+            .filter(MachineLearning::isMlNode)
+            .collect(Collectors.toList());
+        NativeMemoryCapacity currentScale = MlAutoscalingDeciderService.currentScale(mlNodes, 30, false);
+        expectedTierBytes = (long)Math.ceil(
+            (ByteSizeValue.ofMb(50_000 + BASIC_REQUIREMENT_MB).getBytes() + currentScale.getTier()) * 100 / 30.0
+        );
+        expectedNodeBytes = (long) (ByteSizeValue.ofMb(50_000 + BASELINE_OVERHEAD_MB).getBytes() * 100 / 30.0);
+
+        assertMlCapacity(
+            client().execute(GetAutoscalingCapacityAction.INSTANCE, new GetAutoscalingCapacityAction.Request()).actionGet(),
+            "requesting scale up as number of jobs in queues exceeded configured limit "
+                + "or there is at least one trained model waiting for allocation "
+                + "and current capacity is not large enough for waiting jobs",
+            expectedTierBytes,
+            expectedNodeBytes
+        );
+    }
+
     private void assertMlCapacity(GetAutoscalingCapacityAction.Response capacity, String reason, long tierBytes, long nodeBytes) {
     private void assertMlCapacity(GetAutoscalingCapacityAction.Response capacity, String reason, long tierBytes, long nodeBytes) {
         assertThat(capacity.getResults(), hasKey("ml_test"));
         assertThat(capacity.getResults(), hasKey("ml_test"));
         AutoscalingDeciderResults autoscalingDeciderResults = capacity.getResults().get("ml_test");
         AutoscalingDeciderResults autoscalingDeciderResults = capacity.getResults().get("ml_test");
@@ -175,4 +269,30 @@ public class AutoscalingIT extends MlNativeAutodetectIntegTestCase {
 
 
         putJob(job);
         putJob(job);
     }
     }
+
+    private void putAndStartModelDeployment(String modelId, long memoryUse, AllocationStatus.State state) {
+        client().execute(
+            PutTrainedModelAction.INSTANCE,
+            new PutTrainedModelAction.Request(
+                TrainedModelConfig.builder()
+                    .setModelType(TrainedModelType.PYTORCH)
+                    .setInferenceConfig(new PassThroughConfig(null, new BertTokenization(null, false, null)))
+                    .setModelId(modelId)
+                    .build(),
+                false
+            )
+        ).actionGet();
+        client().execute(
+            PutTrainedModelDefinitionPartAction.INSTANCE,
+            new PutTrainedModelDefinitionPartAction.Request(modelId, new BytesArray(BASE_64_ENCODED_MODEL), 0, memoryUse, 1)
+        ).actionGet();
+        client().execute(
+            PutTrainedModelVocabularyAction.INSTANCE,
+            new PutTrainedModelVocabularyAction.Request(modelId, List.of("these", "are", "my", "words"))
+        ).actionGet();
+        client().execute(
+            StartTrainedModelDeploymentAction.INSTANCE,
+            new StartTrainedModelDeploymentAction.Request(modelId).setWaitForState(state)
+        ).actionGet();
+    }
 }
 }

+ 21 - 12
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

@@ -25,6 +25,7 @@ import org.elasticsearch.cluster.metadata.NodesShutdownMetadata;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.LicenseUtils;
@@ -74,11 +75,12 @@ public class TransportStartTrainedModelDeploymentAction
     private final TrainedModelAllocationService trainedModelAllocationService;
     private final TrainedModelAllocationService trainedModelAllocationService;
     private final NamedXContentRegistry xContentRegistry;
     private final NamedXContentRegistry xContentRegistry;
     private final MlMemoryTracker memoryTracker;
     private final MlMemoryTracker memoryTracker;
+    protected volatile int maxLazyMLNodes;
 
 
     @Inject
     @Inject
     public TransportStartTrainedModelDeploymentAction(TransportService transportService, Client client, ClusterService clusterService,
     public TransportStartTrainedModelDeploymentAction(TransportService transportService, Client client, ClusterService clusterService,
                                                       ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState,
                                                       ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState,
-                                                      IndexNameExpressionResolver indexNameExpressionResolver,
+                                                      IndexNameExpressionResolver indexNameExpressionResolver, Settings settings,
                                                       TrainedModelAllocationService trainedModelAllocationService,
                                                       TrainedModelAllocationService trainedModelAllocationService,
                                                       NamedXContentRegistry xContentRegistry, MlMemoryTracker memoryTracker) {
                                                       NamedXContentRegistry xContentRegistry, MlMemoryTracker memoryTracker) {
         super(StartTrainedModelDeploymentAction.NAME, transportService, clusterService, threadPool, actionFilters,
         super(StartTrainedModelDeploymentAction.NAME, transportService, clusterService, threadPool, actionFilters,
@@ -89,6 +91,12 @@ public class TransportStartTrainedModelDeploymentAction
         this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
         this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
         this.memoryTracker = Objects.requireNonNull(memoryTracker);
         this.memoryTracker = Objects.requireNonNull(memoryTracker);
         this.trainedModelAllocationService = Objects.requireNonNull(trainedModelAllocationService);
         this.trainedModelAllocationService = Objects.requireNonNull(trainedModelAllocationService);
+        this.maxLazyMLNodes = MachineLearning.MAX_LAZY_ML_NODES.get(settings);
+        clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_LAZY_ML_NODES, this::setMaxLazyMLNodes);
+    }
+
+    private void setMaxLazyMLNodes(int value) {
+        this.maxLazyMLNodes = value;
     }
     }
 
 
     @Override
     @Override
@@ -198,7 +206,7 @@ public class TransportStartTrainedModelDeploymentAction
         AllocationStatus.State state,
         AllocationStatus.State state,
         ActionListener<CreateTrainedModelAllocationAction.Response> listener
         ActionListener<CreateTrainedModelAllocationAction.Response> listener
     ) {
     ) {
-        DeploymentStartedPredicate predicate = new DeploymentStartedPredicate(modelId, state);
+        DeploymentStartedPredicate predicate = new DeploymentStartedPredicate(modelId, state, maxLazyMLNodes);
         trainedModelAllocationService.waitForAllocationCondition(modelId, predicate, timeout,
         trainedModelAllocationService.waitForAllocationCondition(modelId, predicate, timeout,
             new TrainedModelAllocationService.WaitForAllocationListener() {
             new TrainedModelAllocationService.WaitForAllocationListener() {
                 @Override
                 @Override
@@ -254,10 +262,12 @@ public class TransportStartTrainedModelDeploymentAction
         // for logging
         // for logging
         private final String modelId;
         private final String modelId;
         private final AllocationStatus.State waitForState;
         private final AllocationStatus.State waitForState;
+        private final int maxLazyMLNodes;
 
 
-        DeploymentStartedPredicate(String modelId, AllocationStatus.State waitForState) {
+        DeploymentStartedPredicate(String modelId, AllocationStatus.State waitForState, int maxLazyMLNodes) {
             this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
             this.modelId = ExceptionsHelper.requireNonNull(modelId, "model_id");
             this.waitForState = waitForState;
             this.waitForState = waitForState;
+            this.maxLazyMLNodes = maxLazyMLNodes;
         }
         }
 
 
         @Override
         @Override
@@ -292,10 +302,16 @@ public class TransportStartTrainedModelDeploymentAction
                 );
                 );
                 return true;
                 return true;
             }
             }
+            Set<String> nodesShuttingDown = nodesShuttingDown(clusterState);
+            List<DiscoveryNode> nodes = clusterState.nodes()
+                .getAllNodes()
+                .stream()
+                .filter(d -> nodesShuttingDown.contains(d.getId()) == false)
+                .filter(TaskParams::mayAllocateToNode)
+                .collect(Collectors.toList());
 
 
             // No nodes allocated at all!
             // No nodes allocated at all!
-            // TODO when we support autoscaling for this, check for `maxLazyNodes` setting
-            if (nodesAndState.isEmpty()) {
+            if (nodesAndState.isEmpty() && maxLazyMLNodes <= nodes.size()) {
                 String msg = "Could not start deployment because no suitable nodes were found, allocation explanation ["
                 String msg = "Could not start deployment because no suitable nodes were found, allocation explanation ["
                     + trainedModelAllocation.getReason()
                     + trainedModelAllocation.getReason()
                     + "]";
                     + "]";
@@ -309,13 +325,6 @@ public class TransportStartTrainedModelDeploymentAction
                 return true;
                 return true;
             }
             }
 
 
-            Set<String> nodesShuttingDown = nodesShuttingDown(clusterState);
-            List<DiscoveryNode> nodes = clusterState.nodes()
-                .getAllNodes()
-                .stream()
-                .filter(d -> nodesShuttingDown.contains(d.getId()) == false)
-                .filter(TaskParams::mayAllocateToNode)
-                .collect(Collectors.toList());
             AllocationStatus allocationStatus = trainedModelAllocation.calculateAllocationStatus(nodes).orElse(null);
             AllocationStatus allocationStatus = trainedModelAllocation.calculateAllocationStatus(nodes).orElse(null);
             if (allocationStatus == null || allocationStatus.calculateState().compareTo(waitForState) >= 0) {
             if (allocationStatus == null || allocationStatus.calculateState().compareTo(waitForState) >= 0) {
                 return true;
                 return true;

+ 57 - 12
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderService.java

@@ -31,9 +31,12 @@ import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderService;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction.DatafeedParams;
 import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction.DatafeedParams;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.job.NodeLoad;
 import org.elasticsearch.xpack.ml.job.NodeLoad;
 import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
 import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
@@ -348,6 +351,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
         PersistentTasksCustomMetadata tasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
         PersistentTasksCustomMetadata tasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
         Collection<PersistentTask<?>> anomalyDetectionTasks = anomalyDetectionTasks(tasks);
         Collection<PersistentTask<?>> anomalyDetectionTasks = anomalyDetectionTasks(tasks);
         Collection<PersistentTask<?>> dataframeAnalyticsTasks = dataframeAnalyticsTasks(tasks);
         Collection<PersistentTask<?>> dataframeAnalyticsTasks = dataframeAnalyticsTasks(tasks);
+        Map<String, TrainedModelAllocation> modelAllocations = TrainedModelAllocationMetadata.fromState(clusterState).modelAllocations();
         final List<String> waitingAnomalyJobs = anomalyDetectionTasks.stream()
         final List<String> waitingAnomalyJobs = anomalyDetectionTasks.stream()
             .filter(t -> AWAITING_LAZY_ASSIGNMENT.equals(t.getAssignment()))
             .filter(t -> AWAITING_LAZY_ASSIGNMENT.equals(t.getAssignment()))
             .map(t -> MlTasks.jobId(t.getId()))
             .map(t -> MlTasks.jobId(t.getId()))
@@ -356,6 +360,13 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
             .filter(t -> AWAITING_LAZY_ASSIGNMENT.equals(t.getAssignment()))
             .filter(t -> AWAITING_LAZY_ASSIGNMENT.equals(t.getAssignment()))
             .map(t -> MlTasks.dataFrameAnalyticsId(t.getId()))
             .map(t -> MlTasks.dataFrameAnalyticsId(t.getId()))
             .collect(Collectors.toList());
             .collect(Collectors.toList());
+        final List<String> waitingAllocatedModels = modelAllocations
+            .entrySet()
+            .stream()
+            // TODO: Eventually care about those that are STARTED but not FULLY_ALLOCATED
+            .filter(e -> e.getValue().getAllocationState().equals(AllocationState.STARTING) && e.getValue().getNodeRoutingTable().isEmpty())
+            .map(Map.Entry::getKey)
+            .collect(Collectors.toList());
 
 
         final int numAnalyticsJobsInQueue = NUM_ANALYTICS_JOBS_IN_QUEUE.get(configuration);
         final int numAnalyticsJobsInQueue = NUM_ANALYTICS_JOBS_IN_QUEUE.get(configuration);
         final int numAnomalyJobsInQueue = NUM_ANOMALY_JOBS_IN_QUEUE.get(configuration);
         final int numAnomalyJobsInQueue = NUM_ANOMALY_JOBS_IN_QUEUE.get(configuration);
@@ -366,20 +377,22 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
         final MlScalingReason.Builder reasonBuilder = MlScalingReason.builder()
         final MlScalingReason.Builder reasonBuilder = MlScalingReason.builder()
             .setWaitingAnomalyJobs(waitingAnomalyJobs)
             .setWaitingAnomalyJobs(waitingAnomalyJobs)
             .setWaitingAnalyticsJobs(waitingAnalyticsJobs)
             .setWaitingAnalyticsJobs(waitingAnalyticsJobs)
+            .setWaitingModels(waitingAllocatedModels)
             .setCurrentMlCapacity(currentScale.autoscalingCapacity(maxMachineMemoryPercent, useAuto))
             .setCurrentMlCapacity(currentScale.autoscalingCapacity(maxMachineMemoryPercent, useAuto))
             .setPassedConfiguration(configuration);
             .setPassedConfiguration(configuration);
 
 
         // There are no ML nodes, scale up as quick as possible, no matter if memory is stale or not
         // There are no ML nodes, scale up as quick as possible, no matter if memory is stale or not
         if (nodes.isEmpty()
         if (nodes.isEmpty()
             && (waitingAnomalyJobs.isEmpty() == false
             && (waitingAnomalyJobs.isEmpty() == false
-            || waitingAnalyticsJobs.isEmpty() == false)) {
-            return scaleUpFromZero(waitingAnomalyJobs, waitingAnalyticsJobs, reasonBuilder);
+            || waitingAnalyticsJobs.isEmpty() == false
+            || waitingAllocatedModels.isEmpty() == false)) {
+            return scaleUpFromZero(waitingAnomalyJobs, waitingAnalyticsJobs, waitingAllocatedModels, reasonBuilder);
         }
         }
 
 
         // We don't need to check anything as there are no tasks
         // We don't need to check anything as there are no tasks
         // This is a quick path to downscale.
         // This is a quick path to downscale.
         // simply return `0` for scale down if delay is satisfied
         // simply return `0` for scale down if delay is satisfied
-        if (anomalyDetectionTasks.isEmpty() && dataframeAnalyticsTasks.isEmpty()) {
+        if (anomalyDetectionTasks.isEmpty() && dataframeAnalyticsTasks.isEmpty() && modelAllocations.isEmpty()) {
             long msLeftToScale = msLeftToDownScale(configuration);
             long msLeftToScale = msLeftToDownScale(configuration);
             if (msLeftToScale > 0) {
             if (msLeftToScale > 0) {
                 return new AutoscalingDeciderResult(
                 return new AutoscalingDeciderResult(
@@ -462,6 +475,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
             nodeLoads,
             nodeLoads,
             waitingAnomalyJobs,
             waitingAnomalyJobs,
             waitingAnalyticsJobs,
             waitingAnalyticsJobs,
+            waitingAllocatedModels,
             futureFreedCapacity.orElse(null),
             futureFreedCapacity.orElse(null),
             currentScale,
             currentScale,
             reasonBuilder
             reasonBuilder
@@ -492,7 +506,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
                     .build()));
                     .build()));
         }
         }
 
 
-        long largestJob = Math.max(
+        long largestJobOrModel = Math.max(
             anomalyDetectionTasks.stream()
             anomalyDetectionTasks.stream()
                 .filter(PersistentTask::isAssigned)
                 .filter(PersistentTask::isAssigned)
                 // Memory SHOULD be recently refreshed, so in our current state, we should at least have an idea of the memory used
                 // Memory SHOULD be recently refreshed, so in our current state, we should at least have an idea of the memory used
@@ -513,15 +527,20 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
                 })
                 })
                 .max()
                 .max()
                 .orElse(0L));
                 .orElse(0L));
+        largestJobOrModel = Math.max(
+            largestJobOrModel,
+            modelAllocations.values().stream().mapToLong(t -> t.getTaskParams().estimateMemoryUsageBytes()).max().orElse(0L)
+        );
 
 
         // This is an exceptionally weird state
         // This is an exceptionally weird state
         // Our view of the memory is stale or we have tasks where the required job memory is 0, which should be impossible
         // Our view of the memory is stale or we have tasks where the required job memory is 0, which should be impossible
-        if (largestJob == 0L && (dataframeAnalyticsTasks.size() + anomalyDetectionTasks.size() > 0)) {
+        if (largestJobOrModel == 0L && (dataframeAnalyticsTasks.size() + anomalyDetectionTasks.size() + modelAllocations.size() > 0)) {
             logger.warn(
             logger.warn(
                 "The calculated minimum required node size was unexpectedly [0] as there are "
                 "The calculated minimum required node size was unexpectedly [0] as there are "
-                    + "[{}] anomaly job tasks and [{}] data frame analytics tasks",
+                    + "[{}] anomaly job tasks, [{}] data frame analytics tasks and [{}] model allocations",
                 anomalyDetectionTasks.size(),
                 anomalyDetectionTasks.size(),
-                dataframeAnalyticsTasks.size()
+                dataframeAnalyticsTasks.size(),
+                modelAllocations.size()
             );
             );
             return noScaleResultOrRefresh(reasonBuilder, true, new AutoscalingDeciderResult(
             return noScaleResultOrRefresh(reasonBuilder, true, new AutoscalingDeciderResult(
                 context.currentCapacity(),
                 context.currentCapacity(),
@@ -531,7 +550,12 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
                     .build()));
                     .build()));
         }
         }
 
 
-        final Optional<AutoscalingDeciderResult> maybeScaleDown = checkForScaleDown(nodeLoads, largestJob, currentScale, reasonBuilder)
+        final Optional<AutoscalingDeciderResult> maybeScaleDown = checkForScaleDown(
+            nodeLoads,
+            largestJobOrModel,
+            currentScale,
+            reasonBuilder
+        )
             // Due to weird rounding errors, it may be that a scale down result COULD cause a scale up
             // Due to weird rounding errors, it may be that a scale down result COULD cause a scale up
             // Ensuring the scaleDown here forces the scale down result to always be lower than the current capacity.
             // Ensuring the scaleDown here forces the scale down result to always be lower than the current capacity.
             // This is safe as we know that ALL jobs are assigned at the current capacity
             // This is safe as we know that ALL jobs are assigned at the current capacity
@@ -643,6 +667,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
     // can eventually start, and given the current cluster, no job can eventually start.
     // can eventually start, and given the current cluster, no job can eventually start.
     AutoscalingDeciderResult scaleUpFromZero(List<String> waitingAnomalyJobs,
     AutoscalingDeciderResult scaleUpFromZero(List<String> waitingAnomalyJobs,
                                              List<String> waitingAnalyticsJobs,
                                              List<String> waitingAnalyticsJobs,
+                                             List<String> waitingAllocatedModels,
                                              MlScalingReason.Builder reasonBuilder) {
                                              MlScalingReason.Builder reasonBuilder) {
         final Optional<NativeMemoryCapacity> analyticsCapacity = requiredCapacityForUnassignedJobs(waitingAnalyticsJobs,
         final Optional<NativeMemoryCapacity> analyticsCapacity = requiredCapacityForUnassignedJobs(waitingAnalyticsJobs,
             this::getAnalyticsMemoryRequirement,
             this::getAnalyticsMemoryRequirement,
@@ -650,9 +675,13 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
         final Optional<NativeMemoryCapacity> anomalyCapacity = requiredCapacityForUnassignedJobs(waitingAnomalyJobs,
         final Optional<NativeMemoryCapacity> anomalyCapacity = requiredCapacityForUnassignedJobs(waitingAnomalyJobs,
             this::getAnomalyMemoryRequirement,
             this::getAnomalyMemoryRequirement,
             0);
             0);
+        final Optional<NativeMemoryCapacity> allocatedModelCapacity = requiredCapacityForUnassignedJobs(waitingAllocatedModels,
+            this::getAllocatedModelRequirement,
+            0);
         NativeMemoryCapacity updatedCapacity = NativeMemoryCapacity.ZERO
         NativeMemoryCapacity updatedCapacity = NativeMemoryCapacity.ZERO
             .merge(anomalyCapacity.orElse(NativeMemoryCapacity.ZERO))
             .merge(anomalyCapacity.orElse(NativeMemoryCapacity.ZERO))
-            .merge(analyticsCapacity.orElse(NativeMemoryCapacity.ZERO));
+            .merge(analyticsCapacity.orElse(NativeMemoryCapacity.ZERO))
+            .merge(allocatedModelCapacity.orElse(NativeMemoryCapacity.ZERO));
         // If we still have calculated zero, this means the ml memory tracker does not have the required info.
         // If we still have calculated zero, this means the ml memory tracker does not have the required info.
         // So, request a scale for the default. This is only for the 0 -> N scaling case.
         // So, request a scale for the default. This is only for the 0 -> N scaling case.
         if (updatedCapacity.getNode() == 0L) {
         if (updatedCapacity.getNode() == 0L) {
@@ -681,13 +710,15 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
                                                        List<NodeLoad> nodeLoads,
                                                        List<NodeLoad> nodeLoads,
                                                        List<String> waitingAnomalyJobs,
                                                        List<String> waitingAnomalyJobs,
                                                        List<String> waitingAnalyticsJobs,
                                                        List<String> waitingAnalyticsJobs,
+                                                       List<String> waitingAllocatedModels,
                                                        @Nullable NativeMemoryCapacity futureFreedCapacity,
                                                        @Nullable NativeMemoryCapacity futureFreedCapacity,
                                                        NativeMemoryCapacity currentScale,
                                                        NativeMemoryCapacity currentScale,
                                                        MlScalingReason.Builder reasonBuilder) {
                                                        MlScalingReason.Builder reasonBuilder) {
 
 
         // Are we in breach of maximum waiting jobs?
         // Are we in breach of maximum waiting jobs?
         if (waitingAnalyticsJobs.size() > numAnalyticsJobsInQueue
         if (waitingAnalyticsJobs.size() > numAnalyticsJobsInQueue
-            || waitingAnomalyJobs.size() > numAnomalyJobsInQueue) {
+            || waitingAnomalyJobs.size() > numAnomalyJobsInQueue
+            || waitingAllocatedModels.size() > 0) {
 
 
             Tuple<NativeMemoryCapacity, List<NodeLoad>> anomalyCapacityAndNewLoad = determineUnassignableJobs(
             Tuple<NativeMemoryCapacity, List<NodeLoad>> anomalyCapacityAndNewLoad = determineUnassignableJobs(
                 waitingAnomalyJobs,
                 waitingAnomalyJobs,
@@ -701,8 +732,15 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
                 numAnalyticsJobsInQueue,
                 numAnalyticsJobsInQueue,
                 anomalyCapacityAndNewLoad.v2()).orElse(Tuple.tuple(NativeMemoryCapacity.ZERO, anomalyCapacityAndNewLoad.v2()));
                 anomalyCapacityAndNewLoad.v2()).orElse(Tuple.tuple(NativeMemoryCapacity.ZERO, anomalyCapacityAndNewLoad.v2()));
 
 
+            Tuple<NativeMemoryCapacity, List<NodeLoad>> modelCapacityAndNewLoad = determineUnassignableJobs(
+                waitingAllocatedModels,
+                this::getAllocatedModelRequirement,
+                0,
+                analyticsCapacityAndNewLoad.v2()).orElse(Tuple.tuple(NativeMemoryCapacity.ZERO, analyticsCapacityAndNewLoad.v2()));
+
             if (analyticsCapacityAndNewLoad.v1().equals(NativeMemoryCapacity.ZERO)
             if (analyticsCapacityAndNewLoad.v1().equals(NativeMemoryCapacity.ZERO)
-                && anomalyCapacityAndNewLoad.v1().equals(NativeMemoryCapacity.ZERO)) {
+                && anomalyCapacityAndNewLoad.v1().equals(NativeMemoryCapacity.ZERO)
+                && modelCapacityAndNewLoad.v1().equals(NativeMemoryCapacity.ZERO)) {
                 logger.debug("no_scale event as current capacity, even though there are waiting jobs, is adequate to run the queued jobs");
                 logger.debug("no_scale event as current capacity, even though there are waiting jobs, is adequate to run the queued jobs");
                 return Optional.empty();
                 return Optional.empty();
             }
             }
@@ -710,6 +748,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
             NativeMemoryCapacity updatedCapacity = NativeMemoryCapacity.from(currentScale)
             NativeMemoryCapacity updatedCapacity = NativeMemoryCapacity.from(currentScale)
                 .merge(analyticsCapacityAndNewLoad.v1())
                 .merge(analyticsCapacityAndNewLoad.v1())
                 .merge(anomalyCapacityAndNewLoad.v1())
                 .merge(anomalyCapacityAndNewLoad.v1())
+                .merge(modelCapacityAndNewLoad.v1())
                 // Since we require new capacity, it COULD be we require a brand new node
                 // Since we require new capacity, it COULD be we require a brand new node
                 // We should account for overhead in the tier capacity just in case.
                 // We should account for overhead in the tier capacity just in case.
                 .merge(new NativeMemoryCapacity(MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes(), 0));
                 .merge(new NativeMemoryCapacity(MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes(), 0));
@@ -720,13 +759,15 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
                     .setRequiredCapacity(requiredCapacity)
                     .setRequiredCapacity(requiredCapacity)
                     .setSimpleReason(
                     .setSimpleReason(
                         "requesting scale up as number of jobs in queues exceeded configured limit "
                         "requesting scale up as number of jobs in queues exceeded configured limit "
-                            + "and current capacity is not large enough for waiting jobs"
+                            + "or there is at least one trained model waiting for allocation "
+                            + "and current capacity is not large enough for waiting jobs or models"
                     )
                     )
                     .build()
                     .build()
             ));
             ));
         }
         }
 
 
         // Could the currently waiting jobs ever be assigned?
         // Could the currently waiting jobs ever be assigned?
+        // NOTE: the previous predicate catches if an allocated model isn't assigned
         if (waitingAnalyticsJobs.isEmpty() == false || waitingAnomalyJobs.isEmpty() == false) {
         if (waitingAnalyticsJobs.isEmpty() == false || waitingAnomalyJobs.isEmpty() == false) {
             // we are unable to determine new tier size, but maybe we can see if our nodes are big enough.
             // we are unable to determine new tier size, but maybe we can see if our nodes are big enough.
             if (futureFreedCapacity == null) {
             if (futureFreedCapacity == null) {
@@ -861,6 +902,10 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
         return mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(analyticsId);
         return mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(analyticsId);
     }
     }
 
 
+    private Long getAllocatedModelRequirement(String modelId) {
+        return mlMemoryTracker.getTrainedModelAllocationMemoryRequirement(modelId);
+    }
+
     private Long getAnalyticsMemoryRequirement(PersistentTask<?> task) {
     private Long getAnalyticsMemoryRequirement(PersistentTask<?> task) {
         return getAnalyticsMemoryRequirement(MlTasks.dataFrameAnalyticsId(task.getId()));
         return getAnalyticsMemoryRequirement(MlTasks.dataFrameAnalyticsId(task.getId()));
     }
     }

+ 23 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlScalingReason.java

@@ -7,6 +7,7 @@
 
 
 package org.elasticsearch.xpack.ml.autoscaling;
 package org.elasticsearch.xpack.ml.autoscaling;
 
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
@@ -25,6 +26,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
     public static final String NAME = MlAutoscalingDeciderService.NAME;
     public static final String NAME = MlAutoscalingDeciderService.NAME;
     static final String WAITING_ANALYTICS_JOBS = "waiting_analytics_jobs";
     static final String WAITING_ANALYTICS_JOBS = "waiting_analytics_jobs";
     static final String WAITING_ANOMALY_JOBS = "waiting_anomaly_jobs";
     static final String WAITING_ANOMALY_JOBS = "waiting_anomaly_jobs";
+    static final String WAITING_MODELS = "waiting_models";
     static final String CONFIGURATION = "configuration";
     static final String CONFIGURATION = "configuration";
     static final String LARGEST_WAITING_ANALYTICS_JOB = "largest_waiting_analytics_job";
     static final String LARGEST_WAITING_ANALYTICS_JOB = "largest_waiting_analytics_job";
     static final String LARGEST_WAITING_ANOMALY_JOB = "largest_waiting_anomaly_job";
     static final String LARGEST_WAITING_ANOMALY_JOB = "largest_waiting_anomaly_job";
@@ -34,6 +36,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
 
 
     private final List<String> waitingAnalyticsJobs;
     private final List<String> waitingAnalyticsJobs;
     private final List<String> waitingAnomalyJobs;
     private final List<String> waitingAnomalyJobs;
+    private final List<String> waitingModels;
     private final Settings passedConfiguration;
     private final Settings passedConfiguration;
     private final Long largestWaitingAnalyticsJob;
     private final Long largestWaitingAnalyticsJob;
     private final Long largestWaitingAnomalyJob;
     private final Long largestWaitingAnomalyJob;
@@ -44,6 +47,11 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
     public MlScalingReason(StreamInput in) throws IOException {
     public MlScalingReason(StreamInput in) throws IOException {
         this.waitingAnalyticsJobs = in.readStringList();
         this.waitingAnalyticsJobs = in.readStringList();
         this.waitingAnomalyJobs = in.readStringList();
         this.waitingAnomalyJobs = in.readStringList();
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            this.waitingModels = in.readStringList();
+        } else {
+            this.waitingModels = List.of();
+        }
         this.passedConfiguration = Settings.readSettingsFromStream(in);;
         this.passedConfiguration = Settings.readSettingsFromStream(in);;
         this.currentMlCapacity = new AutoscalingCapacity(in);
         this.currentMlCapacity = new AutoscalingCapacity(in);
         this.requiredCapacity = in.readOptionalWriteable(AutoscalingCapacity::new);
         this.requiredCapacity = in.readOptionalWriteable(AutoscalingCapacity::new);
@@ -54,6 +62,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
 
 
     MlScalingReason(List<String> waitingAnalyticsJobs,
     MlScalingReason(List<String> waitingAnalyticsJobs,
                     List<String> waitingAnomalyJobs,
                     List<String> waitingAnomalyJobs,
+                    List<String> waitingModels,
                     Settings passedConfiguration,
                     Settings passedConfiguration,
                     Long largestWaitingAnalyticsJob,
                     Long largestWaitingAnalyticsJob,
                     Long largestWaitingAnomalyJob,
                     Long largestWaitingAnomalyJob,
@@ -62,6 +71,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
                     String simpleReason) {
                     String simpleReason) {
         this.waitingAnalyticsJobs = waitingAnalyticsJobs == null ? Collections.emptyList() : waitingAnalyticsJobs;
         this.waitingAnalyticsJobs = waitingAnalyticsJobs == null ? Collections.emptyList() : waitingAnalyticsJobs;
         this.waitingAnomalyJobs = waitingAnomalyJobs == null ? Collections.emptyList() : waitingAnomalyJobs;
         this.waitingAnomalyJobs = waitingAnomalyJobs == null ? Collections.emptyList() : waitingAnomalyJobs;
+        this.waitingModels = waitingModels == null ? List.of() : waitingModels;
         this.passedConfiguration = ExceptionsHelper.requireNonNull(passedConfiguration, CONFIGURATION);
         this.passedConfiguration = ExceptionsHelper.requireNonNull(passedConfiguration, CONFIGURATION);
         this.largestWaitingAnalyticsJob = largestWaitingAnalyticsJob;
         this.largestWaitingAnalyticsJob = largestWaitingAnalyticsJob;
         this.largestWaitingAnomalyJob = largestWaitingAnomalyJob;
         this.largestWaitingAnomalyJob = largestWaitingAnomalyJob;
@@ -81,6 +91,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
         MlScalingReason that = (MlScalingReason) o;
         MlScalingReason that = (MlScalingReason) o;
         return Objects.equals(waitingAnalyticsJobs, that.waitingAnalyticsJobs) &&
         return Objects.equals(waitingAnalyticsJobs, that.waitingAnalyticsJobs) &&
             Objects.equals(waitingAnomalyJobs, that.waitingAnomalyJobs) &&
             Objects.equals(waitingAnomalyJobs, that.waitingAnomalyJobs) &&
+            Objects.equals(waitingModels, that.waitingModels) &&
             Objects.equals(passedConfiguration, that.passedConfiguration) &&
             Objects.equals(passedConfiguration, that.passedConfiguration) &&
             Objects.equals(largestWaitingAnalyticsJob, that.largestWaitingAnalyticsJob) &&
             Objects.equals(largestWaitingAnalyticsJob, that.largestWaitingAnalyticsJob) &&
             Objects.equals(largestWaitingAnomalyJob, that.largestWaitingAnomalyJob) &&
             Objects.equals(largestWaitingAnomalyJob, that.largestWaitingAnomalyJob) &&
@@ -95,6 +106,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
             waitingAnomalyJobs,
             waitingAnomalyJobs,
             passedConfiguration,
             passedConfiguration,
             largestWaitingAnalyticsJob,
             largestWaitingAnalyticsJob,
+            waitingModels,
             largestWaitingAnomalyJob,
             largestWaitingAnomalyJob,
             currentMlCapacity,
             currentMlCapacity,
             requiredCapacity,
             requiredCapacity,
@@ -115,6 +127,9 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
     public void writeTo(StreamOutput out) throws IOException {
     public void writeTo(StreamOutput out) throws IOException {
         out.writeStringCollection(this.waitingAnalyticsJobs);
         out.writeStringCollection(this.waitingAnalyticsJobs);
         out.writeStringCollection(this.waitingAnomalyJobs);
         out.writeStringCollection(this.waitingAnomalyJobs);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeStringCollection(this.waitingModels);
+        }
         Settings.writeSettingsToStream(this.passedConfiguration, out);
         Settings.writeSettingsToStream(this.passedConfiguration, out);
         this.currentMlCapacity.writeTo(out);
         this.currentMlCapacity.writeTo(out);
         out.writeOptionalWriteable(this.requiredCapacity);
         out.writeOptionalWriteable(this.requiredCapacity);
@@ -128,6 +143,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
         builder.startObject();
         builder.startObject();
         builder.field(WAITING_ANALYTICS_JOBS, waitingAnalyticsJobs);
         builder.field(WAITING_ANALYTICS_JOBS, waitingAnalyticsJobs);
         builder.field(WAITING_ANOMALY_JOBS, waitingAnomalyJobs);
         builder.field(WAITING_ANOMALY_JOBS, waitingAnomalyJobs);
+        builder.field(WAITING_MODELS, waitingModels);
         builder.startObject(CONFIGURATION).value(passedConfiguration).endObject();
         builder.startObject(CONFIGURATION).value(passedConfiguration).endObject();
         if (largestWaitingAnalyticsJob != null) {
         if (largestWaitingAnalyticsJob != null) {
             builder.field(LARGEST_WAITING_ANALYTICS_JOB, largestWaitingAnalyticsJob);
             builder.field(LARGEST_WAITING_ANALYTICS_JOB, largestWaitingAnalyticsJob);
@@ -152,6 +168,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
     static class Builder {
     static class Builder {
         private List<String> waitingAnalyticsJobs = Collections.emptyList();
         private List<String> waitingAnalyticsJobs = Collections.emptyList();
         private List<String> waitingAnomalyJobs = Collections.emptyList();
         private List<String> waitingAnomalyJobs = Collections.emptyList();
+        private List<String> waitingModels = Collections.emptyList();
         private Settings passedConfiguration;
         private Settings passedConfiguration;
         private Long largestWaitingAnalyticsJob;
         private Long largestWaitingAnalyticsJob;
         private Long largestWaitingAnomalyJob;
         private Long largestWaitingAnomalyJob;
@@ -169,6 +186,11 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
             return this;
             return this;
         }
         }
 
 
+        public Builder setWaitingModels(List<String> waitingModels) {
+            this.waitingModels = waitingModels;
+            return this;
+        }
+
         public Builder setPassedConfiguration(Settings passedConfiguration) {
         public Builder setPassedConfiguration(Settings passedConfiguration) {
             this.passedConfiguration = passedConfiguration;
             this.passedConfiguration = passedConfiguration;
             return this;
             return this;
@@ -203,6 +225,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
             return new MlScalingReason(
             return new MlScalingReason(
                 waitingAnalyticsJobs,
                 waitingAnalyticsJobs,
                 waitingAnomalyJobs,
                 waitingAnomalyJobs,
+                waitingModels,
                 passedConfiguration,
                 passedConfiguration,
                 largestWaitingAnalyticsJob,
                 largestWaitingAnalyticsJob,
                 largestWaitingAnomalyJob,
                 largestWaitingAnomalyJob,

+ 99 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderServiceTests.java

@@ -70,6 +70,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
     private static final long DEFAULT_NODE_SIZE = ByteSizeValue.ofGb(20).getBytes();
     private static final long DEFAULT_NODE_SIZE = ByteSizeValue.ofGb(20).getBytes();
     private static final long DEFAULT_JVM_SIZE = ByteSizeValue.ofMb((long)(DEFAULT_NODE_SIZE * 0.25)).getBytes();
     private static final long DEFAULT_JVM_SIZE = ByteSizeValue.ofMb((long)(DEFAULT_NODE_SIZE * 0.25)).getBytes();
     private static final long DEFAULT_JOB_SIZE = ByteSizeValue.ofMb(200).getBytes();
     private static final long DEFAULT_JOB_SIZE = ByteSizeValue.ofMb(200).getBytes();
+    private static final long DEFAULT_MODEL_SIZE = ByteSizeValue.ofMb(200).getBytes();
     private static final long OVERHEAD = ByteSizeValue.ofMb(30).getBytes();
     private static final long OVERHEAD = ByteSizeValue.ofMb(30).getBytes();
     private NodeLoadDetector nodeLoadDetector;
     private NodeLoadDetector nodeLoadDetector;
     private ClusterService clusterService;
     private ClusterService clusterService;
@@ -84,6 +85,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
         when(mlMemoryTracker.asyncRefresh()).thenReturn(true);
         when(mlMemoryTracker.asyncRefresh()).thenReturn(true);
         when(mlMemoryTracker.getAnomalyDetectorJobMemoryRequirement(any())).thenReturn(DEFAULT_JOB_SIZE);
         when(mlMemoryTracker.getAnomalyDetectorJobMemoryRequirement(any())).thenReturn(DEFAULT_JOB_SIZE);
         when(mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(any())).thenReturn(DEFAULT_JOB_SIZE);
         when(mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(any())).thenReturn(DEFAULT_JOB_SIZE);
+        when(mlMemoryTracker.getTrainedModelAllocationMemoryRequirement(any())).thenReturn(DEFAULT_JOB_SIZE);
         nodeLoadDetector = mock(NodeLoadDetector.class);
         nodeLoadDetector = mock(NodeLoadDetector.class);
         when(nodeLoadDetector.getMlMemoryTracker()).thenReturn(mlMemoryTracker);
         when(nodeLoadDetector.getMlMemoryTracker()).thenReturn(mlMemoryTracker);
         when(nodeLoadDetector.detectNodeLoad(any(), anyBoolean(), any(), anyInt(), anyInt(), anyBoolean()))
         when(nodeLoadDetector.detectNodeLoad(any(), anyBoolean(), any(), anyInt(), anyInt(), anyBoolean()))
@@ -121,6 +123,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
             Collections.emptyList(),
             Collections.emptyList(),
             Collections.emptyList(),
             Collections.emptyList(),
             Collections.emptyList(),
             Collections.emptyList(),
+            Collections.emptyList(),
             null,
             null,
             NativeMemoryCapacity.ZERO,
             NativeMemoryCapacity.ZERO,
             MlScalingReason.builder()),
             MlScalingReason.builder()),
@@ -148,6 +151,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 fullyLoadedNode,
                 fullyLoadedNode,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 null,
                 null,
                 NativeMemoryCapacity.ZERO,
                 NativeMemoryCapacity.ZERO,
                 reasonBuilder);
                 reasonBuilder);
@@ -171,6 +175,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 fullyLoadedNode,
                 fullyLoadedNode,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 null,
                 null,
                 NativeMemoryCapacity.ZERO,
                 NativeMemoryCapacity.ZERO,
                 reasonBuilder);
                 reasonBuilder);
@@ -194,6 +199,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 fullyLoadedNode,
                 fullyLoadedNode,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 null,
                 null,
                 NativeMemoryCapacity.ZERO,
                 NativeMemoryCapacity.ZERO,
                 reasonBuilder);
                 reasonBuilder);
@@ -237,6 +243,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 nodesWithRoom,
                 nodesWithRoom,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 null,
                 null,
                 NativeMemoryCapacity.ZERO,
                 NativeMemoryCapacity.ZERO,
                 reasonBuilder);
                 reasonBuilder);
@@ -250,6 +257,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 nodesWithRoom,
                 nodesWithRoom,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 null,
                 null,
                 NativeMemoryCapacity.ZERO,
                 NativeMemoryCapacity.ZERO,
                 reasonBuilder);
                 reasonBuilder);
@@ -260,6 +268,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 nodesWithRoom,
                 nodesWithRoom,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 null,
                 null,
                 NativeMemoryCapacity.ZERO,
                 NativeMemoryCapacity.ZERO,
                 reasonBuilder);
                 reasonBuilder);
@@ -284,6 +293,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 fullyLoadedNode,
                 fullyLoadedNode,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 null,
                 null,
                 NativeMemoryCapacity.ZERO,
                 NativeMemoryCapacity.ZERO,
                 reasonBuilder);
                 reasonBuilder);
@@ -297,6 +307,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 fullyLoadedNode,
                 fullyLoadedNode,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 null,
                 null,
                 NativeMemoryCapacity.ZERO,
                 NativeMemoryCapacity.ZERO,
                 reasonBuilder);
                 reasonBuilder);
@@ -310,6 +321,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 fullyLoadedNode,
                 fullyLoadedNode,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 null,
                 null,
                 NativeMemoryCapacity.ZERO,
                 NativeMemoryCapacity.ZERO,
                 reasonBuilder);
                 reasonBuilder);
@@ -336,6 +348,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 fullyLoadedNode,
                 fullyLoadedNode,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 null,
                 null,
                 NativeMemoryCapacity.ZERO,
                 NativeMemoryCapacity.ZERO,
                 reasonBuilder);
                 reasonBuilder);
@@ -348,6 +361,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 fullyLoadedNode,
                 fullyLoadedNode,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 new NativeMemoryCapacity(ByteSizeValue.ofGb(3).getBytes(), ByteSizeValue.ofGb(1).getBytes()),
                 new NativeMemoryCapacity(ByteSizeValue.ofGb(3).getBytes(), ByteSizeValue.ofGb(1).getBytes()),
                 new NativeMemoryCapacity(ByteSizeValue.ofGb(2).getBytes(), ByteSizeValue.ofGb(2).getBytes()),
                 new NativeMemoryCapacity(ByteSizeValue.ofGb(2).getBytes(), ByteSizeValue.ofGb(2).getBytes()),
                 reasonBuilder);
                 reasonBuilder);
@@ -358,6 +372,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 fullyLoadedNode,
                 fullyLoadedNode,
                 jobTasks,
                 jobTasks,
                 analytics,
                 analytics,
+                Collections.emptyList(),
                 new NativeMemoryCapacity(ByteSizeValue.ofMb(1).getBytes(), ByteSizeValue.ofMb(1).getBytes()),
                 new NativeMemoryCapacity(ByteSizeValue.ofMb(1).getBytes(), ByteSizeValue.ofMb(1).getBytes()),
                 new NativeMemoryCapacity(ByteSizeValue.ofGb(2).getBytes(), ByteSizeValue.ofGb(2).getBytes()),
                 new NativeMemoryCapacity(ByteSizeValue.ofGb(2).getBytes(), ByteSizeValue.ofGb(2).getBytes()),
                 reasonBuilder);
                 reasonBuilder);
@@ -367,6 +382,90 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
         }
         }
     }
     }
 
 
+    public void testScaleUp_withWaitingModelAndAutoMemoryAndNoRoomInNodes() {
+        when(mlMemoryTracker.getTrainedModelAllocationMemoryRequirement(any())).thenReturn(ByteSizeValue.ofGb(2).getBytes());
+        List<NodeLoad> fullyLoadedNode = Arrays.asList(
+            NodeLoad.builder("any")
+                .setMaxMemory(ByteSizeValue.ofGb(1).getBytes())
+                .setUseMemory(true)
+                .incAssignedJobMemory(ByteSizeValue.ofGb(1).getBytes())
+                .build()
+        );
+        MlScalingReason.Builder reasonBuilder = new MlScalingReason.Builder().setPassedConfiguration(Settings.EMPTY)
+            .setCurrentMlCapacity(AutoscalingCapacity.ZERO);
+        MlAutoscalingDeciderService service = buildService();
+        service.setUseAuto(true);
+        Optional<AutoscalingDeciderResult> decision = service.checkForScaleUp(
+            0,
+            0,
+            fullyLoadedNode,
+            Collections.emptyList(),
+            Collections.emptyList(),
+            List.of("foo"),
+            null,
+            NativeMemoryCapacity.ZERO,
+            reasonBuilder
+        );
+        assertFalse(decision.isEmpty());
+        AutoscalingDeciderResult result = decision.get();
+        long allowedBytesForMlNode = NativeMemoryCalculator.allowedBytesForMl(
+            result.requiredCapacity().node().memory().getBytes(),
+            30,
+            true
+        );
+        long allowedBytesForMlTier = NativeMemoryCalculator.allowedBytesForMl(
+            result.requiredCapacity().total().memory().getBytes(),
+            30,
+            true
+        );
+        assertThat(allowedBytesForMlNode, greaterThanOrEqualTo(ByteSizeValue.ofGb(2).getBytes() + OVERHEAD));
+        assertThat(allowedBytesForMlTier, greaterThanOrEqualTo(ByteSizeValue.ofGb(2).getBytes() + OVERHEAD));
+    }
+
+    public void testScaleUp_withWaitingModelsAndRoomInNodes() {
+        MlScalingReason.Builder reasonBuilder = new MlScalingReason.Builder().setPassedConfiguration(Settings.EMPTY)
+            .setCurrentMlCapacity(AutoscalingCapacity.ZERO);
+        List<NodeLoad> nodesWithRoom = Arrays.asList(
+            NodeLoad.builder("partially_filled")
+                .setMaxMemory(ByteSizeValue.ofMb(430).getBytes())
+                .setUseMemory(true)
+                .setMaxJobs(10)
+                .incNumAssignedJobs()
+                .incAssignedJobMemory(ByteSizeValue.ofMb(230).getBytes())
+                .build(),
+            NodeLoad.builder("not_filled").setMaxMemory(ByteSizeValue.ofMb(230).getBytes()).setMaxJobs(10).setUseMemory(true).build()
+        );
+        MlAutoscalingDeciderService service = buildService();
+        service.setMaxMachineMemoryPercent(25);
+        Optional<AutoscalingDeciderResult> decision = service.checkForScaleUp(
+            0,
+            0,
+            nodesWithRoom,
+            Collections.emptyList(),
+            Collections.emptyList(),
+            List.of("foo", "bar", "baz"),
+            null,
+            NativeMemoryCapacity.ZERO,
+            reasonBuilder
+        );
+        assertTrue(decision.isPresent());
+        assertThat(decision.get().requiredCapacity().node().memory().getBytes(), equalTo((DEFAULT_JOB_SIZE + OVERHEAD) * 4));
+        assertThat(decision.get().requiredCapacity().total().memory().getBytes(), equalTo(4 * (DEFAULT_JOB_SIZE + OVERHEAD)));
+        assertFalse(
+            service.checkForScaleUp(
+                1,
+                0,
+                nodesWithRoom,
+                Collections.emptyList(),
+                Collections.emptyList(),
+                List.of("foo", "bar"),
+                null,
+                NativeMemoryCapacity.ZERO,
+                reasonBuilder
+            ).isPresent()
+        );
+    }
+
     public void testScaleDown() {
     public void testScaleDown() {
         List<NodeLoad> nodeLoads = Arrays.asList(
         List<NodeLoad> nodeLoads = Arrays.asList(
             NodeLoad.builder("foo").setMaxMemory(DEFAULT_NODE_SIZE).incAssignedJobMemory(ByteSizeValue.ofGb(1).getBytes()).build(),
             NodeLoad.builder("foo").setMaxMemory(DEFAULT_NODE_SIZE).incAssignedJobMemory(ByteSizeValue.ofGb(1).getBytes()).build(),

+ 1 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlScalingReasonTests.java

@@ -35,6 +35,7 @@ public class MlScalingReasonTests extends AbstractWireSerializingTestCase<MlScal
     @Override
     @Override
     protected MlScalingReason createTestInstance() {
     protected MlScalingReason createTestInstance() {
         return new MlScalingReason(
         return new MlScalingReason(
+            randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()),
             randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()),
             randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()),
             randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()),
             randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()),
             randomConfiguration(),
             randomConfiguration(),