Browse Source

[ML] Make model snapshot upgrade autoscaling friendly (#81123)

Model snapshot upgrade was not taking autoscaling into account
when doing node assignment.

Fixes #81012
David Roberts 3 years ago
parent
commit
d031cfde1b
20 changed files with 526 additions and 129 deletions
  1. 22 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java
  2. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/snapshot/upgrade/SnapshotUpgradeTaskParams.java
  3. 24 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java
  4. 3 3
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AutoscalingIT.java
  5. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  6. 1 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpgradeJobModelSnapshotAction.java
  7. 82 34
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderService.java
  8. 20 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlScalingReason.java
  9. 48 29
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/NativeMemoryCapacity.java
  10. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java
  11. 1 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradePredicate.java
  12. 3 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java
  13. 39 19
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java
  14. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/task/AbstractJobPersistentTasksExecutor.java
  15. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java
  16. 156 13
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderServiceTests.java
  17. 1 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlScalingReasonTests.java
  18. 6 6
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/NativeMemoryCapacityTests.java
  19. 1 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradePredicateTests.java
  20. 112 4
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java

+ 22 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java

@@ -181,6 +181,27 @@ public final class MlTasks {
         return jobState;
     }
 
+    public static SnapshotUpgradeState getSnapshotUpgradeState(
+        String jobId,
+        String snapshotId,
+        @Nullable PersistentTasksCustomMetadata tasks
+    ) {
+        return getSnapshotUpgradeState(getSnapshotUpgraderTask(jobId, snapshotId, tasks));
+    }
+
+    public static SnapshotUpgradeState getSnapshotUpgradeState(@Nullable PersistentTasksCustomMetadata.PersistentTask<?> task) {
+        if (task == null) {
+            return SnapshotUpgradeState.STOPPED;
+        }
+        SnapshotUpgradeTaskState taskState = (SnapshotUpgradeTaskState) task.getState();
+        if (taskState == null) {
+            // If we haven't set a state yet then the task has never been assigned, so
+            // report that it's doing the first thing it does
+            return SnapshotUpgradeState.LOADING_OLD_STATE;
+        }
+        return taskState.getState();
+    }
+
     public static DatafeedState getDatafeedState(String datafeedId, @Nullable PersistentTasksCustomMetadata tasks) {
         PersistentTasksCustomMetadata.PersistentTask<?> task = getDatafeedTask(datafeedId, tasks);
         if (task == null) {
@@ -414,8 +435,7 @@ public final class MlTasks {
             case JOB_TASK_NAME:
                 return getJobStateModifiedForReassignments(task);
             case JOB_SNAPSHOT_UPGRADE_TASK_NAME:
-                SnapshotUpgradeTaskState taskState = (SnapshotUpgradeTaskState) task.getState();
-                return taskState == null ? SnapshotUpgradeState.LOADING_OLD_STATE : taskState.getState();
+                return getSnapshotUpgradeState(task);
             case DATA_FRAME_ANALYTICS_TASK_NAME:
                 return getDataFrameAnalyticsState(task);
             default:

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskParams.java → x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/snapshot/upgrade/SnapshotUpgradeTaskParams.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.ml.job.snapshot.upgrader;
+package org.elasticsearch.xpack.core.ml.job.snapshot.upgrade;
 
 import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;

+ 24 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java

@@ -21,6 +21,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
 
 import java.net.InetAddress;
 
@@ -71,6 +74,27 @@ public class MlTasksTests extends ESTestCase {
         assertEquals(DatafeedState.STARTED, MlTasks.getDatafeedState("foo", tasksBuilder.build()));
     }
 
+    public void testGetSnapshotUpgradeState() {
+        PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder();
+        // A missing task is a stopped snapshot upgrade
+        assertEquals(SnapshotUpgradeState.STOPPED, MlTasks.getSnapshotUpgradeState("foo", "1", tasksBuilder.build()));
+
+        tasksBuilder.addTask(
+            MlTasks.snapshotUpgradeTaskId("foo", "1"),
+            MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME,
+            new SnapshotUpgradeTaskParams("foo", "1"),
+            new PersistentTasksCustomMetadata.Assignment("bar", "test assignment")
+        );
+        // A task with no state means the datafeed is starting
+        assertEquals(SnapshotUpgradeState.LOADING_OLD_STATE, MlTasks.getSnapshotUpgradeState("foo", "1", tasksBuilder.build()));
+
+        tasksBuilder.updateTaskState(
+            MlTasks.snapshotUpgradeTaskId("foo", "1"),
+            new SnapshotUpgradeTaskState(SnapshotUpgradeState.SAVING_NEW_STATE, tasksBuilder.getLastAllocationId(), null)
+        );
+        assertEquals(SnapshotUpgradeState.SAVING_NEW_STATE, MlTasks.getSnapshotUpgradeState("foo", "1", tasksBuilder.build()));
+    }
+
     public void testGetJobTask() {
         assertNull(MlTasks.getJobTask("foo", null));
 

+ 3 - 3
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AutoscalingIT.java

@@ -137,8 +137,8 @@ public class AutoscalingIT extends MlNativeAutodetectIntegTestCase {
             .collect(Collectors.toList());
         NativeMemoryCapacity currentScale = MlAutoscalingDeciderService.currentScale(mlNodes, 30, false);
         expectedTierBytes = (long) Math.ceil(
-            (ByteSizeValue.ofMb(50_000 + BASIC_REQUIREMENT_MB + 60_000 + BASELINE_OVERHEAD_MB).getBytes() + currentScale.getTier()) * 100
-                / 30.0
+            (ByteSizeValue.ofMb(50_000 + BASIC_REQUIREMENT_MB + 60_000 + BASELINE_OVERHEAD_MB).getBytes() + currentScale
+                .getTierMlNativeMemoryRequirement()) * 100 / 30.0
         );
         expectedNodeBytes = (long) (ByteSizeValue.ofMb(60_000 + BASELINE_OVERHEAD_MB).getBytes() * 100 / 30.0);
 
@@ -215,7 +215,7 @@ public class AutoscalingIT extends MlNativeAutodetectIntegTestCase {
             .collect(Collectors.toList());
         NativeMemoryCapacity currentScale = MlAutoscalingDeciderService.currentScale(mlNodes, 30, false);
         expectedTierBytes = (long) Math.ceil(
-            (ByteSizeValue.ofMb(50_000 + BASIC_REQUIREMENT_MB).getBytes() + currentScale.getTier()) * 100 / 30.0
+            (ByteSizeValue.ofMb(50_000 + BASIC_REQUIREMENT_MB).getBytes() + currentScale.getTierMlNativeMemoryRequirement()) * 100 / 30.0
         );
         expectedNodeBytes = (long) (ByteSizeValue.ofMb(50_000 + BASELINE_OVERHEAD_MB).getBytes() * 100 / 30.0);
 

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -178,6 +178,7 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvide
 import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.template.TemplateUtils;
@@ -324,7 +325,6 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NativeNormalizerProcess
 import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory;
 import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory;
 import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor;
-import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor;
 import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
 import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;

+ 1 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpgradeJobModelSnapshotAction.java

@@ -45,19 +45,15 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
 import org.elasticsearch.xpack.core.ml.job.results.Result;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.job.persistence.JobConfigProvider;
 import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider;
 import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradePredicate;
-import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 
 public class TransportUpgradeJobModelSnapshotAction extends TransportMasterNodeAction<Request, Response> {
 
-    // If the snapshot is from any version other than the current major, we consider it for upgrade.
-    // This is to support upgrading to the NEXT major without worry
-    private static final byte UPGRADE_FROM_MAJOR = Version.CURRENT.major;
-
     private static final Logger logger = LogManager.getLogger(TransportUpgradeJobModelSnapshotAction.class);
 
     private final XPackLicenseState licenseState;

+ 82 - 34
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderService.java

@@ -14,7 +14,6 @@ import org.elasticsearch.cluster.LocalNodeMasterListener;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.common.component.LifecycleListener;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
@@ -29,12 +28,16 @@ import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderContext;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderResult;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderService;
 import org.elasticsearch.xpack.core.ml.MlTasks;
+import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
+import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction.DatafeedParams;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.job.NodeLoad;
@@ -65,6 +68,7 @@ import java.util.stream.Stream;
 
 import static org.elasticsearch.xpack.core.ml.MlTasks.getDataFrameAnalyticsState;
 import static org.elasticsearch.xpack.core.ml.MlTasks.getJobStateModifiedForReassignments;
+import static org.elasticsearch.xpack.core.ml.MlTasks.getSnapshotUpgradeState;
 import static org.elasticsearch.xpack.ml.MachineLearning.MAX_OPEN_JOBS_PER_NODE;
 import static org.elasticsearch.xpack.ml.MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD;
 import static org.elasticsearch.xpack.ml.job.JobNodeSelector.AWAITING_LAZY_ASSIGNMENT;
@@ -88,7 +92,6 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
     private final LongSupplier timeSupplier;
 
     private volatile boolean isMaster;
-    private volatile boolean running;
     private volatile int maxMachineMemoryPercent;
     private volatile int maxOpenJobs;
     private volatile boolean useAuto;
@@ -117,17 +120,6 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
         clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_OPEN_JOBS_PER_NODE, this::setMaxOpenJobs);
         clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT, this::setUseAuto);
         clusterService.addLocalNodeMasterListener(this);
-        clusterService.addLifecycleListener(new LifecycleListener() {
-            @Override
-            public void afterStart() {
-                running = true;
-            }
-
-            @Override
-            public void beforeStop() {
-                running = false;
-            }
-        });
     }
 
     static OptionalLong getNodeJvmSize(DiscoveryNode node) {
@@ -197,7 +189,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
         }
         PriorityQueue<NodeLoad.Builder> mostFreeMemoryFirst = new PriorityQueue<>(
             nodeLoads.size(),
-            // If we have no more remaining jobs, its the same as having no more free memory
+            // If we have no more remaining jobs, it's the same as having no more free memory
             Comparator.<NodeLoad.Builder>comparingLong(v -> v.remainingJobs() == 0 ? 0L : v.getFreeMemory()).reversed()
         );
         for (NodeLoad load : nodeLoads) {
@@ -258,6 +250,14 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
         return tasksCustomMetadata.findTasks(MlTasks.JOB_TASK_NAME, t -> taskStateFilter(getJobStateModifiedForReassignments(t)));
     }
 
+    private static Collection<PersistentTask<?>> snapshotUpgradeTasks(PersistentTasksCustomMetadata tasksCustomMetadata) {
+        if (tasksCustomMetadata == null) {
+            return Collections.emptyList();
+        }
+
+        return tasksCustomMetadata.findTasks(MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME, t -> taskStateFilter(getSnapshotUpgradeState(t)));
+    }
+
     private static Collection<PersistentTask<?>> dataframeAnalyticsTasks(PersistentTasksCustomMetadata tasksCustomMetadata) {
         if (tasksCustomMetadata == null) {
             return Collections.emptyList();
@@ -352,15 +352,20 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
 
         PersistentTasksCustomMetadata tasks = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
         Collection<PersistentTask<?>> anomalyDetectionTasks = anomalyDetectionTasks(tasks);
+        Collection<PersistentTask<?>> snapshotUpgradeTasks = snapshotUpgradeTasks(tasks);
         Collection<PersistentTask<?>> dataframeAnalyticsTasks = dataframeAnalyticsTasks(tasks);
         Map<String, TrainedModelAllocation> modelAllocations = TrainedModelAllocationMetadata.fromState(clusterState).modelAllocations();
         final List<String> waitingAnomalyJobs = anomalyDetectionTasks.stream()
             .filter(t -> AWAITING_LAZY_ASSIGNMENT.equals(t.getAssignment()))
-            .map(t -> MlTasks.jobId(t.getId()))
+            .map(t -> ((OpenJobAction.JobParams) t.getParams()).getJobId())
+            .collect(Collectors.toList());
+        final List<String> waitingSnapshotUpgrades = snapshotUpgradeTasks.stream()
+            .filter(t -> AWAITING_LAZY_ASSIGNMENT.equals(t.getAssignment()))
+            .map(t -> ((SnapshotUpgradeTaskParams) t.getParams()).getJobId())
             .collect(Collectors.toList());
         final List<String> waitingAnalyticsJobs = dataframeAnalyticsTasks.stream()
             .filter(t -> AWAITING_LAZY_ASSIGNMENT.equals(t.getAssignment()))
-            .map(t -> MlTasks.dataFrameAnalyticsId(t.getId()))
+            .map(t -> ((StartDataFrameAnalyticsAction.TaskParams) t.getParams()).getId())
             .collect(Collectors.toList());
         final List<String> waitingAllocatedModels = modelAllocations.entrySet()
             .stream()
@@ -377,6 +382,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
 
         final MlScalingReason.Builder reasonBuilder = MlScalingReason.builder()
             .setWaitingAnomalyJobs(waitingAnomalyJobs)
+            .setWaitingSnapshotUpgrades(waitingSnapshotUpgrades)
             .setWaitingAnalyticsJobs(waitingAnalyticsJobs)
             .setWaitingModels(waitingAllocatedModels)
             .setCurrentMlCapacity(currentScale.autoscalingCapacity(maxMachineMemoryPercent, useAuto))
@@ -385,9 +391,16 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
         // There are no ML nodes, scale up as quick as possible, no matter if memory is stale or not
         if (nodes.isEmpty()
             && (waitingAnomalyJobs.isEmpty() == false
+                || waitingSnapshotUpgrades.isEmpty() == false
                 || waitingAnalyticsJobs.isEmpty() == false
                 || waitingAllocatedModels.isEmpty() == false)) {
-            return scaleUpFromZero(waitingAnomalyJobs, waitingAnalyticsJobs, waitingAllocatedModels, reasonBuilder);
+            return scaleUpFromZero(
+                waitingAnomalyJobs,
+                waitingSnapshotUpgrades,
+                waitingAnalyticsJobs,
+                waitingAllocatedModels,
+                reasonBuilder
+            );
         }
 
         // We don't need to check anything as there are no tasks
@@ -477,6 +490,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
             numAnalyticsJobsInQueue,
             nodeLoads,
             waitingAnomalyJobs,
+            waitingSnapshotUpgrades,
             waitingAnalyticsJobs,
             waitingAllocatedModels,
             futureFreedCapacity.orElse(null),
@@ -488,7 +502,9 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
             resetScaleDownCoolDown();
             return scaleUpDecision.get();
         }
-        if (waitingAnalyticsJobs.isEmpty() == false || waitingAnomalyJobs.isEmpty() == false) {
+        if (waitingAnalyticsJobs.isEmpty() == false
+            || waitingSnapshotUpgrades.isEmpty() == false
+            || waitingAnomalyJobs.isEmpty() == false) {
             // We don't want to continue to consider a scale down if there are now waiting jobs
             resetScaleDownCoolDown();
             return noScaleResultOrRefresh(
@@ -499,9 +515,11 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
                     reasonBuilder.setSimpleReason(
                         String.format(
                             Locale.ROOT,
-                            "Passing currently perceived capacity as there are [%d] analytics and [%d] anomaly jobs in the queue, "
+                            "Passing currently perceived capacity as there are [%d] model snapshot upgrades, "
+                                + "[%d] analytics and [%d] anomaly detection jobs in the queue, "
                                 + "but the number in the queue is less than the configured maximum allowed "
                                 + " or the queued jobs will eventually be assignable at the current size. ",
+                            waitingSnapshotUpgrades.size(),
                             waitingAnalyticsJobs.size(),
                             waitingAnomalyJobs.size()
                         )
@@ -690,6 +708,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
     // can eventually start, and given the current cluster, no job can eventually start.
     AutoscalingDeciderResult scaleUpFromZero(
         List<String> waitingAnomalyJobs,
+        List<String> waitingSnapshotUpgrades,
         List<String> waitingAnalyticsJobs,
         List<String> waitingAllocatedModels,
         MlScalingReason.Builder reasonBuilder
@@ -704,17 +723,23 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
             this::getAnomalyMemoryRequirement,
             0
         );
+        final Optional<NativeMemoryCapacity> snapshotUpgradeCapacity = requiredCapacityForUnassignedJobs(
+            waitingSnapshotUpgrades,
+            this::getAnomalyMemoryRequirement,
+            0
+        );
         final Optional<NativeMemoryCapacity> allocatedModelCapacity = requiredCapacityForUnassignedJobs(
             waitingAllocatedModels,
             this::getAllocatedModelRequirement,
             0
         );
         NativeMemoryCapacity updatedCapacity = NativeMemoryCapacity.ZERO.merge(anomalyCapacity.orElse(NativeMemoryCapacity.ZERO))
+            .merge(snapshotUpgradeCapacity.orElse(NativeMemoryCapacity.ZERO))
             .merge(analyticsCapacity.orElse(NativeMemoryCapacity.ZERO))
             .merge(allocatedModelCapacity.orElse(NativeMemoryCapacity.ZERO));
         // If we still have calculated zero, this means the ml memory tracker does not have the required info.
         // So, request a scale for the default. This is only for the 0 -> N scaling case.
-        if (updatedCapacity.getNode() == 0L) {
+        if (updatedCapacity.getNodeMlNativeMemoryRequirement() == 0L) {
             updatedCapacity.merge(
                 new NativeMemoryCapacity(
                     ByteSizeValue.ofMb(AnalysisLimits.DEFAULT_MODEL_MEMORY_LIMIT_MB).getBytes(),
@@ -744,6 +769,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
         int numAnalyticsJobsInQueue,
         List<NodeLoad> nodeLoads,
         List<String> waitingAnomalyJobs,
+        List<String> waitingSnapshotUpgrades,
         List<String> waitingAnalyticsJobs,
         List<String> waitingAllocatedModels,
         @Nullable NativeMemoryCapacity futureFreedCapacity,
@@ -753,11 +779,11 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
 
         // Are we in breach of maximum waiting jobs?
         if (waitingAnalyticsJobs.size() > numAnalyticsJobsInQueue
-            || waitingAnomalyJobs.size() > numAnomalyJobsInQueue
+            || waitingAnomalyJobs.size() + waitingSnapshotUpgrades.size() > numAnomalyJobsInQueue
             || waitingAllocatedModels.size() > 0) {
 
             Tuple<NativeMemoryCapacity, List<NodeLoad>> anomalyCapacityAndNewLoad = determineUnassignableJobs(
-                waitingAnomalyJobs,
+                Stream.concat(waitingAnomalyJobs.stream(), waitingSnapshotUpgrades.stream()).collect(Collectors.toList()),
                 this::getAnomalyMemoryRequirement,
                 numAnomalyJobsInQueue,
                 nodeLoads
@@ -808,16 +834,21 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
 
         // Could the currently waiting jobs ever be assigned?
         // NOTE: the previous predicate catches if an allocated model isn't assigned
-        if (waitingAnalyticsJobs.isEmpty() == false || waitingAnomalyJobs.isEmpty() == false) {
+        if (waitingAnalyticsJobs.isEmpty() == false
+            || waitingSnapshotUpgrades.isEmpty() == false
+            || waitingAnomalyJobs.isEmpty() == false) {
             // we are unable to determine new tier size, but maybe we can see if our nodes are big enough.
             if (futureFreedCapacity == null) {
                 Optional<Long> maxSize = Stream.concat(
                     waitingAnalyticsJobs.stream().map(mlMemoryTracker::getDataFrameAnalyticsJobMemoryRequirement),
-                    waitingAnomalyJobs.stream().map(mlMemoryTracker::getAnomalyDetectorJobMemoryRequirement)
+                    Stream.concat(
+                        waitingAnomalyJobs.stream().map(mlMemoryTracker::getAnomalyDetectorJobMemoryRequirement),
+                        waitingSnapshotUpgrades.stream().map(mlMemoryTracker::getAnomalyDetectorJobMemoryRequirement)
+                    )
                 ).filter(Objects::nonNull).max(Long::compareTo);
-                if (maxSize.isPresent() && maxSize.get() > currentScale.getNode()) {
+                if (maxSize.isPresent() && maxSize.get() > currentScale.getNodeMlNativeMemoryRequirement()) {
                     AutoscalingCapacity requiredCapacity = new NativeMemoryCapacity(
-                        Math.max(currentScale.getTier(), maxSize.get()),
+                        Math.max(currentScale.getTierMlNativeMemoryRequirement(), maxSize.get()),
                         maxSize.get()
                     ).autoscalingCapacity(maxMachineMemoryPercent, useAuto);
                     return Optional.of(
@@ -834,7 +865,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
             }
             long newTierNeeded = 0L;
             // could any of the nodes actually run the job?
-            long newNodeMax = currentScale.getNode();
+            long newNodeMax = currentScale.getNodeMlNativeMemoryRequirement();
             for (String analyticsJob : waitingAnalyticsJobs) {
                 Long requiredMemory = mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(analyticsJob);
                 // it is OK to continue here as we have not breached our queuing limit
@@ -842,7 +873,7 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
                     continue;
                 }
                 // Is there "future capacity" on a node that could run this job? If not, we need that much more in the tier.
-                if (futureFreedCapacity.getNode() < requiredMemory) {
+                if (futureFreedCapacity.getNodeMlNativeMemoryRequirement() < requiredMemory) {
                     newTierNeeded = Math.max(requiredMemory, newTierNeeded);
                 }
                 newNodeMax = Math.max(newNodeMax, requiredMemory);
@@ -854,12 +885,24 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
                     continue;
                 }
                 // Is there "future capacity" on a node that could run this job? If not, we need that much more in the tier.
-                if (futureFreedCapacity.getNode() < requiredMemory) {
+                if (futureFreedCapacity.getNodeMlNativeMemoryRequirement() < requiredMemory) {
                     newTierNeeded = Math.max(requiredMemory, newTierNeeded);
                 }
                 newNodeMax = Math.max(newNodeMax, requiredMemory);
             }
-            if (newNodeMax > currentScale.getNode() || newTierNeeded > 0L) {
+            for (String snapshotUpgrade : waitingSnapshotUpgrades) {
+                Long requiredMemory = mlMemoryTracker.getAnomalyDetectorJobMemoryRequirement(snapshotUpgrade);
+                // it is OK to continue here as we have not breached our queuing limit
+                if (requiredMemory == null) {
+                    continue;
+                }
+                // Is there "future capacity" on a node that could run this job? If not, we need that much more in the tier.
+                if (futureFreedCapacity.getNodeMlNativeMemoryRequirement() < requiredMemory) {
+                    newTierNeeded = Math.max(requiredMemory, newTierNeeded);
+                }
+                newNodeMax = Math.max(newNodeMax, requiredMemory);
+            }
+            if (newNodeMax > currentScale.getNodeMlNativeMemoryRequirement() || newTierNeeded > 0L) {
                 NativeMemoryCapacity newCapacity = new NativeMemoryCapacity(newTierNeeded, newNodeMax);
                 AutoscalingCapacity requiredCapacity = NativeMemoryCapacity.from(currentScale)
                     .merge(newCapacity)
@@ -983,15 +1026,16 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
         }
         // We consider a scale down if we are not fully utilizing the tier
         // Or our largest job could be on a smaller node (meaning the same size tier but smaller nodes are possible).
-        if (currentlyNecessaryTier < currentCapacity.getTier() || currentlyNecessaryNode < currentCapacity.getNode()) {
+        if (currentlyNecessaryTier < currentCapacity.getTierMlNativeMemoryRequirement()
+            || currentlyNecessaryNode < currentCapacity.getNodeMlNativeMemoryRequirement()) {
             NativeMemoryCapacity nativeMemoryCapacity = new NativeMemoryCapacity(
                 // Since we are in the `scaleDown` branch, we know jobs are running and we could be smaller
                 // If we have some weird rounding errors, it may be that the `currentlyNecessary` values are larger than
                 // current capacity. We never want to accidentally say "scale up" via a scale down.
-                Math.min(currentlyNecessaryTier, currentCapacity.getTier()),
-                Math.min(currentlyNecessaryNode, currentCapacity.getNode()),
+                Math.min(currentlyNecessaryTier, currentCapacity.getTierMlNativeMemoryRequirement()),
+                Math.min(currentlyNecessaryNode, currentCapacity.getNodeMlNativeMemoryRequirement()),
                 // If our newly suggested native capacity is the same, we can use the previously stored jvm size
-                currentlyNecessaryNode == currentCapacity.getNode() ? currentCapacity.getJvmSize() : null
+                currentlyNecessaryNode == currentCapacity.getNodeMlNativeMemoryRequirement() ? currentCapacity.getJvmSize() : null
             );
             AutoscalingCapacity requiredCapacity = nativeMemoryCapacity.autoscalingCapacity(maxMachineMemoryPercent, useAuto);
             return Optional.of(
@@ -1035,6 +1079,10 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService, L
         return jobState == null || jobState.isNoneOf(JobState.CLOSED, JobState.FAILED);
     }
 
+    private static boolean taskStateFilter(SnapshotUpgradeState snapshotUpgradeState) {
+        return snapshotUpgradeState == null || snapshotUpgradeState.isNoneOf(SnapshotUpgradeState.STOPPED, SnapshotUpgradeState.FAILED);
+    }
+
     private static boolean taskStateFilter(DataFrameAnalyticsState dataFrameAnalyticsState) {
         // Don't count stopped and failed df-analytics tasks as they don't consume native memory
         return dataFrameAnalyticsState == null

+ 20 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlScalingReason.java

@@ -36,6 +36,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
 
     private final List<String> waitingAnalyticsJobs;
     private final List<String> waitingAnomalyJobs;
+    private final List<String> waitingSnapshotUpgrades;
     private final List<String> waitingModels;
     private final Settings passedConfiguration;
     private final Long largestWaitingAnalyticsJob;
@@ -47,6 +48,12 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
     public MlScalingReason(StreamInput in) throws IOException {
         this.waitingAnalyticsJobs = in.readStringList();
         this.waitingAnomalyJobs = in.readStringList();
+        // TODO: change on backport
+        if (in.getVersion().onOrAfter(Version.V_8_1_0)) {
+            this.waitingSnapshotUpgrades = in.readStringList();
+        } else {
+            this.waitingSnapshotUpgrades = List.of();
+        }
         if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
             this.waitingModels = in.readStringList();
         } else {
@@ -63,6 +70,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
     MlScalingReason(
         List<String> waitingAnalyticsJobs,
         List<String> waitingAnomalyJobs,
+        List<String> waitingSnapshotUpgrades,
         List<String> waitingModels,
         Settings passedConfiguration,
         Long largestWaitingAnalyticsJob,
@@ -73,6 +81,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
     ) {
         this.waitingAnalyticsJobs = waitingAnalyticsJobs == null ? Collections.emptyList() : waitingAnalyticsJobs;
         this.waitingAnomalyJobs = waitingAnomalyJobs == null ? Collections.emptyList() : waitingAnomalyJobs;
+        this.waitingSnapshotUpgrades = waitingSnapshotUpgrades == null ? Collections.emptyList() : waitingSnapshotUpgrades;
         this.waitingModels = waitingModels == null ? List.of() : waitingModels;
         this.passedConfiguration = ExceptionsHelper.requireNonNull(passedConfiguration, CONFIGURATION);
         this.largestWaitingAnalyticsJob = largestWaitingAnalyticsJob;
@@ -131,6 +140,10 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
     public void writeTo(StreamOutput out) throws IOException {
         out.writeStringCollection(this.waitingAnalyticsJobs);
         out.writeStringCollection(this.waitingAnomalyJobs);
+        // TODO: change version on backport
+        if (out.getVersion().onOrAfter(Version.V_8_1_0)) {
+            out.writeStringCollection(this.waitingSnapshotUpgrades);
+        }
         if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
             out.writeStringCollection(this.waitingModels);
         }
@@ -172,6 +185,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
     static class Builder {
         private List<String> waitingAnalyticsJobs = Collections.emptyList();
         private List<String> waitingAnomalyJobs = Collections.emptyList();
+        private List<String> waitingSnapshotUpgrades = Collections.emptyList();
         private List<String> waitingModels = Collections.emptyList();
         private Settings passedConfiguration;
         private Long largestWaitingAnalyticsJob;
@@ -190,6 +204,11 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
             return this;
         }
 
+        public Builder setWaitingSnapshotUpgrades(List<String> waitingSnapshotUpgrades) {
+            this.waitingSnapshotUpgrades = waitingSnapshotUpgrades;
+            return this;
+        }
+
         public Builder setWaitingModels(List<String> waitingModels) {
             this.waitingModels = waitingModels;
             return this;
@@ -229,6 +248,7 @@ public class MlScalingReason implements AutoscalingDeciderResult.Reason {
             return new MlScalingReason(
                 waitingAnalyticsJobs,
                 waitingAnomalyJobs,
+                waitingSnapshotUpgrades,
                 waitingModels,
                 passedConfiguration,
                 largestWaitingAnalyticsJob,

+ 48 - 29
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/NativeMemoryCapacity.java

@@ -23,28 +23,28 @@ public class NativeMemoryCapacity {
     static final NativeMemoryCapacity ZERO = new NativeMemoryCapacity(0L, 0L);
 
     static NativeMemoryCapacity from(NativeMemoryCapacity capacity) {
-        return new NativeMemoryCapacity(capacity.tier, capacity.node, capacity.jvmSize);
+        return new NativeMemoryCapacity(capacity.tierMlNativeMemoryRequirement, capacity.nodeMlNativeMemoryRequirement, capacity.jvmSize);
     }
 
-    private long tier;
-    private long node;
+    private long tierMlNativeMemoryRequirement;
+    private long nodeMlNativeMemoryRequirement;
     private Long jvmSize;
 
-    public NativeMemoryCapacity(long tier, long node, Long jvmSize) {
-        this.tier = tier;
-        this.node = node;
+    public NativeMemoryCapacity(long tierMlNativeMemoryRequirement, long nodeMlNativeMemoryRequirement, Long jvmSize) {
+        this.tierMlNativeMemoryRequirement = tierMlNativeMemoryRequirement;
+        this.nodeMlNativeMemoryRequirement = nodeMlNativeMemoryRequirement;
         this.jvmSize = jvmSize;
     }
 
-    NativeMemoryCapacity(long tier, long node) {
-        this.tier = tier;
-        this.node = node;
+    NativeMemoryCapacity(long tierMlNativeMemoryRequirement, long nodeMlNativeMemoryRequirement) {
+        this.tierMlNativeMemoryRequirement = tierMlNativeMemoryRequirement;
+        this.nodeMlNativeMemoryRequirement = nodeMlNativeMemoryRequirement;
     }
 
     NativeMemoryCapacity merge(NativeMemoryCapacity nativeMemoryCapacity) {
-        this.tier += nativeMemoryCapacity.tier;
-        if (nativeMemoryCapacity.node > this.node) {
-            this.node = nativeMemoryCapacity.node;
+        this.tierMlNativeMemoryRequirement += nativeMemoryCapacity.tierMlNativeMemoryRequirement;
+        if (nativeMemoryCapacity.nodeMlNativeMemoryRequirement > this.nodeMlNativeMemoryRequirement) {
+            this.nodeMlNativeMemoryRequirement = nativeMemoryCapacity.nodeMlNativeMemoryRequirement;
             // If the new node size is bigger, we have no way of knowing if the JVM size would stay the same
             // So null out
             this.jvmSize = null;
@@ -55,32 +55,49 @@ public class NativeMemoryCapacity {
     public AutoscalingCapacity autoscalingCapacity(int maxMemoryPercent, boolean useAuto) {
         // We calculate the JVM size here first to ensure it stays the same given the rest of the calculations
         final Long jvmSize = useAuto
-            ? Optional.ofNullable(this.jvmSize).orElse(dynamicallyCalculateJvmSizeFromNativeMemorySize(node))
+            ? Optional.ofNullable(this.jvmSize).orElse(dynamicallyCalculateJvmSizeFromNativeMemorySize(nodeMlNativeMemoryRequirement))
             : null;
-        // We first need to calculate the actual node size given the current native memory size.
+        // We first need to calculate the required node size given the required native ML memory size.
         // This way we can accurately determine the required node size AND what the overall memory percentage will be
-        long actualNodeSize = NativeMemoryCalculator.calculateApproxNecessaryNodeSize(node, jvmSize, maxMemoryPercent, useAuto);
+        long requiredNodeSize = NativeMemoryCalculator.calculateApproxNecessaryNodeSize(
+            nodeMlNativeMemoryRequirement,
+            jvmSize,
+            maxMemoryPercent,
+            useAuto
+        );
         // We make the assumption that the JVM size is the same across the entire tier
         // This simplifies calculating the tier as it means that each node in the tier
         // will have the same dynamic memory calculation. And thus the tier is simply the sum of the memory necessary
         // times that scaling factor.
-        double memoryPercentForMl = NativeMemoryCalculator.modelMemoryPercent(actualNodeSize, jvmSize, maxMemoryPercent, useAuto);
+        //
+        // Since this is a _minimum_ node size, the memory percent calculated here is not
+        // necessarily what "auto" will imply after scaling. Because the JVM occupies a
+        // smaller proportion of memory the bigger the node, the memory percent might be
+        // higher than we calculate here, potentially resulting in a bigger ML tier than
+        // required. The effect is most pronounced when the minimum node size is small and
+        // the minimum tier size is a lot bigger. For example, if the minimum node size is
+        // 1GB and the total native ML memory requirement for the tier is 32GB then the memory
+        // percent will be 41%, implying a minimum ML tier size of 78GB. But in reality a
+        // single 64GB ML node would have an auto memory percent of 90%, and 90% of 64GB is
+        // plenty big enough for the 32GB of ML native memory required.
+        // TODO: improve this in the next refactoring
+        double memoryPercentForMl = NativeMemoryCalculator.modelMemoryPercent(requiredNodeSize, jvmSize, maxMemoryPercent, useAuto);
         double inverseScale = memoryPercentForMl <= 0 ? 0 : 100.0 / memoryPercentForMl;
-        long actualTier = Math.round(tier * inverseScale);
+        long requiredTierSize = Math.round(Math.ceil(tierMlNativeMemoryRequirement * inverseScale));
         return new AutoscalingCapacity(
             // Tier should always be AT LEAST the largest node size.
             // This Math.max catches any strange rounding errors or weird input.
-            new AutoscalingCapacity.AutoscalingResources(null, ByteSizeValue.ofBytes(Math.max(actualTier, actualNodeSize))),
-            new AutoscalingCapacity.AutoscalingResources(null, ByteSizeValue.ofBytes(actualNodeSize))
+            new AutoscalingCapacity.AutoscalingResources(null, ByteSizeValue.ofBytes(Math.max(requiredTierSize, requiredNodeSize))),
+            new AutoscalingCapacity.AutoscalingResources(null, ByteSizeValue.ofBytes(requiredNodeSize))
         );
     }
 
-    public long getTier() {
-        return tier;
+    public long getTierMlNativeMemoryRequirement() {
+        return tierMlNativeMemoryRequirement;
     }
 
-    public long getNode() {
-        return node;
+    public long getNodeMlNativeMemoryRequirement() {
+        return nodeMlNativeMemoryRequirement;
     }
 
     public Long getJvmSize() {
@@ -90,10 +107,10 @@ public class NativeMemoryCapacity {
     @Override
     public String toString() {
         return "NativeMemoryCapacity{"
-            + "total bytes="
-            + ByteSizeValue.ofBytes(tier)
-            + ", largest node bytes="
-            + ByteSizeValue.ofBytes(node)
+            + "total ML native bytes="
+            + ByteSizeValue.ofBytes(tierMlNativeMemoryRequirement)
+            + ", largest node ML native bytes="
+            + ByteSizeValue.ofBytes(nodeMlNativeMemoryRequirement)
             + '}';
     }
 
@@ -102,11 +119,13 @@ public class NativeMemoryCapacity {
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         NativeMemoryCapacity that = (NativeMemoryCapacity) o;
-        return tier == that.tier && node == that.node && Objects.equals(jvmSize, that.jvmSize);
+        return tierMlNativeMemoryRequirement == that.tierMlNativeMemoryRequirement
+            && nodeMlNativeMemoryRequirement == that.nodeMlNativeMemoryRequirement
+            && Objects.equals(jvmSize, that.jvmSize);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(tier, node, jvmSize);
+        return Objects.hash(tierMlNativeMemoryRequirement, nodeMlNativeMemoryRequirement, jvmSize);
     }
 }

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java

@@ -99,7 +99,8 @@ public class NodeLoadDetector {
             );
             for (PersistentTasksCustomMetadata.PersistentTask<?> task : memoryTrackedTasks) {
                 MemoryTrackedTaskState state = MlTasks.getMemoryTrackedTaskState(task);
-                if (state == null || state.consumesMemory()) {
+                assert state != null : "null MemoryTrackedTaskState for memory tracked task with params " + task.getParams();
+                if (state != null && state.consumesMemory()) {
                     MlTaskParams taskParams = (MlTaskParams) task.getParams();
                     nodeLoad.addTask(task.getTaskName(), taskParams.getMlId(), state.isAllocating(), mlMemoryTracker);
                 }

+ 1 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradePredicate.java

@@ -11,6 +11,7 @@ import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 

+ 3 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskExecutor.java

@@ -33,6 +33,7 @@ import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
 import org.elasticsearch.xpack.core.ml.job.results.Result;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
@@ -107,16 +108,10 @@ public class SnapshotUpgradeTaskExecutor extends AbstractJobPersistentTasksExecu
             // Use the job_task_name for the appropriate job size
             MlTasks.JOB_TASK_NAME,
             memoryTracker,
-            0,
+            maxLazyMLNodes,
             node -> null
         );
-        return jobNodeSelector.selectNode(
-            Integer.MAX_VALUE,
-            Integer.MAX_VALUE,
-            maxMachineMemoryPercent,
-            Long.MAX_VALUE,
-            useAutoMemoryPercentage
-        );
+        return jobNodeSelector.selectNode(maxOpenJobs, Integer.MAX_VALUE, maxMachineMemoryPercent, maxNodeMemory, useAutoMemoryPercentage);
     }
 
     @Override

+ 39 - 19
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java

@@ -29,6 +29,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
@@ -49,17 +50,18 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.Phaser;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
  * This class keeps track of the memory requirement of ML jobs.
  * It only functions on the master node - for this reason it should only be used by master node actions.
  * The memory requirement for ML jobs can be updated in 4 ways:
- * 1. For all open ML data frame analytics jobs and anomaly detector jobs (via {@link #asyncRefresh})
- * 2. For all open/started ML jobs, plus one named ML anomaly detector job that is not open
+ * 1. For all open ML data frame analytics jobs, anomaly detector jobs and model snapshot upgrades (via {@link #asyncRefresh})
+ * 2. For all open/started ML jobs and model snapshot upgrades, plus one named ML anomaly detector job that may not be open
  *    (via {@link #refreshAnomalyDetectorJobMemoryAndAllOthers})
- * 3. For all open/started ML jobs, plus one named ML data frame analytics job that is not started
+ * 3. For all open/started ML jobs and model snapshot upgrades, plus one named ML data frame analytics job that is not started
  *    (via {@link #addDataFrameAnalyticsJobMemoryAndRefreshAllOthers})
- * 4. For one named ML anomaly detector job (via {@link #refreshAnomalyDetectorJobMemory})
+ * 4. For one named ML anomaly detector job or model snapshot upgrade (via {@link #refreshAnomalyDetectorJobMemory})
  * In cases 2, 3 and 4 a listener informs the caller when the requested updates are complete.
  */
 public class MlMemoryTracker implements LocalNodeMasterListener {
@@ -78,7 +80,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
     private final JobResultsProvider jobResultsProvider;
     private final DataFrameAnalyticsConfigProvider configProvider;
     private final Phaser stopPhaser;
-    private volatile AtomicInteger phase = new AtomicInteger(0);
+    private final AtomicInteger phase = new AtomicInteger(0);
     private volatile boolean isMaster;
     private volatile boolean stopped;
     private volatile Instant lastUpdateTime;
@@ -102,6 +104,8 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
         Map<String, Map<String, Long>> memoryRequirementByTaskName = new TreeMap<>();
         memoryRequirementByTaskName.put(MlTasks.JOB_TASK_NAME, memoryRequirementByAnomalyDetectorJob);
         memoryRequirementByTaskName.put(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, memoryRequirementByDataFrameAnalyticsJob);
+        // We don't add snapshot upgrade tasks here - instead, we assume they
+        // have the same memory requirement as the job they correspond to.
         this.memoryRequirementByTaskName = Collections.unmodifiableMap(memoryRequirementByTaskName);
 
         setReassignmentRecheckInterval(PersistentTasksClusterService.CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING.get(settings));
@@ -260,6 +264,11 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
             return null;
         }
 
+        // Assume snapshot upgrade tasks have the same memory requirement as the job they correspond to.
+        if (MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME.equals(taskName)) {
+            taskName = MlTasks.JOB_TASK_NAME;
+        }
+
         Map<String, Long> memoryRequirementByJob = memoryRequirementByTaskName.get(taskName);
         if (memoryRequirementByJob == null) {
             assert false : "Unknown taskName type [" + taskName + "]";
@@ -327,8 +336,13 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
             return;
         }
 
-        PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
-        refresh(persistentTasks, ActionListener.wrap(aVoid -> refreshAnomalyDetectorJobMemory(jobId, listener), listener::onFailure));
+        // Skip the provided job ID in the main refresh, as we unconditionally do it at the end.
+        // Otherwise it might get refreshed twice, because it could have both a job task and a snapshot upgrade task.
+        refresh(
+            clusterService.state().getMetadata().custom(PersistentTasksCustomMetadata.TYPE),
+            Collections.singleton(jobId),
+            ActionListener.wrap(aVoid -> refreshAnomalyDetectorJobMemory(jobId, listener), listener::onFailure)
+        );
     }
 
     /**
@@ -359,6 +373,10 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
      * jobs are deleted.)
      */
     public void refresh(PersistentTasksCustomMetadata persistentTasks, ActionListener<Void> onCompletion) {
+        refresh(persistentTasks, Collections.emptySet(), onCompletion);
+    }
+
+    void refresh(PersistentTasksCustomMetadata persistentTasks, Set<String> jobIdsToSkip, ActionListener<Void> onCompletion) {
 
         synchronized (fullRefreshCompletionListeners) {
             fullRefreshCompletionListeners.add(onCompletion);
@@ -411,29 +429,31 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
                 refreshComplete::onFailure
             );
 
-            List<PersistentTasksCustomMetadata.PersistentTask<?>> mlAnomalyDetectorJobTasks = persistentTasks.tasks()
-                .stream()
-                .filter(task -> MlTasks.JOB_TASK_NAME.equals(task.getTaskName()))
-                .collect(Collectors.toList());
-            iterateAnomalyDetectorJobTasks(mlAnomalyDetectorJobTasks.iterator(), refreshDataFrameAnalyticsJobs);
+            Set<String> mlAnomalyDetectorJobTasks = Stream.concat(
+                persistentTasks.tasks()
+                    .stream()
+                    .filter(task -> MlTasks.JOB_TASK_NAME.equals(task.getTaskName()))
+                    .map(task -> ((OpenJobAction.JobParams) task.getParams()).getJobId()),
+                persistentTasks.tasks()
+                    .stream()
+                    .filter(task -> MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME.equals(task.getTaskName()))
+                    .map(task -> ((SnapshotUpgradeTaskParams) task.getParams()).getJobId())
+            ).filter(jobId -> jobIdsToSkip.contains(jobId) == false).collect(Collectors.toSet());
+            iterateAnomalyDetectorJobs(mlAnomalyDetectorJobTasks.iterator(), refreshDataFrameAnalyticsJobs);
         }
     }
 
-    private void iterateAnomalyDetectorJobTasks(
-        Iterator<PersistentTasksCustomMetadata.PersistentTask<?>> iterator,
-        ActionListener<Void> refreshComplete
-    ) {
+    private void iterateAnomalyDetectorJobs(Iterator<String> iterator, ActionListener<Void> refreshComplete) {
         if (iterator.hasNext()) {
-            OpenJobAction.JobParams jobParams = (OpenJobAction.JobParams) iterator.next().getParams();
             refreshAnomalyDetectorJobMemory(
-                jobParams.getJobId(),
+                iterator.next(),
                 ActionListener.wrap(
                     // Do the next iteration in a different thread, otherwise stack overflow
                     // can occur if the searches happen to be on the local node, as the huge
                     // chain of listeners are all called in the same thread if only one node
                     // is involved
                     mem -> threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)
-                        .execute(() -> iterateAnomalyDetectorJobTasks(iterator, refreshComplete)),
+                        .execute(() -> iterateAnomalyDetectorJobs(iterator, refreshComplete)),
                     refreshComplete::onFailure
                 )
             );

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/task/AbstractJobPersistentTasksExecutor.java

@@ -141,8 +141,8 @@ public abstract class AbstractJobPersistentTasksExecutor<Params extends Persiste
                             JOB_AUDIT_REQUIRES_MORE_MEMORY_TO_RUN,
                             ByteSizeValue.ofBytes(memoryTracker.getJobMemoryRequirement(getTaskName(), jobId)),
                             ByteSizeValue.ofBytes(capacityAndFreeMemory.v2()),
-                            ByteSizeValue.ofBytes(capacityAndFreeMemory.v1().getTier()),
-                            ByteSizeValue.ofBytes(capacityAndFreeMemory.v1().getNode())
+                            ByteSizeValue.ofBytes(capacityAndFreeMemory.v1().getTierMlNativeMemoryRequirement()),
+                            ByteSizeValue.ofBytes(capacityAndFreeMemory.v1().getNodeMlNativeMemoryRequirement())
                         )
                     );
                     auditedJobCapacity.put(getUniqueId(jobId), capacityAndFreeMemory.v2());

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java

@@ -26,11 +26,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
 import org.elasticsearch.xpack.ml.datafeed.DatafeedRunner;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager;
 import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager;
-import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.ml.process.MlController;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.junit.Before;

+ 156 - 13
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderServiceTests.java

@@ -103,7 +103,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
     private static final long DEFAULT_NODE_SIZE = ByteSizeValue.ofGb(20).getBytes();
     private static final long DEFAULT_JVM_SIZE = ByteSizeValue.ofMb((long) (DEFAULT_NODE_SIZE * 0.25)).getBytes();
     private static final long DEFAULT_JOB_SIZE = ByteSizeValue.ofMb(200).getBytes();
-    private static final long OVERHEAD = ByteSizeValue.ofMb(30).getBytes();
+    private static final long OVERHEAD = MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes();
     private NodeLoadDetector nodeLoadDetector;
     private ClusterService clusterService;
     private Settings settings;
@@ -170,6 +170,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
             jobTasks,
             List.of(),
             Collections.emptyList(),
+            Collections.emptyList(),
             null,
             new NativeMemoryCapacity(432013312, 432013312, 432013312L),
             reasonBuilder
@@ -258,6 +259,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                     waitingJobs,
                     List.of(),
                     Collections.emptyList(),
+                    Collections.emptyList(),
                     null,
                     new NativeMemoryCapacity(memoryForMl, memoryForMl, lowerTier.v2()),
                     new MlScalingReason.Builder().setPassedConfiguration(Settings.EMPTY)
@@ -323,6 +325,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 Collections.emptyList(),
                 Collections.emptyList(),
                 Collections.emptyList(),
+                Collections.emptyList(),
                 null,
                 NativeMemoryCapacity.ZERO,
                 MlScalingReason.builder()
@@ -332,8 +335,10 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
     }
 
     public void testScaleUp_withWaitingJobsAndAutoMemoryAndNoRoomInNodes() {
-        when(mlMemoryTracker.getAnomalyDetectorJobMemoryRequirement(any())).thenReturn(ByteSizeValue.ofGb(2).getBytes());
-        when(mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(any())).thenReturn(ByteSizeValue.ofGb(2).getBytes());
+        ByteSizeValue anomalyDetectorJobSize = ByteSizeValue.ofGb(randomIntBetween(2, 4));
+        ByteSizeValue analyticsJobSize = ByteSizeValue.ofGb(randomIntBetween(2, 4));
+        when(mlMemoryTracker.getAnomalyDetectorJobMemoryRequirement(any())).thenReturn(anomalyDetectorJobSize.getBytes());
+        when(mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(any())).thenReturn(analyticsJobSize.getBytes());
         List<String> jobTasks = Arrays.asList("waiting_job", "waiting_job_2");
         List<String> analytics = Arrays.asList("analytics_waiting");
         List<NodeLoad> fullyLoadedNode = Arrays.asList(
@@ -353,6 +358,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 0,
                 fullyLoadedNode,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 null,
@@ -366,13 +372,21 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 30,
                 true
             );
+            // Note: with more than 1 job involved this calculation can be a wild overestimate, because
+            // NativeMemoryCapacity.autoscalingCapacity() is assuming the memory percent is the same regardless of node size
             long allowedBytesForMlTier = NativeMemoryCalculator.allowedBytesForMl(
                 result.requiredCapacity().total().memory().getBytes(),
                 30,
                 true
             );
-            assertThat(allowedBytesForMlNode, greaterThanOrEqualTo(ByteSizeValue.ofGb(2).getBytes() + OVERHEAD));
-            assertThat(allowedBytesForMlTier, greaterThanOrEqualTo(ByteSizeValue.ofGb(2).getBytes() * 3 + OVERHEAD));
+            assertThat(
+                allowedBytesForMlNode,
+                greaterThanOrEqualTo(Math.max(anomalyDetectorJobSize.getBytes(), analyticsJobSize.getBytes()) + OVERHEAD)
+            );
+            assertThat(
+                allowedBytesForMlTier,
+                greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() * 2 + analyticsJobSize.getBytes() + OVERHEAD)
+            );
         }
         { // we allow one job in the analytics queue
             Optional<AutoscalingDeciderResult> decision = service.checkForScaleUp(
@@ -380,6 +394,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 1,
                 fullyLoadedNode,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 null,
@@ -393,13 +408,15 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 30,
                 true
             );
+            // Note: with more than 1 job involved this calculation can be a wild overestimate, because
+            // NativeMemoryCapacity.autoscalingCapacity() is assuming the memory percent is the same regardless of node size
             long allowedBytesForMlTier = NativeMemoryCalculator.allowedBytesForMl(
                 result.requiredCapacity().total().memory().getBytes(),
                 30,
                 true
             );
-            assertThat(allowedBytesForMlNode, greaterThanOrEqualTo(ByteSizeValue.ofGb(2).getBytes() + OVERHEAD));
-            assertThat(allowedBytesForMlTier, greaterThanOrEqualTo(ByteSizeValue.ofGb(2).getBytes() * 2 + OVERHEAD));
+            assertThat(allowedBytesForMlNode, greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() + OVERHEAD));
+            assertThat(allowedBytesForMlTier, greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() * 2 + OVERHEAD));
         }
         { // we allow one job in the anomaly queue and analytics queue
             Optional<AutoscalingDeciderResult> decision = service.checkForScaleUp(
@@ -407,6 +424,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 1,
                 fullyLoadedNode,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 null,
@@ -420,13 +438,124 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 30,
                 true
             );
+            // Note: with more than 1 job involved this calculation can be a wild overestimate, because
+            // NativeMemoryCapacity.autoscalingCapacity() is assuming the memory percent is the same regardless of node size
+            long allowedBytesForMlTier = NativeMemoryCalculator.allowedBytesForMl(
+                result.requiredCapacity().total().memory().getBytes(),
+                30,
+                true
+            );
+            assertThat(allowedBytesForMlNode, greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() + OVERHEAD));
+            assertThat(allowedBytesForMlTier, greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() + OVERHEAD));
+        }
+    }
+
+    public void testScaleUp_withWaitingSnapshotUpgradesAndAutoMemoryAndNoRoomInNodes() {
+        ByteSizeValue anomalyDetectorJobSize = ByteSizeValue.ofGb(randomIntBetween(2, 8));
+        ByteSizeValue analyticsJobSize = ByteSizeValue.ofGb(randomIntBetween(2, 8));
+        when(mlMemoryTracker.getAnomalyDetectorJobMemoryRequirement(any())).thenReturn(anomalyDetectorJobSize.getBytes());
+        when(mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(any())).thenReturn(analyticsJobSize.getBytes());
+        List<String> snapshotUpgradeTasks = Arrays.asList("waiting_upgrade", "waiting_upgrade_2");
+        List<NodeLoad> fullyLoadedNode = Arrays.asList(
+            NodeLoad.builder("any")
+                .setMaxMemory(ByteSizeValue.ofGb(1).getBytes())
+                .setUseMemory(true)
+                .incAssignedJobMemory(ByteSizeValue.ofGb(1).getBytes())
+                .build()
+        );
+        MlScalingReason.Builder reasonBuilder = new MlScalingReason.Builder().setPassedConfiguration(Settings.EMPTY)
+            .setCurrentMlCapacity(AutoscalingCapacity.ZERO);
+        MlAutoscalingDeciderService service = buildService();
+        service.setUseAuto(true);
+        { // No time in queue
+            Optional<AutoscalingDeciderResult> decision = service.checkForScaleUp(
+                0,
+                0,
+                fullyLoadedNode,
+                Collections.emptyList(),
+                snapshotUpgradeTasks,
+                Collections.emptyList(),
+                Collections.emptyList(),
+                null,
+                NativeMemoryCapacity.ZERO,
+                reasonBuilder
+            );
+            assertFalse(decision.isEmpty());
+            AutoscalingDeciderResult result = decision.get();
+            long allowedBytesForMlNode = NativeMemoryCalculator.allowedBytesForMl(
+                result.requiredCapacity().node().memory().getBytes(),
+                30,
+                true
+            );
+            // Note: with more than 1 job involved this calculation can be a wild overestimate, because
+            // NativeMemoryCapacity.autoscalingCapacity() is assuming the memory percent is the same regardless of node size
+            long allowedBytesForMlTier = NativeMemoryCalculator.allowedBytesForMl(
+                result.requiredCapacity().total().memory().getBytes(),
+                30,
+                true
+            );
+            assertThat(allowedBytesForMlNode, greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() + OVERHEAD));
+            assertThat(allowedBytesForMlTier, greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() * 2 + OVERHEAD));
+        }
+        { // we allow one job in the analytics queue
+            Optional<AutoscalingDeciderResult> decision = service.checkForScaleUp(
+                0,
+                1,
+                fullyLoadedNode,
+                Collections.emptyList(),
+                snapshotUpgradeTasks,
+                Collections.emptyList(),
+                Collections.emptyList(),
+                null,
+                NativeMemoryCapacity.ZERO,
+                reasonBuilder
+            );
+            assertFalse(decision.isEmpty());
+            AutoscalingDeciderResult result = decision.get();
+            long allowedBytesForMlNode = NativeMemoryCalculator.allowedBytesForMl(
+                result.requiredCapacity().node().memory().getBytes(),
+                30,
+                true
+            );
+            // Note: with more than 1 job involved this calculation can be a wild overestimate, because
+            // NativeMemoryCapacity.autoscalingCapacity() is assuming the memory percent is the same regardless of node size
             long allowedBytesForMlTier = NativeMemoryCalculator.allowedBytesForMl(
                 result.requiredCapacity().total().memory().getBytes(),
                 30,
                 true
             );
-            assertThat(allowedBytesForMlNode, greaterThanOrEqualTo(ByteSizeValue.ofGb(2).getBytes() + OVERHEAD));
-            assertThat(allowedBytesForMlTier, greaterThanOrEqualTo(ByteSizeValue.ofGb(2).getBytes() + OVERHEAD));
+            assertThat(allowedBytesForMlNode, greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() + OVERHEAD));
+            assertThat(allowedBytesForMlTier, greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() * 2 + OVERHEAD));
+        }
+        { // we allow one job in the anomaly queue and analytics queue
+            Optional<AutoscalingDeciderResult> decision = service.checkForScaleUp(
+                1,
+                1,
+                fullyLoadedNode,
+                Collections.emptyList(),
+                snapshotUpgradeTasks,
+                Collections.emptyList(),
+                Collections.emptyList(),
+                null,
+                NativeMemoryCapacity.ZERO,
+                reasonBuilder
+            );
+            assertFalse(decision.isEmpty());
+            AutoscalingDeciderResult result = decision.get();
+            long allowedBytesForMlNode = NativeMemoryCalculator.allowedBytesForMl(
+                result.requiredCapacity().node().memory().getBytes(),
+                30,
+                true
+            );
+            // Note: with more than 1 job involved this calculation can be a wild overestimate, because
+            // NativeMemoryCapacity.autoscalingCapacity() is assuming the memory percent is the same regardless of node size
+            long allowedBytesForMlTier = NativeMemoryCalculator.allowedBytesForMl(
+                result.requiredCapacity().total().memory().getBytes(),
+                30,
+                true
+            );
+            assertThat(allowedBytesForMlNode, greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() + OVERHEAD));
+            assertThat(allowedBytesForMlTier, greaterThanOrEqualTo(anomalyDetectorJobSize.getBytes() + OVERHEAD));
         }
     }
 
@@ -453,6 +582,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 0,
                 nodesWithRoom,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 null,
@@ -469,6 +599,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 1,
                 nodesWithRoom,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 null,
@@ -483,6 +614,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 0,
                 nodesWithRoom,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 null,
@@ -513,6 +645,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 0,
                 fullyLoadedNode,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 null,
@@ -529,6 +662,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 1,
                 fullyLoadedNode,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 null,
@@ -545,6 +679,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 1,
                 fullyLoadedNode,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 null,
@@ -577,6 +712,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 1,
                 fullyLoadedNode,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 null,
@@ -593,6 +729,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 1,
                 fullyLoadedNode,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 new NativeMemoryCapacity(ByteSizeValue.ofGb(3).getBytes(), ByteSizeValue.ofGb(1).getBytes()),
@@ -607,6 +744,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 1,
                 fullyLoadedNode,
                 jobTasks,
+                Collections.emptyList(),
                 analytics,
                 Collections.emptyList(),
                 new NativeMemoryCapacity(ByteSizeValue.ofMb(1).getBytes(), ByteSizeValue.ofMb(1).getBytes()),
@@ -638,6 +776,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
             fullyLoadedNode,
             Collections.emptyList(),
             Collections.emptyList(),
+            Collections.emptyList(),
             List.of("foo"),
             null,
             NativeMemoryCapacity.ZERO,
@@ -650,6 +789,8 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
             30,
             true
         );
+        // Note: with more than 1 job involved this calculation can be a wild overestimate, because
+        // NativeMemoryCapacity.autoscalingCapacity() is assuming the memory percent is the same regardless of node size
         long allowedBytesForMlTier = NativeMemoryCalculator.allowedBytesForMl(
             result.requiredCapacity().total().memory().getBytes(),
             30,
@@ -680,6 +821,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
             nodesWithRoom,
             Collections.emptyList(),
             Collections.emptyList(),
+            Collections.emptyList(),
             List.of("foo", "bar", "baz"),
             null,
             NativeMemoryCapacity.ZERO,
@@ -695,6 +837,7 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
                 nodesWithRoom,
                 Collections.emptyList(),
                 Collections.emptyList(),
+                Collections.emptyList(),
                 List.of("foo", "bar"),
                 null,
                 NativeMemoryCapacity.ZERO,
@@ -841,17 +984,17 @@ public class MlAutoscalingDeciderServiceTests extends ESTestCase {
             clusterState
         );
         assertThat(nativeMemoryCapacity.isEmpty(), is(false));
-        assertThat(nativeMemoryCapacity.get().getNode(), greaterThanOrEqualTo(DEFAULT_JOB_SIZE));
+        assertThat(nativeMemoryCapacity.get().getNodeMlNativeMemoryRequirement(), greaterThanOrEqualTo(DEFAULT_JOB_SIZE));
         assertThat(
-            nativeMemoryCapacity.get().getNode(),
+            nativeMemoryCapacity.get().getNodeMlNativeMemoryRequirement(),
             lessThanOrEqualTo(NativeMemoryCalculator.allowedBytesForMl(DEFAULT_NODE_SIZE, 20, true))
         );
         assertThat(
-            nativeMemoryCapacity.get().getTier(),
+            nativeMemoryCapacity.get().getTierMlNativeMemoryRequirement(),
             greaterThanOrEqualTo(DEFAULT_JOB_SIZE * (assignedAnalyticsJobs.size() + batchAnomalyJobs.size()))
         );
         assertThat(
-            nativeMemoryCapacity.get().getTier(),
+            nativeMemoryCapacity.get().getTierMlNativeMemoryRequirement(),
             lessThanOrEqualTo(3 * (NativeMemoryCalculator.allowedBytesForMl(DEFAULT_NODE_SIZE, 20, true)))
         );
     }

+ 1 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlScalingReasonTests.java

@@ -35,6 +35,7 @@ public class MlScalingReasonTests extends AbstractWireSerializingTestCase<MlScal
     @Override
     protected MlScalingReason createTestInstance() {
         return new MlScalingReason(
+            randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()),
             randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()),
             randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()),
             randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(5).collect(Collectors.toList()),

+ 6 - 6
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/NativeMemoryCapacityTests.java

@@ -31,14 +31,14 @@ public class NativeMemoryCapacityTests extends ESTestCase {
             ByteSizeValue.ofMb(50).getBytes()
         );
         capacity.merge(new NativeMemoryCapacity(ByteSizeValue.ofGb(1).getBytes(), ByteSizeValue.ofMb(100).getBytes()));
-        assertThat(capacity.getTier(), equalTo(ByteSizeValue.ofGb(1).getBytes() * 2L));
-        assertThat(capacity.getNode(), equalTo(ByteSizeValue.ofMb(200).getBytes()));
+        assertThat(capacity.getTierMlNativeMemoryRequirement(), equalTo(ByteSizeValue.ofGb(1).getBytes() * 2L));
+        assertThat(capacity.getNodeMlNativeMemoryRequirement(), equalTo(ByteSizeValue.ofMb(200).getBytes()));
         assertThat(capacity.getJvmSize(), equalTo(ByteSizeValue.ofMb(50).getBytes()));
 
         capacity.merge(new NativeMemoryCapacity(ByteSizeValue.ofGb(1).getBytes(), ByteSizeValue.ofMb(300).getBytes()));
 
-        assertThat(capacity.getTier(), equalTo(ByteSizeValue.ofGb(1).getBytes() * 3L));
-        assertThat(capacity.getNode(), equalTo(ByteSizeValue.ofMb(300).getBytes()));
+        assertThat(capacity.getTierMlNativeMemoryRequirement(), equalTo(ByteSizeValue.ofGb(1).getBytes() * 3L));
+        assertThat(capacity.getNodeMlNativeMemoryRequirement(), equalTo(ByteSizeValue.ofMb(300).getBytes()));
         assertThat(capacity.getJvmSize(), is(nullValue()));
     }
 
@@ -71,8 +71,8 @@ public class NativeMemoryCapacityTests extends ESTestCase {
     public void testAutoscalingCapacityConsistency() {
         final BiConsumer<NativeMemoryCapacity, Integer> consistentAutoAssertions = (nativeMemory, memoryPercentage) -> {
             AutoscalingCapacity autoscalingCapacity = nativeMemory.autoscalingCapacity(25, true);
-            assertThat(autoscalingCapacity.total().memory().getBytes(), greaterThan(nativeMemory.getTier()));
-            assertThat(autoscalingCapacity.node().memory().getBytes(), greaterThan(nativeMemory.getNode()));
+            assertThat(autoscalingCapacity.total().memory().getBytes(), greaterThan(nativeMemory.getTierMlNativeMemoryRequirement()));
+            assertThat(autoscalingCapacity.node().memory().getBytes(), greaterThan(nativeMemory.getNodeMlNativeMemoryRequirement()));
             assertThat(
                 autoscalingCapacity.total().memory().getBytes(),
                 greaterThanOrEqualTo(autoscalingCapacity.node().memory().getBytes())

+ 1 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradePredicateTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
 
 import static org.hamcrest.Matchers.containsString;

+ 112 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java

@@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
 import org.elasticsearch.xpack.ml.job.JobManager;
 import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider;
@@ -35,6 +36,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
@@ -72,7 +74,6 @@ public class MlMemoryTrackerTests extends ESTestCase {
         ThreadPool threadPool = mock(ThreadPool.class);
         ExecutorService executorService = mock(ExecutorService.class);
         doAnswer(invocation -> {
-            @SuppressWarnings("unchecked")
             Runnable r = (Runnable) invocation.getArguments()[0];
             r.run();
             return null;
@@ -102,6 +103,15 @@ public class MlMemoryTrackerTests extends ESTestCase {
             tasks.put(task.getId(), task);
         }
 
+        // One snapshot upgrade is for a running job, one for a job that isn't running
+        int numSnapshotUpgradeTasks = 2;
+        for (int i = numAnomalyDetectorJobTasks; i < numAnomalyDetectorJobTasks + numSnapshotUpgradeTasks; ++i) {
+            String jobId = "job" + i;
+            String snapshotId = Long.toString(randomLongBetween(1000000000L, 9999999999L));
+            PersistentTasksCustomMetadata.PersistentTask<?> task = makeTestSnapshotUpgradeTask(jobId, snapshotId);
+            tasks.put(task.getId(), task);
+        }
+
         List<String> allIds = new ArrayList<>();
         int numDataFrameAnalyticsTasks = randomIntBetween(2, 5);
         for (int i = 1; i <= numDataFrameAnalyticsTasks; ++i) {
@@ -112,7 +122,7 @@ public class MlMemoryTrackerTests extends ESTestCase {
         }
 
         PersistentTasksCustomMetadata persistentTasks = new PersistentTasksCustomMetadata(
-            numAnomalyDetectorJobTasks + numDataFrameAnalyticsTasks,
+            numAnomalyDetectorJobTasks + numSnapshotUpgradeTasks + numDataFrameAnalyticsTasks,
             tasks
         );
 
@@ -132,8 +142,9 @@ public class MlMemoryTrackerTests extends ESTestCase {
         }
 
         if (isMaster) {
-            for (int i = 1; i <= numAnomalyDetectorJobTasks; ++i) {
+            for (int i = 1; i < numAnomalyDetectorJobTasks + numSnapshotUpgradeTasks; ++i) {
                 String jobId = "job" + i;
+                // This should only be called once even for the job where there's both a job task and a snapshot upgrade task
                 verify(jobResultsProvider, times(1)).getEstablishedMemoryUsage(eq(jobId), any(), any(), any(), any());
             }
             verify(configProvider, times(1)).getConfigsForJobsWithTasksLeniently(eq(new HashSet<>(allIds)), any());
@@ -142,6 +153,83 @@ public class MlMemoryTrackerTests extends ESTestCase {
         }
     }
 
+    public void testRefreshWithSkips() {
+
+        boolean isMaster = randomBoolean();
+        if (isMaster) {
+            memoryTracker.onMaster();
+        } else {
+            memoryTracker.offMaster();
+        }
+
+        Map<String, PersistentTasksCustomMetadata.PersistentTask<?>> tasks = new HashMap<>();
+
+        Set<String> toSkip = new HashSet<>();
+
+        int numAnomalyDetectorJobTasks = randomIntBetween(2, 5);
+        for (int i = 1; i <= numAnomalyDetectorJobTasks; ++i) {
+            String jobId = "job" + i;
+            if (randomBoolean()) {
+                toSkip.add(jobId);
+            }
+            PersistentTasksCustomMetadata.PersistentTask<?> task = makeTestAnomalyDetectorTask(jobId);
+            tasks.put(task.getId(), task);
+        }
+
+        // One snapshot upgrade is for a running job, one for a job that isn't running
+        int numSnapshotUpgradeTasks = 2;
+        for (int i = numAnomalyDetectorJobTasks; i < numAnomalyDetectorJobTasks + numSnapshotUpgradeTasks; ++i) {
+            String jobId = "job" + i;
+            String snapshotId = Long.toString(randomLongBetween(1000000000L, 9999999999L));
+            PersistentTasksCustomMetadata.PersistentTask<?> task = makeTestSnapshotUpgradeTask(jobId, snapshotId);
+            tasks.put(task.getId(), task);
+        }
+
+        List<String> allIds = new ArrayList<>();
+        int numDataFrameAnalyticsTasks = randomIntBetween(2, 5);
+        for (int i = 1; i <= numDataFrameAnalyticsTasks; ++i) {
+            String id = "analytics" + i;
+            allIds.add(id);
+            PersistentTasksCustomMetadata.PersistentTask<?> task = makeTestDataFrameAnalyticsTask(id, false);
+            tasks.put(task.getId(), task);
+        }
+
+        PersistentTasksCustomMetadata persistentTasks = new PersistentTasksCustomMetadata(
+            numAnomalyDetectorJobTasks + numSnapshotUpgradeTasks + numDataFrameAnalyticsTasks,
+            tasks
+        );
+
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            Consumer<Long> listener = (Consumer<Long>) invocation.getArguments()[3];
+            listener.accept(randomLongBetween(1000, 1000000));
+            return null;
+        }).when(jobResultsProvider).getEstablishedMemoryUsage(anyString(), any(), any(), any(), any());
+
+        if (isMaster) {
+            memoryTracker.refresh(persistentTasks, toSkip, ActionListener.wrap(aVoid -> {}, ESTestCase::assertNull));
+        } else {
+            AtomicReference<Exception> exception = new AtomicReference<>();
+            memoryTracker.refresh(persistentTasks, toSkip, ActionListener.wrap(e -> fail("Expected failure response"), exception::set));
+            assertEquals("Request to refresh anomaly detector memory requirement on non-master node", exception.get().getMessage());
+        }
+
+        if (isMaster) {
+            for (int i = 1; i < numAnomalyDetectorJobTasks + numSnapshotUpgradeTasks; ++i) {
+                String jobId = "job" + i;
+                if (toSkip.contains(jobId)) {
+                    verify(jobResultsProvider, never()).getEstablishedMemoryUsage(eq(jobId), any(), any(), any(), any());
+                } else {
+                    // This should only be called once even for the job where there's both a job task and a snapshot upgrade task
+                    verify(jobResultsProvider, times(1)).getEstablishedMemoryUsage(eq(jobId), any(), any(), any(), any());
+                }
+            }
+            verify(configProvider, times(1)).getConfigsForJobsWithTasksLeniently(eq(new HashSet<>(allIds)), any());
+        } else {
+            verify(jobResultsProvider, never()).getEstablishedMemoryUsage(anyString(), any(), any(), any(), any());
+        }
+    }
+
     public void testRefreshAllFailure() {
 
         Map<String, PersistentTasksCustomMetadata.PersistentTask<?>> tasks = new HashMap<>();
@@ -153,6 +241,13 @@ public class MlMemoryTrackerTests extends ESTestCase {
             tasks.put(task.getId(), task);
         }
 
+        int numSnapshotUpgradeTasks = randomIntBetween(1, 3);
+        for (int i = numAnomalyDetectorJobTasks; i < numAnomalyDetectorJobTasks + numSnapshotUpgradeTasks; ++i) {
+            String jobId = "job" + i;
+            PersistentTasksCustomMetadata.PersistentTask<?> task = makeTestAnomalyDetectorTask(jobId);
+            tasks.put(task.getId(), task);
+        }
+
         int numDataFrameAnalyticsTasks = randomIntBetween(2, 5);
         for (int i = 1; i <= numDataFrameAnalyticsTasks; ++i) {
             String id = "analytics" + i;
@@ -161,7 +256,7 @@ public class MlMemoryTrackerTests extends ESTestCase {
         }
 
         PersistentTasksCustomMetadata persistentTasks = new PersistentTasksCustomMetadata(
-            numAnomalyDetectorJobTasks + numDataFrameAnalyticsTasks,
+            numAnomalyDetectorJobTasks + numSnapshotUpgradeTasks + numDataFrameAnalyticsTasks,
             tasks
         );
 
@@ -294,6 +389,19 @@ public class MlMemoryTrackerTests extends ESTestCase {
         );
     }
 
+    private PersistentTasksCustomMetadata.PersistentTask<SnapshotUpgradeTaskParams> makeTestSnapshotUpgradeTask(
+        String jobId,
+        String snapshotId
+    ) {
+        return new PersistentTasksCustomMetadata.PersistentTask<>(
+            MlTasks.snapshotUpgradeTaskId(jobId, snapshotId),
+            MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME,
+            new SnapshotUpgradeTaskParams(jobId, snapshotId),
+            0,
+            PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT
+        );
+    }
+
     private PersistentTasksCustomMetadata.PersistentTask<StartDataFrameAnalyticsAction.TaskParams> makeTestDataFrameAnalyticsTask(
         String id,
         boolean allowLazyStart