Browse Source

[ML] Simplify node load memory calculation for various tasks (#75186)

Refactors calculation of node memory load so that there is a
framework for supporting various different ML tasks. This results
in simpler code and it is a step towards making it easier to
add memory tracking for future tasks.
Dimitris Athanasiou 4 years ago
parent
commit
3a76a398cb
16 changed files with 177 additions and 101 deletions
  1. 18 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java
  2. 8 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/OpenJobAction.java
  3. 9 3
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java
  4. 7 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDatafeedAction.java
  5. 8 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java
  6. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java
  7. 13 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/JobState.java
  8. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/JobTaskState.java
  9. 14 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/snapshot/upgrade/SnapshotUpgradeState.java
  10. 3 3
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/snapshot/upgrade/SnapshotUpgradeTaskState.java
  11. 17 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MemoryTrackedTaskState.java
  12. 16 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlTaskParams.java
  13. 15 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderService.java
  14. 13 56
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoad.java
  15. 25 22
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java
  16. 7 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskParams.java

+ 18 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java

@@ -17,6 +17,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
+import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
+import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
 
 import java.util.Collection;
 import java.util.Collections;
@@ -331,4 +334,19 @@ public final class MlTasks {
 
         return tasks.findTasks(DATAFEED_TASK_NAME, task -> PersistentTasksClusterService.needsReassignment(task.getAssignment(), nodes));
     }
+
+    public static MemoryTrackedTaskState getMemoryTrackedTaskState(PersistentTasksCustomMetadata.PersistentTask<?> task) {
+        String taskName = task.getTaskName();
+        switch (taskName) {
+            case JOB_TASK_NAME:
+                return getJobStateModifiedForReassignments(task);
+            case JOB_SNAPSHOT_UPGRADE_TASK_NAME:
+                SnapshotUpgradeTaskState taskState = (SnapshotUpgradeTaskState) task.getState();
+                return taskState == null ? SnapshotUpgradeState.LOADING_OLD_STATE : taskState.getState();
+            case DATA_FRAME_ANALYTICS_TASK_NAME:
+                return getDataFrameAnalyticsState(task);
+            default:
+                throw new IllegalStateException("unexpected task type [" + task.getTaskName() + "]");
+        }
+    }
 }

+ 8 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/OpenJobAction.java

@@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 
 import java.io.IOException;
 import java.util.Objects;
