Browse Source

[ML] updating node memory load for new allocation service (#76046)

Trained model deployment memory usage is no longer determinable via persistent tasks.

The new way is to look into the trained model allocation metadata.

This PR updates this and removes some unused code.

relates: #75778
Benjamin Trent 4 years ago
parent
commit
117e74bbc8
14 changed files with 110 additions and 312 deletions
  1. 3 30
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java
  2. 3 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
  3. 21 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java
  4. 0 52
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java
  5. 0 117
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java
  6. 0 53
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java
  7. 0 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  8. 4 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java
  9. 5 3
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java
  10. 20 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java
  11. 13 16
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java
  12. 11 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java
  13. 30 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java
  14. 0 17
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java

+ 3 - 30
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java

@@ -15,8 +15,6 @@ import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
-import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState;
-import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
@@ -30,11 +28,13 @@ import java.util.stream.Collectors;
 
 public final class MlTasks {
 
+    public static final String TRAINED_MODEL_ALLOCATION_TASK_TYPE = "trained_model_allocation";
+    public static final String TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX = "xpack/ml/allocation-";
+
     public static final String JOB_TASK_NAME = "xpack/ml/job";
     public static final String DATAFEED_TASK_NAME = "xpack/ml/datafeed";
     public static final String DATA_FRAME_ANALYTICS_TASK_NAME = "xpack/ml/data_frame/analytics";
     public static final String JOB_SNAPSHOT_UPGRADE_TASK_NAME = "xpack/ml/job/snapshot/upgrade";
-    public static final String TRAINED_MODEL_DEPLOYMENT_TASK_NAME = "xpack/ml/trained_model/deployment";
 
     public static final String JOB_TASK_ID_PREFIX = "job-";
     public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-";
@@ -225,31 +225,6 @@ public final class MlTasks {
         return state;
     }
 
-    public static TrainedModelDeploymentState getTrainedModelDeploymentState(PersistentTasksCustomMetadata.PersistentTask<?> task) {
-        if (task == null) {
-            return TrainedModelDeploymentState.STOPPED;
-        }
-        TrainedModelDeploymentTaskState taskState = (TrainedModelDeploymentTaskState) task.getState();
-        if (taskState == null) {
-            return TrainedModelDeploymentState.STARTING;
-        }
-
-        TrainedModelDeploymentState state = taskState.getState();
-        if (taskState.isStatusStale(task)) {
-            if (state == TrainedModelDeploymentState.STOPPING) {
-                // previous executor node failed while the job was stopping - it won't
-                // be restarted on another node, so consider it STOPPED for reassignment purposes
-                return TrainedModelDeploymentState.STOPPED;
-            }
-            if (state != TrainedModelDeploymentState.FAILED) {
-                // we are relocating at the moment
-                // TODO Revisit this in the new allocation framework as there won't necessarily be a concept of relocation.
-                return TrainedModelDeploymentState.STARTING;
-            }
-        }
-        return state;
-    }
-
     /**
      * The job Ids of anomaly detector job tasks.
      * All anomaly detector jobs are returned regardless of the status of the
@@ -435,8 +410,6 @@ public final class MlTasks {
                 return taskState == null ? SnapshotUpgradeState.LOADING_OLD_STATE : taskState.getState();
             case DATA_FRAME_ANALYTICS_TASK_NAME:
                 return getDataFrameAnalyticsState(task);
-            case TRAINED_MODEL_DEPLOYMENT_TASK_NAME:
-                return getTrainedModelDeploymentState(task);
             default:
                 throw new IllegalStateException("unexpected task type [" + task.getTaskName() + "]");
         }

+ 3 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.support.master.MasterNodeRequest;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ParseField;
@@ -23,7 +24,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
-import org.elasticsearch.persistent.PersistentTaskParams;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
@@ -122,13 +122,13 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         }
     }
 
-    public static class TaskParams implements PersistentTaskParams, MlTaskParams {
+    public static class TaskParams implements MlTaskParams, Writeable, ToXContentObject {
 
         // TODO add support for other roles? If so, it may have to be an instance method...
         // NOTE, whatever determines allocation should not be dynamically set on the node
         // Otherwise allocation logic might fail
         public static boolean mayAllocateToNode(DiscoveryNode node) {
-            return node.getRoles().contains(DiscoveryNodeRole.ML_ROLE);
+            return node.getRoles().contains(DiscoveryNodeRole.ML_ROLE) && node.getVersion().onOrAfter(VERSION_INTRODUCED);
         }
 
         public static final Version VERSION_INTRODUCED = Version.V_8_0_0;
@@ -187,12 +187,6 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             return MEMORY_OVERHEAD.getBytes() + 2 * modelBytes;
         }
 
-        @Override
-        public String getWriteableName() {
-            return MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME;
-        }
-
-        @Override
         public Version getMinimalSupportedVersion() {
             return VERSION_INTRODUCED;
         }

+ 21 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/RoutingState.java

@@ -7,9 +7,12 @@
 
 package org.elasticsearch.xpack.core.ml.inference.allocation;
 
+import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
+
+import java.util.Arrays;
 import java.util.Locale;
 
-public enum RoutingState {
+public enum RoutingState implements MemoryTrackedTaskState {
     STARTING,
     STARTED,
     STOPPING,
@@ -20,8 +23,25 @@ public enum RoutingState {
         return valueOf(value.toUpperCase(Locale.ROOT));
     }
 
+    /**
+     * @return {@code true} if state matches none of the given {@code candidates}
+     */
+    public boolean isNoneOf(RoutingState... candidates) {
+        return Arrays.stream(candidates).noneMatch(candidate -> this == candidate);
+    }
+
     @Override
     public String toString() {
         return name().toLowerCase(Locale.ROOT);
     }
+
+    @Override
+    public boolean consumesMemory() {
+        return isNoneOf(FAILED, STOPPED);
+    }
+
+    @Override
+    public boolean isAllocating() {
+        return this == STARTING;
+    }
 }

+ 0 - 52
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentState.java

@@ -1,52 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.core.ml.inference.deployment;
-
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
-
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.Locale;
-
-public enum TrainedModelDeploymentState implements Writeable, MemoryTrackedTaskState {
-
-    STARTING, STARTED, STOPPING, STOPPED, FAILED;
-
-    public static TrainedModelDeploymentState fromString(String name) {
-        return valueOf(name.trim().toUpperCase(Locale.ROOT));
-    }
-
-    public static TrainedModelDeploymentState fromStream(StreamInput in) throws IOException {
-        return in.readEnum(TrainedModelDeploymentState.class);
-    }
-
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        out.writeEnum(this);
-    }
-
-    @Override
-    public String toString() {
-        return name().toLowerCase(Locale.ROOT);
-    }
-
-    /**
-     * @return {@code true} if state matches none of the given {@code candidates}
-     */
-    public boolean isNoneOf(TrainedModelDeploymentState... candidates) {
-        return Arrays.stream(candidates).noneMatch(candidate -> this == candidate);
-    }
-
-    @Override
-    public boolean consumesMemory() {
-        return isNoneOf(FAILED, STOPPED);
-    }
-}

