1
0
Эх сурвалжийг харах

[ML] Avoid NPE when node load is calculated on job assignment (#49186)

This commit fixes a NPE problem as reported in #49150.
But this problem uncovered that we never added proper handling
of state for data frame analytics tasks.

In this commit we improve the `MlTasks.getDataFrameAnalyticsState`
method to handle null tasks and state tasks properly.

Closes #49150
Dimitris Athanasiou 6 жил өмнө
parent
commit
5857184611

+ 22 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java

@@ -143,14 +143,31 @@ public final class MlTasks {
 
     public static DataFrameAnalyticsState getDataFrameAnalyticsState(String analyticsId, @Nullable PersistentTasksCustomMetaData tasks) {
         PersistentTasksCustomMetaData.PersistentTask<?> task = getDataFrameAnalyticsTask(analyticsId, tasks);
-        if (task != null) {
-            DataFrameAnalyticsTaskState taskState = (DataFrameAnalyticsTaskState) task.getState();
-            if (taskState == null) {
+        return getDataFrameAnalyticsState(task);
+    }
+
+    public static DataFrameAnalyticsState getDataFrameAnalyticsState(@Nullable PersistentTasksCustomMetaData.PersistentTask<?> task) {
+        if (task == null) {
+            return DataFrameAnalyticsState.STOPPED;
+        }
+        DataFrameAnalyticsTaskState taskState = (DataFrameAnalyticsTaskState) task.getState();
+        if (taskState == null) {
+            return DataFrameAnalyticsState.STARTING;
+        }
+
+        DataFrameAnalyticsState state = taskState.getState();
+        if (taskState.isStatusStale(task)) {
+            if (state == DataFrameAnalyticsState.STOPPING) {
+                // previous executor node failed while the job was stopping - it won't
+                // be restarted on another node, so consider it STOPPED for reassignment purposes
+                return DataFrameAnalyticsState.STOPPED;
+            }
+            if (state != DataFrameAnalyticsState.FAILED) {
+                // we are relocating at the moment
                 return DataFrameAnalyticsState.STARTING;
             }
-            return taskState.getState();
         }
-        return DataFrameAnalyticsState.STOPPED;
+        return state;
     }
 
     /**

+ 134 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java

@@ -13,18 +13,23 @@ import org.elasticsearch.common.transport.TransportAddress;
 import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
+import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.config.JobTaskState;
 
 import java.net.InetAddress;
+import java.util.Collections;
 
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
 
 public class MlTasksTests extends ESTestCase {
+
     public void testGetJobState() {
         PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
         // A missing task is a closed job
@@ -168,4 +173,133 @@ public class MlTasksTests extends ESTestCase {
         assertThat(taskId, equalTo("data_frame_analytics-foo"));
         assertThat(MlTasks.dataFrameAnalyticsIdFromTaskId(taskId), equalTo("foo"));
     }
+
+    public void testGetDataFrameAnalyticsState_GivenNullTask() {
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(null);
+        assertThat(state, equalTo(DataFrameAnalyticsState.STOPPED));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenTaskWithNullState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node", null, false);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.STARTING));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenTaskWithStartedState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
+            DataFrameAnalyticsState.STARTED, false);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.STARTED));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenStaleTaskWithStartedState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
+            DataFrameAnalyticsState.STARTED, true);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.STARTING));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenTaskWithReindexingState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
+            DataFrameAnalyticsState.REINDEXING, false);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.REINDEXING));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenStaleTaskWithReindexingState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
+            DataFrameAnalyticsState.REINDEXING, true);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.STARTING));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenTaskWithAnalyzingState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
+            DataFrameAnalyticsState.ANALYZING, false);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.ANALYZING));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenStaleTaskWithAnalyzingState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
+            DataFrameAnalyticsState.ANALYZING, true);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.STARTING));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenTaskWithStoppingState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
+            DataFrameAnalyticsState.STOPPING, false);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.STOPPING));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenStaleTaskWithStoppingState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
+            DataFrameAnalyticsState.STOPPING, true);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.STOPPED));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenTaskWithFailedState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
+            DataFrameAnalyticsState.FAILED, false);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.FAILED));
+    }
+
+    public void testGetDataFrameAnalyticsState_GivenStaleTaskWithFailedState() {
+        String jobId = "foo";
+        PersistentTasksCustomMetaData.PersistentTask<?> task = createDataFrameAnalyticsTask(jobId, "test_node",
+            DataFrameAnalyticsState.FAILED, true);
+
+        DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(task);
+
+        assertThat(state, equalTo(DataFrameAnalyticsState.FAILED));
+    }
+
+    private static PersistentTasksCustomMetaData.PersistentTask<?> createDataFrameAnalyticsTask(String jobId, String nodeId,
+                                                                                                DataFrameAnalyticsState state,
+                                                                                                boolean isStale) {
+        PersistentTasksCustomMetaData.Builder builder = PersistentTasksCustomMetaData.builder();
+        builder.addTask(MlTasks.dataFrameAnalyticsTaskId(jobId), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME,
+            new StartDataFrameAnalyticsAction.TaskParams(jobId, Version.CURRENT, Collections.emptyList(), false),
+            new PersistentTasksCustomMetaData.Assignment(nodeId, "test assignment"));
+        if (state != null) {
+            builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(jobId),
+                new DataFrameAnalyticsTaskState(state, builder.getLastAllocationId() - (isStale ? 1 : 0), null));
+        }
+        PersistentTasksCustomMetaData tasks = builder.build();
+        return tasks.getTask(MlTasks.dataFrameAnalyticsTaskId(jobId));
+    }
 }

+ 1 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java

@@ -16,7 +16,6 @@ import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
-import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
@@ -269,7 +268,7 @@ public class JobNodeSelector {
             Collection<PersistentTasksCustomMetaData.PersistentTask<?>> assignedAnalyticsTasks = persistentTasks.findTasks(
                 MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, task -> node.getId().equals(task.getExecutorNode()));
             for (PersistentTasksCustomMetaData.PersistentTask<?> assignedTask : assignedAnalyticsTasks) {
-                DataFrameAnalyticsState dataFrameAnalyticsState = ((DataFrameAnalyticsTaskState) assignedTask.getState()).getState();
+                DataFrameAnalyticsState dataFrameAnalyticsState = MlTasks.getDataFrameAnalyticsState(assignedTask);
 
                 // Don't count stopped and failed df-analytics tasks as they don't consume native memory
                 if (dataFrameAnalyticsState.isAnyOf(DataFrameAnalyticsState.STOPPED, DataFrameAnalyticsState.FAILED) == false) {

+ 29 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java

@@ -200,6 +200,27 @@ public class JobNodeSelectorTests extends ESTestCase {
             + currentlyRunningJobMemory + "], estimated memory required for this job [" + JOB_MEMORY_REQUIREMENT.getBytes() + "]"));
     }
 
+    public void testSelectLeastLoadedMlNodeForDataFrameAnalyticsJob_givenTaskHasNullState() {
+        int numNodes = randomIntBetween(1, 10);
+        int maxRunningJobsPerNode = 10;
+        int maxMachineMemoryPercent = 30;
+
+        Map<String, String> nodeAttr = new HashMap<>();
+        nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, Integer.toString(maxRunningJobsPerNode));
+        nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "-1");
+
+        ClusterState.Builder cs = fillNodesWithRunningJobs(nodeAttr, numNodes, 1, JobState.OPENED, null);
+
+        String dataFrameAnalyticsId = "data_frame_analytics_id_new";
+
+        JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), dataFrameAnalyticsId,
+            MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, memoryTracker, 0,
+            node -> TransportStartDataFrameAnalyticsAction.TaskExecutor.nodeFilter(node, dataFrameAnalyticsId));
+        PersistentTasksCustomMetaData.Assignment result =
+            jobNodeSelector.selectNode(maxRunningJobsPerNode, 2, maxMachineMemoryPercent, isMemoryTrackerRecentlyRefreshed);
+        assertNotNull(result.getExecutorNode());
+    }
+
     public void testSelectLeastLoadedMlNodeForAnomalyDetectorJob_firstJobTooBigMemoryLimiting() {
         int numNodes = randomIntBetween(1, 10);
         int maxRunningJobsPerNode = randomIntBetween(1, 100);
@@ -579,6 +600,12 @@ public class JobNodeSelectorTests extends ESTestCase {
 
     private ClusterState.Builder fillNodesWithRunningJobs(Map<String, String> nodeAttr, int numNodes, int numRunningJobsPerNode) {
 
+        return fillNodesWithRunningJobs(nodeAttr, numNodes, numRunningJobsPerNode, JobState.OPENED, DataFrameAnalyticsState.STARTED);
+    }
+
+    private ClusterState.Builder fillNodesWithRunningJobs(Map<String, String> nodeAttr, int numNodes, int numRunningJobsPerNode,
+                                                          JobState anomalyDetectionJobState, DataFrameAnalyticsState dfAnalyticsJobState) {
+
         DiscoveryNodes.Builder nodes = DiscoveryNodes.builder();
         PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
         String[] jobIds = new String[numNodes * numRunningJobsPerNode];
@@ -591,10 +618,10 @@ public class JobNodeSelectorTests extends ESTestCase {
                 // Both anomaly detector jobs and data frame analytics jobs should count towards the limit
                 if (randomBoolean()) {
                     jobIds[id] = "job_id" + id;
-                    TransportOpenJobActionTests.addJobTask(jobIds[id], nodeId, JobState.OPENED, tasksBuilder);
+                    TransportOpenJobActionTests.addJobTask(jobIds[id], nodeId, anomalyDetectionJobState, tasksBuilder);
                 } else {
                     jobIds[id] = "data_frame_analytics_id" + id;
-                    addDataFrameAnalyticsJobTask(jobIds[id], nodeId, DataFrameAnalyticsState.STARTED, tasksBuilder);
+                    addDataFrameAnalyticsJobTask(jobIds[id], nodeId, dfAnalyticsJobState, tasksBuilder);
                 }
             }
         }