@@ -57,7 +58,7 @@ public class OpenJobAction extends ActionType<NodeAcknowledgedResponse> {
         private JobParams jobParams;
 
         public Request(JobParams jobParams) {
-            this.jobParams = jobParams;
+            this.jobParams = Objects.requireNonNull(jobParams);
         }
 
         public Request(String jobId) {
@@ -113,7 +114,7 @@ public class OpenJobAction extends ActionType<NodeAcknowledgedResponse> {
         }
     }
 
-    public static class JobParams implements PersistentTaskParams {
+    public static class JobParams implements PersistentTaskParams, MlTaskParams {
 
         public static final ParseField TIMEOUT = new ParseField("timeout");
         public static final ParseField JOB = new ParseField("job");
@@ -235,6 +236,11 @@ public class OpenJobAction extends ActionType<NodeAcknowledgedResponse> {
         public Version getMinimalSupportedVersion() {
             return Version.CURRENT.minimumCompatibilityVersion();
         }
+
+        @Override
+        public String getMlId() {
+            return jobId;
+        }
     }
 
     public interface JobTaskMatcher {

+ 9 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java

@@ -10,22 +10,23 @@ import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.support.master.MasterNodeRequest;
-import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.persistent.PersistentTaskParams;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 
 import java.io.IOException;
 import java.util.Objects;
@@ -139,7 +140,7 @@ public class StartDataFrameAnalyticsAction extends ActionType<NodeAcknowledgedRe
         }
     }
 
-    public static class TaskParams implements PersistentTaskParams {
+    public static class TaskParams implements PersistentTaskParams, MlTaskParams {
 
         public static final Version VERSION_INTRODUCED = Version.V_7_3_0;
         public static final Version VERSION_DESTINATION_INDEX_MAPPINGS_CHANGED = Version.V_7_10_0;
@@ -232,6 +233,11 @@ public class StartDataFrameAnalyticsAction extends ActionType<NodeAcknowledgedRe
                 && Objects.equals(version, other.version)
                 && Objects.equals(allowLazyStart, other.allowLazyStart);
         }
+
+        @Override
+        public String getMlId() {
+            return id;
+        }
     }
 
     public interface TaskMatcher {

+ 7 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDatafeedAction.java

@@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -129,7 +130,7 @@ public class StartDatafeedAction extends ActionType<NodeAcknowledgedResponse> {
         }
     }
 
-    public static class DatafeedParams implements PersistentTaskParams {
+    public static class DatafeedParams implements PersistentTaskParams, MlTaskParams {
 
         public static final ParseField INDICES = new ParseField("indices");
 
@@ -326,6 +327,11 @@ public class StartDatafeedAction extends ActionType<NodeAcknowledgedResponse> {
                     Objects.equals(indicesOptions, other.indicesOptions) &&
                     Objects.equals(datafeedIndices, other.datafeedIndices);
         }
+
+        @Override
+        public String getMlId() {
+            return datafeedId;
+        }
     }
 
     public interface DatafeedTaskMatcher {

+ 8 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java

@@ -9,12 +9,13 @@ package org.elasticsearch.xpack.core.ml.dataframe;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
 
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.Locale;
 
-public enum DataFrameAnalyticsState implements Writeable {
+public enum DataFrameAnalyticsState implements Writeable, MemoryTrackedTaskState {
 
     // States reindexing and analyzing are no longer used.
     // However, we need to keep them for BWC as tasks may be
@@ -47,9 +48,14 @@ public enum DataFrameAnalyticsState implements Writeable {
     }
 
     /**
-     * @return {@code false} if state matches any of the given {@code candidates}
+     * @return {@code true} if state matches none of the given {@code candidates}
      */
     public boolean isNoneOf(DataFrameAnalyticsState... candidates) {
         return Arrays.stream(candidates).noneMatch(candidate -> this == candidate);
     }
+
+    @Override
+    public boolean consumesMemory() {
+        return isNoneOf(FAILED, STOPPED);
+    }
 }

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java

@@ -6,13 +6,13 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe;
 
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.persistent.PersistentTaskState;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.xpack.core.ml.MlTasks;

+ 13 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/JobState.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.core.ml.job.config;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
 
 import java.io.IOException;
 import java.util.Arrays;
@@ -19,7 +20,7 @@ import java.util.Locale;
  * When a job is created it is initialised in to the state closed
  * i.e. it is not running.
  */
-public enum JobState implements Writeable {
+public enum JobState implements Writeable, MemoryTrackedTaskState {
 
     CLOSING, CLOSED, OPENED, FAILED, OPENING;
 
@@ -50,7 +51,7 @@ public enum JobState implements Writeable {
     }
 
     /**
-     * @return {@code false} if state matches any of the given {@code candidates}
+     * @return {@code true} if state matches none of the given {@code candidates}
      */
     public boolean isNoneOf(JobState... candidates) {
         return Arrays.stream(candidates).noneMatch(candidate -> this == candidate);
@@ -60,4 +61,14 @@ public enum JobState implements Writeable {
     public String toString() {
         return name().toLowerCase(Locale.ROOT);
     }
+
+    @Override
+    public boolean consumesMemory() {
+        return isNoneOf(CLOSED, FAILED);
+    }
+
+    @Override
+    public boolean isAllocating() {
+        return this == OPENING;
+    }
 }

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/JobTaskState.java

@@ -6,13 +6,13 @@
  */
 package org.elasticsearch.xpack.core.ml.job.config;
 
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.persistent.PersistentTaskState;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask;
 import org.elasticsearch.xpack.core.ml.MlTasks;

+ 14 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/snapshot/upgrade/SnapshotUpgradeState.java

@@ -9,11 +9,13 @@ package org.elasticsearch.xpack.core.ml.job.snapshot.upgrade;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Locale;
 
-public enum SnapshotUpgradeState implements Writeable {
+public enum SnapshotUpgradeState implements Writeable, MemoryTrackedTaskState {
 
     LOADING_OLD_STATE, SAVING_NEW_STATE, STOPPED, FAILED;
 
@@ -35,4 +37,15 @@ public enum SnapshotUpgradeState implements Writeable {
         return name().toLowerCase(Locale.ROOT);
     }
 
+    /**
+     * @return {@code true} if state matches none of the given {@code candidates}
+     */
+    public boolean isNoneOf(SnapshotUpgradeState... candidates) {
+        return Arrays.stream(candidates).noneMatch(candidate -> this == candidate);
+    }
+
+    @Override
+    public boolean consumesMemory() {
+        return isNoneOf(FAILED, STOPPED);
+    }
 }

+ 3 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/snapshot/upgrade/SnapshotUpgradeTaskState.java

@@ -7,13 +7,13 @@
 
 package org.elasticsearch.xpack.core.ml.job.snapshot.upgrade;
 
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ParseField;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.persistent.PersistentTaskState;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 
@@ -21,7 +21,7 @@ import java.io.IOException;
 import java.io.UncheckedIOException;
 import java.util.Objects;
 
-public class SnapshotUpgradeTaskState implements PersistentTaskState {
+public class SnapshotUpgradeTaskState implements PersistentTaskState{
 
     public static final String NAME = MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME;
 

+ 17 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MemoryTrackedTaskState.java

@@ -0,0 +1,17 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.utils;
+
+public interface MemoryTrackedTaskState {
+
+    boolean consumesMemory();
+
+    default boolean isAllocating() {
+        return false;
+    }
+}

+ 16 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MlTaskParams.java

@@ -0,0 +1,16 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.utils;
+
+public interface MlTaskParams {
+
+    /**
+     * The id of the ML config this task is executing
+     */
+    String getMlId();
+}

+ 15 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingDeciderService.java

@@ -14,14 +14,14 @@ import org.elasticsearch.cluster.LocalNodeMasterListener;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.core.Tuple;
 import org.elasticsearch.common.component.LifecycleListener;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
-import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.common.xcontent.XContentElasticsearchExtension;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingCapacity;
@@ -30,7 +30,9 @@ import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderResult;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderService;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction.DatafeedParams;
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
+import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.job.NodeLoad;
 import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
@@ -62,7 +64,6 @@ import static org.elasticsearch.xpack.core.ml.MlTasks.getJobStateModifiedForReas
 import static org.elasticsearch.xpack.ml.MachineLearning.MAX_OPEN_JOBS_PER_NODE;
 import static org.elasticsearch.xpack.ml.MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD;
 import static org.elasticsearch.xpack.ml.job.JobNodeSelector.AWAITING_LAZY_ASSIGNMENT;
-import static org.elasticsearch.xpack.ml.job.NodeLoad.taskStateFilter;
 
 public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
     LocalNodeMasterListener {
@@ -928,4 +929,14 @@ public class MlAutoscalingDeciderService implements AutoscalingDeciderService,
     public List<DiscoveryNodeRole> roles() {
         return List.of(DiscoveryNodeRole.ML_ROLE);
     }
+
+    private static boolean taskStateFilter(JobState jobState) {
+        return jobState == null || jobState.isNoneOf(JobState.CLOSED, JobState.FAILED);
+    }
+
+    private static boolean taskStateFilter(DataFrameAnalyticsState dataFrameAnalyticsState) {
+        // Don't count stopped and failed df-analytics tasks as they don't consume native memory
+        return dataFrameAnalyticsState == null
+            || dataFrameAnalyticsState.isNoneOf(DataFrameAnalyticsState.STOPPED, DataFrameAnalyticsState.FAILED);
+    }
 }

+ 13 - 56
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoad.java

@@ -11,27 +11,12 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.core.Nullable;
-import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
-import org.elasticsearch.xpack.core.ml.MlTasks;
-import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
-import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
-import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 
 import java.util.Objects;
 
 public class NodeLoad {
 
-    public static boolean taskStateFilter(JobState jobState) {
-        return jobState == null || jobState.isNoneOf(JobState.CLOSED, JobState.FAILED);
-    }
-
-    public static boolean taskStateFilter(DataFrameAnalyticsState dataFrameAnalyticsState) {
-        // Don't count stopped and failed df-analytics tasks as they don't consume native memory
-        return dataFrameAnalyticsState == null
-            || dataFrameAnalyticsState.isNoneOf(DataFrameAnalyticsState.STOPPED, DataFrameAnalyticsState.FAILED);
-    }
-
     private static final Logger logger = LogManager.getLogger(NodeLoadDetector.class);
 
     private final long maxMemory;
@@ -236,48 +221,20 @@ public class NodeLoad {
             return this;
         }
 
-        void adjustForAnomalyJob(JobState jobState,
-                                 String jobId,
-                                 MlMemoryTracker mlMemoryTracker) {
-            if (taskStateFilter(jobState) && jobId != null) {
-                // Don't count CLOSED or FAILED jobs, as they don't consume native memory
-                ++numAssignedJobs;
-                if (jobState == JobState.OPENING) {
-                    ++numAllocatingJobs;
-                }
-                Long jobMemoryRequirement = mlMemoryTracker.getAnomalyDetectorJobMemoryRequirement(jobId);
-                if (jobMemoryRequirement == null) {
-                    useMemory = false;
-                    logger.debug(() -> new ParameterizedMessage(
-                        "[{}] memory requirement was not available. Calculating load by number of assigned jobs.",
-                        jobId
-                    ));
-                } else {
-                    assignedJobMemory += jobMemoryRequirement;
-                }
+        void addTask(String taskName, String taskId, boolean isAllocating, MlMemoryTracker memoryTracker) {
+            ++numAssignedJobs;
+            if (isAllocating) {
+                ++numAllocatingJobs;
             }
-        }
-
-        void adjustForAnalyticsJob(PersistentTasksCustomMetadata.PersistentTask<?> assignedTask,
-                                   MlMemoryTracker mlMemoryTracker) {
-            DataFrameAnalyticsState dataFrameAnalyticsState = MlTasks.getDataFrameAnalyticsState(assignedTask);
-
-            if (taskStateFilter(dataFrameAnalyticsState)) {
-                // The native process is only running in the ANALYZING and STOPPING states, but in the STARTED
-                // and REINDEXING states we're committed to using the memory soon, so account for it here
-                ++numAssignedJobs;
-                StartDataFrameAnalyticsAction.TaskParams params =
-                    (StartDataFrameAnalyticsAction.TaskParams) assignedTask.getParams();
-                Long jobMemoryRequirement = mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(params.getId());
-                if (jobMemoryRequirement == null) {
-                    useMemory = false;
-                    logger.debug(() -> new ParameterizedMessage(
-                        "[{}] memory requirement was not available. Calculating load by number of assigned jobs.",
-                        params.getId()
-                    ));
-                } else {
-                    assignedJobMemory += jobMemoryRequirement;
-                }
+            Long jobMemoryRequirement = memoryTracker.getJobMemoryRequirement(taskName, taskId);
+            if (jobMemoryRequirement == null) {
+                useMemory = false;
+                logger.debug(() -> new ParameterizedMessage(
+                    "[{}] memory requirement was not available. Calculating load by number of assigned jobs.",
+                    taskId
+                ));
+            } else {
+                assignedJobMemory += jobMemoryRequirement;
             }
         }
 

+ 25 - 22
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java

@@ -11,10 +11,9 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.xpack.core.ml.MlTasks;
-import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
-import org.elasticsearch.xpack.core.ml.job.config.JobState;
+import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
+import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 import org.elasticsearch.xpack.ml.MachineLearning;
-import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskParams;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.elasticsearch.xpack.ml.utils.NativeMemoryCalculator;
 
@@ -23,6 +22,7 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.OptionalLong;
+import java.util.stream.Collectors;
 
 
 public class NodeLoadDetector {
@@ -80,27 +80,16 @@ public class NodeLoadDetector {
 
     private void updateLoadGivenTasks(NodeLoad.Builder nodeLoad, PersistentTasksCustomMetadata persistentTasks) {
         if (persistentTasks != null) {
-            // find all the anomaly detector job tasks assigned to this node
-            Collection<PersistentTasksCustomMetadata.PersistentTask<?>> assignedAnomalyDetectorTasks = persistentTasks.findTasks(
-                MlTasks.JOB_TASK_NAME, task -> nodeLoad.getNodeId().equals(task.getExecutorNode()));
-            for (PersistentTasksCustomMetadata.PersistentTask<?> assignedTask : assignedAnomalyDetectorTasks) {
-                JobState jobState = MlTasks.getJobStateModifiedForReassignments(assignedTask);
-                OpenJobAction.JobParams params = (OpenJobAction.JobParams) assignedTask.getParams();
-                nodeLoad.adjustForAnomalyJob(jobState, params == null ? null : params.getJobId(), mlMemoryTracker);
-            }
-            Collection<PersistentTasksCustomMetadata.PersistentTask<?>> assignedShapshotUpgraderTasks = persistentTasks.findTasks(
-                MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME, task -> nodeLoad.getNodeId().equals(task.getExecutorNode()));
-            for (PersistentTasksCustomMetadata.PersistentTask<?> assignedTask : assignedShapshotUpgraderTasks) {
-                SnapshotUpgradeTaskParams params = (SnapshotUpgradeTaskParams) assignedTask.getParams();
-                nodeLoad.adjustForAnomalyJob(JobState.OPENED, params == null ? null : params.getJobId(), mlMemoryTracker);
+            Collection<PersistentTasksCustomMetadata.PersistentTask<?>> memoryTrackedTasks = findAllMemoryTrackedTasks(
+                persistentTasks, nodeLoad.getNodeId());
+            for (PersistentTasksCustomMetadata.PersistentTask<?> task : memoryTrackedTasks) {
+                MemoryTrackedTaskState state = MlTasks.getMemoryTrackedTaskState(task);
+                if (state == null || state.consumesMemory()) {
+                    MlTaskParams taskParams = (MlTaskParams) task.getParams();
+                    nodeLoad.addTask(task.getTaskName(), taskParams.getMlId(), state.isAllocating(), mlMemoryTracker);
+                }
             }
 
-            // find all the data frame analytics job tasks assigned to this node
-            Collection<PersistentTasksCustomMetadata.PersistentTask<?>> assignedAnalyticsTasks = persistentTasks.findTasks(
-                MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, task -> nodeLoad.getNodeId().equals(task.getExecutorNode()));
-            for (PersistentTasksCustomMetadata.PersistentTask<?> assignedTask : assignedAnalyticsTasks) {
-                nodeLoad.adjustForAnalyticsJob(assignedTask, mlMemoryTracker);
-            }
             // if any jobs are running then the native code will be loaded, but shared between all jobs,
             // so increase the total memory usage of the assigned jobs to account for this
             if (nodeLoad.getNumAssignedJobs() > 0) {
@@ -109,4 +98,18 @@ public class NodeLoadDetector {
         }
     }
 
+    private static Collection<PersistentTasksCustomMetadata.PersistentTask<?>> findAllMemoryTrackedTasks(
+        PersistentTasksCustomMetadata persistentTasks, String nodeId) {
+        return persistentTasks.tasks().stream()
+            .filter(NodeLoadDetector::isMemoryTrackedTask)
+            .filter(task -> nodeId.equals(task.getExecutorNode()))
+            .collect(Collectors.toList());
+    }
+
+    private static boolean isMemoryTrackedTask(PersistentTasksCustomMetadata.PersistentTask<?> task) {
+        return MlTasks.JOB_TASK_NAME.equals(task.getTaskName())
+            || MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME.equals(task.getTaskName())
+            || MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME.equals(task.getTaskName());
+    }
+
 }

+ 7 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/snapshot/upgrader/SnapshotUpgradeTaskParams.java

@@ -15,13 +15,14 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.persistent.PersistentTaskParams;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
+import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 
 import java.io.IOException;
 import java.util.Objects;
 
 import static org.elasticsearch.xpack.core.ml.MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME;
 
-public class SnapshotUpgradeTaskParams implements PersistentTaskParams {
+public class SnapshotUpgradeTaskParams implements PersistentTaskParams, MlTaskParams {
 
     public static final ParseField SNAPSHOT_ID = new ParseField("snapshot_id");
 
@@ -96,6 +97,11 @@ public class SnapshotUpgradeTaskParams implements PersistentTaskParams {
     public int hashCode() {
         return Objects.hash(jobId, snapshotId);
     }
+
+    @Override
+    public String getMlId() {
+        return jobId;
+    }
 }