+ 0 - 117
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/deployment/TrainedModelDeploymentTaskState.java

@@ -1,117 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.core.ml.inference.deployment;
-
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.common.xcontent.ParseField;
-import org.elasticsearch.common.io.stream.StreamInput;
-import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.xcontent.ConstructingObjectParser;
-import org.elasticsearch.common.xcontent.XContentBuilder;
-import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.persistent.PersistentTaskState;
-import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
-import org.elasticsearch.xpack.core.ml.MlTasks;
-import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
-
-import java.io.IOException;
-import java.util.Objects;
-
-public class TrainedModelDeploymentTaskState implements PersistentTaskState {
-
-    public static final String NAME = MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME;
-
-    private static ParseField STATE = new ParseField("state");
-    private static ParseField ALLOCATION_ID = new ParseField("allocation_id");
-    private static ParseField REASON = new ParseField("reason");
-
-    private final TrainedModelDeploymentState state;
-    private final long allocationId;
-    private final String reason;
-
-    private static final ConstructingObjectParser<TrainedModelDeploymentTaskState, Void> PARSER =
-        new ConstructingObjectParser<>(NAME, true,
-            a -> new TrainedModelDeploymentTaskState((TrainedModelDeploymentState) a[0], (long) a[1], (String) a[2]));
-
-    static {
-        PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsState::fromString, STATE);
-        PARSER.declareLong(ConstructingObjectParser.constructorArg(), ALLOCATION_ID);
-        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REASON);
-    }
-
-    public static TrainedModelDeploymentTaskState fromXContent(XContentParser parser) {
-        try {
-            return PARSER.parse(parser, null);
-        } catch (IOException e) {
-            throw new RuntimeException(e);
-        }
-    }
-
-    public TrainedModelDeploymentTaskState(TrainedModelDeploymentState state, long allocationId, @Nullable String reason) {
-        this.state = Objects.requireNonNull(state);
-        this.allocationId = allocationId;
-        this.reason = reason;
-    }
-
-    public TrainedModelDeploymentTaskState(StreamInput in) throws IOException {
-        this.state = TrainedModelDeploymentState.fromStream(in);
-        this.allocationId = in.readLong();
-        this.reason = in.readOptionalString();
-    }
-
-    public TrainedModelDeploymentState getState() {
-        return state;
-    }
-
-    public String getReason() {
-        return reason;
-    }
-
-    @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
-        builder.field(STATE.getPreferredName(), state.toString());
-        builder.field(ALLOCATION_ID.getPreferredName(), allocationId);
-        if (reason != null) {
-            builder.field(REASON.getPreferredName(), reason);
-        }
-        builder.endObject();
-        return builder;
-    }
-
-    @Override
-    public String getWriteableName() {
-        return NAME;
-    }
-
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        state.writeTo(out);
-        out.writeLong(allocationId);
-        out.writeOptionalString(reason);
-    }
-
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) return true;
-        if (o == null || getClass() != o.getClass()) return false;
-        TrainedModelDeploymentTaskState that = (TrainedModelDeploymentTaskState) o;
-        return allocationId == that.allocationId &&
-            state == that.state &&
-            Objects.equals(reason, that.reason);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(state, allocationId, reason);
-    }
-
-    public boolean isStatusStale(PersistentTasksCustomMetadata.PersistentTask<?> task) {
-        return allocationId != task.getAllocationId();
-    }
-}

+ 0 - 53
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java

@@ -16,17 +16,13 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
-import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
-import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentState;
-import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 
 import java.net.InetAddress;
-import java.util.Arrays;
 
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -308,41 +304,6 @@ public class MlTasksTests extends ESTestCase {
         assertThat(state, equalTo(DataFrameAnalyticsState.FAILED));
     }
 
-    public void testGetTrainedModelDeploymentState_GivenNull() {
-        assertThat(MlTasks.getTrainedModelDeploymentState(null), equalTo(TrainedModelDeploymentState.STOPPED));
-    }
-
-    public void testGetTrainedModelDeploymentState_GivenTaskStateIsNull() {
-        PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(null, false);
-        assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STARTING));
-    }
-
-    public void testGetTrainedModelDeploymentState_GivenTaskStateIsNotNullAndNotStale() {
-        TrainedModelDeploymentState state = randomFrom(TrainedModelDeploymentState.values());
-        PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(state, false);
-        assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(state));
-    }
-
-    public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndStopping() {
-        PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(TrainedModelDeploymentState.STOPPING, true);
-        assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STOPPED));
-    }
-
-    public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndFailed() {
-        PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(TrainedModelDeploymentState.FAILED, true);
-        assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.FAILED));
-    }
-
-    public void testGetTrainedModelDeploymentState_GivenTaskStateIsStaleAndNotFailedNorStopping() {
-        TrainedModelDeploymentState state = randomFrom(
-            Arrays.stream(TrainedModelDeploymentState.values())
-                .filter(s -> s != TrainedModelDeploymentState.FAILED && s != TrainedModelDeploymentState.STOPPING)
-                .toArray(TrainedModelDeploymentState[]::new)
-        );
-        PersistentTasksCustomMetadata.PersistentTask<?> task = createTrainedModelTask(state, true);
-        assertThat(MlTasks.getTrainedModelDeploymentState(task), equalTo(TrainedModelDeploymentState.STARTING));
-    }
-
     private static PersistentTasksCustomMetadata.PersistentTask<?> createDataFrameAnalyticsTask(String jobId, String nodeId,
                                                                                                 DataFrameAnalyticsState state,
                                                                                                 boolean isStale) {
@@ -358,18 +319,4 @@ public class MlTasksTests extends ESTestCase {
         return tasks.getTask(MlTasks.dataFrameAnalyticsTaskId(jobId));
     }
 
-    private static PersistentTasksCustomMetadata.PersistentTask<?> createTrainedModelTask(TrainedModelDeploymentState state,
-                                                                                          boolean isStale) {
-        String id = randomAlphaOfLength(10);
-        PersistentTasksCustomMetadata.Builder builder = PersistentTasksCustomMetadata.builder();
-        builder.addTask(MlTasks.trainedModelDeploymentTaskId(id), MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME,
-            new StartTrainedModelDeploymentAction.TaskParams(id, randomAlphaOfLength(10), randomNonNegativeLong()),
-            new PersistentTasksCustomMetadata.Assignment(randomAlphaOfLength(10), "test assignment"));
-        if (state != null) {
-            builder.updateTaskState(MlTasks.trainedModelDeploymentTaskId(id),
-                new TrainedModelDeploymentTaskState(state, builder.getLastAllocationId() - (isStale ? 1 : 0), null));
-        }
-        PersistentTasksCustomMetadata tasks = builder.build();
-        return tasks.getTask(MlTasks.trainedModelDeploymentTaskId(id));
-    }
 }

+ 0 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -171,7 +171,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNam
 import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStatsNamedWriteablesProvider;
 import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
-import org.elasticsearch.xpack.core.ml.inference.deployment.TrainedModelDeploymentTaskState;
 import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
@@ -1287,8 +1286,6 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
             StartDataFrameAnalyticsAction.TaskParams::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME,
             SnapshotUpgradeTaskParams::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME,
-            StartTrainedModelDeploymentAction.TaskParams::new));
 
         // Persistent task states
         namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, JobTaskState.NAME, JobTaskState::new));
@@ -1298,8 +1295,6 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
         namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class,
             SnapshotUpgradeTaskState.NAME,
             SnapshotUpgradeTaskState::new));
-        namedWriteables.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class,
-            TrainedModelDeploymentTaskState.NAME, TrainedModelDeploymentTaskState::new));
 
         namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables());
         namedWriteables.addAll(new AnalysisStatsNamedWriteablesProvider().getNamedWriteables());

+ 4 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadata.java

@@ -151,7 +151,7 @@ public class TrainedModelAllocationMetadata implements Metadata.Custom {
             return modelRoutingEntries.containsKey(modelId);
         }
 
-        Builder addNewAllocation(StartTrainedModelDeploymentAction.TaskParams taskParams) {
+        public Builder addNewAllocation(StartTrainedModelDeploymentAction.TaskParams taskParams) {
             if (modelRoutingEntries.containsKey(taskParams.getModelId())) {
                 return this;
             }
@@ -160,7 +160,7 @@ public class TrainedModelAllocationMetadata implements Metadata.Custom {
             return this;
         }
 
-        Builder updateAllocation(String modelId, String nodeId, RoutingStateAndReason state) {
+        public Builder updateAllocation(String modelId, String nodeId, RoutingStateAndReason state) {
             TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId);
             if (allocation == null) {
                 return this;
@@ -169,7 +169,7 @@ public class TrainedModelAllocationMetadata implements Metadata.Custom {
             return this;
         }
 
-        Builder addNode(String modelId, String nodeId) {
+        public Builder addNode(String modelId, String nodeId) {
             TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId);
             if (allocation == null) {
                 throw new ResourceNotFoundException(
@@ -182,7 +182,7 @@ public class TrainedModelAllocationMetadata implements Metadata.Custom {
             return this;
         }
 
-        Builder addFailedNode(String modelId, String nodeId, String reason) {
+        public Builder addFailedNode(String modelId, String nodeId, String reason) {
             TrainedModelAllocation.Builder allocation = modelRoutingEntries.get(modelId);
             if (allocation == null) {
                 throw new ResourceNotFoundException(

+ 5 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java

@@ -44,9 +44,11 @@ import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentLinkedDeque;
 
+import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX;
+import static org.elasticsearch.xpack.core.ml.MlTasks.TRAINED_MODEL_ALLOCATION_TASK_TYPE;
+
 public class TrainedModelAllocationNodeService implements ClusterStateListener {
 
-    private static final String TASK_NAME = "trained_model_allocation";
     private static final TimeValue MODEL_LOADING_CHECK_INTERVAL = TimeValue.timeValueSeconds(1);
     private static final Logger logger = LogManager.getLogger(TrainedModelAllocationNodeService.class);
     private final TrainedModelAllocationService trainedModelAllocationService;
@@ -286,8 +288,8 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
 
     void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams) {
         TrainedModelDeploymentTask task = (TrainedModelDeploymentTask) taskManager.register(
-            TASK_NAME,
-            taskParams.getModelId(),
+            TRAINED_MODEL_ALLOCATION_TASK_TYPE,
+            TRAINED_MODEL_ALLOCATION_TASK_NAME_PREFIX + taskParams.getModelId(),
             taskAwareRequest(taskParams)
         );
         // threadsafe check to verify we are not loading/loaded the model

+ 20 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java

@@ -11,9 +11,13 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.xpack.core.ml.MlTasks;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
 import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.elasticsearch.xpack.ml.utils.NativeMemoryCalculator;
 
@@ -21,6 +25,7 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.OptionalLong;
 import java.util.stream.Collectors;
 
@@ -75,6 +80,7 @@ public class NodeLoadDetector {
             return nodeLoad.setError(Strings.collectionToCommaDelimitedString(errors)).build();
         }
         updateLoadGivenTasks(nodeLoad, persistentTasks);
+        updateLoadGivenModelAllocations(nodeLoad, TrainedModelAllocationMetadata.fromState(clusterState));
         return nodeLoad.build();
     }
 
@@ -98,6 +104,19 @@ public class NodeLoadDetector {
         }
     }
 
+    private void updateLoadGivenModelAllocations(NodeLoad.Builder nodeLoad, TrainedModelAllocationMetadata trainedModelAllocationMetadata) {
+        if (trainedModelAllocationMetadata != null && trainedModelAllocationMetadata.modelAllocations().isEmpty() == false) {
+            for (TrainedModelAllocation allocation : trainedModelAllocationMetadata.modelAllocations().values()) {
+                if (Optional.ofNullable(allocation.getNodeRoutingTable().get(nodeLoad.getNodeId()))
+                    .map(RoutingStateAndReason::getState)
+                    .orElse(RoutingState.STOPPED)
+                    .consumesMemory()) {
+                    nodeLoad.incAssignedJobMemory(allocation.getTaskParams().estimateMemoryUsageBytes());
+                }
+            }
+        }
+    }
+
     private static Collection<PersistentTasksCustomMetadata.PersistentTask<?>> findAllMemoryTrackedTasks(
         PersistentTasksCustomMetadata persistentTasks, String nodeId) {
         return persistentTasks.tasks().stream()
@@ -109,8 +128,7 @@ public class NodeLoadDetector {
     private static boolean isMemoryTrackedTask(PersistentTasksCustomMetadata.PersistentTask<?> task) {
         return MlTasks.JOB_TASK_NAME.equals(task.getTaskName())
             || MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME.equals(task.getTaskName())
-            || MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME.equals(task.getTaskName())
-            || MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME.equals(task.getTaskName());
+            || MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME.equals(task.getTaskName());
     }
 
 }

+ 13 - 16
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java

@@ -26,10 +26,12 @@ import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.job.JobManager;
 import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider;
 
@@ -40,6 +42,7 @@ import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.TreeMap;
 import java.util.concurrent.ConcurrentHashMap;
@@ -222,18 +225,15 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
      * @return The memory requirement of the trained model task specified by {@code modelId},
      *         or <code>null</code> if it cannot be found.
      */
-    public Long getTrainedModelTaskMemoryRequirement(String modelId) {
+    public Long getTrainedModelAllocationMemoryRequirement(String modelId) {
         if (isMaster == false) {
             return null;
         }
 
-        PersistentTasksCustomMetadata tasks = clusterService.state().getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
-        PersistentTasksCustomMetadata.PersistentTask<?> task = MlTasks.getTrainedModelDeploymentTask(modelId, tasks);
-        if (task == null) {
-            return null;
-        }
-        StartTrainedModelDeploymentAction.TaskParams taskParams = (StartTrainedModelDeploymentAction.TaskParams) task.getParams();
-        return taskParams.estimateMemoryUsageBytes();
+        return Optional.ofNullable(TrainedModelAllocationMetadata.fromState(clusterService.state()).modelAllocations().get(modelId))
+            .map(TrainedModelAllocation::getTaskParams)
+            .map(StartTrainedModelDeploymentAction.TaskParams::estimateMemoryUsageBytes)
+            .orElse(null);
     }
 
     /**
@@ -250,15 +250,12 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
             return null;
         }
 
-        if (MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME.equals(taskName)) {
-            return getTrainedModelTaskMemoryRequirement(id);
-        } else {
-            Map<String, Long> memoryRequirementByJob = memoryRequirementByTaskName.get(taskName);
-            if (memoryRequirementByJob == null) {
-                return null;
-            }
-            return memoryRequirementByJob.get(id);
+        Map<String, Long> memoryRequirementByJob = memoryRequirementByTaskName.get(taskName);
+        if (memoryRequirementByJob == null) {
+            assert false: "Unknown taskName type [" + taskName +"]";
+            return null;
         }
+        return memoryRequirementByJob.get(id);
     }
 
     /**

+ 11 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java

@@ -185,6 +185,7 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                     .add(buildNode("ml-node-without-room", true, 1000L))
                     .add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes()))
                     .add(buildNode("ml-node-shutting-down", true, ByteSizeValue.ofGb(4).getBytes()))
+                    .add(buildOldNode("old-ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes()))
                     .build()
             )
             .metadata(Metadata.builder().putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata("ml-node-shutting-down")))
@@ -220,6 +221,7 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                     .add(buildNode("ml-node-without-room", true, 1000L))
                     .add(buildNode("not-ml-node", false, ByteSizeValue.ofGb(4).getBytes()))
                     .add(buildNode("ml-node-shutting-down", true, ByteSizeValue.ofGb(4).getBytes()))
+                    .add(buildOldNode("old-versioned-ml-node-with-room", true, ByteSizeValue.ofGb(4).getBytes()))
                     .build()
             )
             .metadata(
@@ -632,6 +634,10 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
     }
 
     private static DiscoveryNode buildNode(String name, boolean isML, long nativeMemory) {
+        return buildNode(name, isML, nativeMemory, Version.CURRENT);
+    }
+
+    private static DiscoveryNode buildNode(String name, boolean isML, long nativeMemory, Version version) {
         return new DiscoveryNode(
             name,
             name,
@@ -642,10 +648,14 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
                 .put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, String.valueOf(10))
                 .map(),
             isML ? DiscoveryNodeRole.roles() : Set.of(DiscoveryNodeRole.DATA_ROLE, DiscoveryNodeRole.MASTER_ROLE),
-            Version.CURRENT
+            version
         );
     }
 
+    private static DiscoveryNode buildOldNode(String name, boolean isML, long nativeMemory) {
+        return buildNode(name, isML, nativeMemory, Version.V_7_15_0);
+    }
+
     private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId, long modelSize) {
         return new StartTrainedModelDeploymentAction.TaskParams(modelId, "test-index", modelSize);
     }

+ 30 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java

@@ -16,8 +16,12 @@ import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
+import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.ml.MachineLearning;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutorTests;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.junit.Before;
@@ -39,6 +43,8 @@ public class NodeLoadDetectorTests extends ESTestCase {
     // To simplify the logic in this class all jobs have the same memory requirement
     private static final ByteSizeValue JOB_MEMORY_REQUIREMENT = ByteSizeValue.ofMb(10);
 
+    private static final long MODEL_MEMORY_REQUIREMENT = ByteSizeValue.ofMb(50).getBytes();
+
     private NodeLoadDetector nodeLoadDetector;
 
     @Before
@@ -75,7 +81,29 @@ public class NodeLoadDetectorTests extends ESTestCase {
         PersistentTasksCustomMetadata tasks = tasksBuilder.build();
 
         final ClusterState cs = ClusterState.builder(new ClusterName("_name")).nodes(nodes)
-                .metadata(Metadata.builder().putCustom(PersistentTasksCustomMetadata.TYPE, tasks)).build();
+                .metadata(
+                    Metadata.builder()
+                        .putCustom(PersistentTasksCustomMetadata.TYPE, tasks)
+                        .putCustom(
+                            TrainedModelAllocationMetadata.NAME,
+                            TrainedModelAllocationMetadata.Builder.empty()
+                                .addNewAllocation(
+                                    new StartTrainedModelDeploymentAction.TaskParams("model1", "any-index", MODEL_MEMORY_REQUIREMENT)
+                                )
+                                .addNode("model1", "_node_id4")
+                                .addFailedNode("model1", "_node_id2", "test")
+                                .addNode("model1", "_node_id1")
+                                .updateAllocation(
+                                    "model1",
+                                    "_node_id1",
+                                    new RoutingStateAndReason(
+                                        randomFrom(RoutingState.STOPPED, RoutingState.FAILED),
+                                        "test"
+                                    )
+                                )
+                                .build()
+                        )
+                ).build();
 
         NodeLoad load = nodeLoadDetector.detectNodeLoad(cs, true, nodes.get("_node_id1"), 10, 30, false);
         assertThat(load.getAssignedJobMemory(), equalTo(52428800L));
@@ -99,7 +127,7 @@ public class NodeLoadDetectorTests extends ESTestCase {
         assertThat(load.getMaxMlMemory(), equalTo(0L));
 
         load = nodeLoadDetector.detectNodeLoad(cs, true, nodes.get("_node_id4"), 5, 30, false);
-        assertThat(load.getAssignedJobMemory(), equalTo(41943040L));
+        assertThat(load.getAssignedJobMemory(), equalTo(429916160L));
         assertThat(load.getNumAllocatingJobs(), equalTo(0L));
         assertThat(load.getNumAssignedJobs(), equalTo(1L));
         assertThat(load.getMaxJobs(), equalTo(5));

+ 0 - 17
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java

@@ -21,7 +21,6 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
-import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
@@ -110,13 +109,6 @@ public class MlMemoryTrackerTests extends ESTestCase {
             tasks.put(task.getId(), task);
         }
 
-        int numTrainedModelTasks = randomIntBetween(2, 5);
-        for (int i = 1; i <= numTrainedModelTasks; ++i) {
-            String id = "trained_model_" + i;
-            PersistentTasksCustomMetadata.PersistentTask<?> task = makeTestTrainedModelTask(id, randomLongBetween(1000, 1000000));
-            tasks.put(task.getId(), task);
-        }
-
         PersistentTasksCustomMetadata persistentTasks =
             new PersistentTasksCustomMetadata(numAnomalyDetectorJobTasks + numDataFrameAnalyticsTasks, tasks);
 
@@ -289,13 +281,4 @@ public class MlMemoryTrackerTests extends ESTestCase {
             0, PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT);
     }
 
-    private PersistentTasksCustomMetadata.PersistentTask<StartTrainedModelDeploymentAction.TaskParams> makeTestTrainedModelTask(
-        String id, long memUsage) {
-        return new PersistentTasksCustomMetadata.PersistentTask<>(MlTasks.trainedModelDeploymentTaskId(id),
-            MlTasks.TRAINED_MODEL_DEPLOYMENT_TASK_NAME,
-            new StartTrainedModelDeploymentAction.TaskParams(id, randomAlphaOfLength(10), memUsage),
-            0,
-            PersistentTasksCustomMetadata.INITIAL_ASSIGNMENT
-        );
-    }
 }