|
@@ -5,6 +5,7 @@
|
|
|
*/
|
|
|
package org.elasticsearch.xpack.ml.dataframe;
|
|
|
|
|
|
+import org.elasticsearch.Version;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.index.IndexAction;
|
|
|
import org.elasticsearch.action.index.IndexRequest;
|
|
@@ -12,16 +13,25 @@ import org.elasticsearch.action.index.IndexResponse;
|
|
|
import org.elasticsearch.action.search.SearchAction;
|
|
|
import org.elasticsearch.action.search.SearchResponse;
|
|
|
import org.elasticsearch.client.Client;
|
|
|
+import org.elasticsearch.cluster.service.ClusterService;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
|
|
+import org.elasticsearch.persistent.PersistentTasksService;
|
|
|
+import org.elasticsearch.persistent.UpdatePersistentTaskStatusAction;
|
|
|
import org.elasticsearch.search.SearchHit;
|
|
|
import org.elasticsearch.search.SearchHits;
|
|
|
+import org.elasticsearch.tasks.TaskManager;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
import org.elasticsearch.threadpool.ThreadPool;
|
|
|
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
|
|
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
|
|
|
+import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
|
|
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
|
|
|
+import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
|
|
|
+import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
|
|
import org.mockito.ArgumentCaptor;
|
|
|
import org.mockito.InOrder;
|
|
|
import org.mockito.stubbing.Answer;
|
|
@@ -34,9 +44,13 @@ import java.util.Map;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
import static org.mockito.Matchers.any;
|
|
|
import static org.mockito.Matchers.eq;
|
|
|
+import static org.mockito.Matchers.same;
|
|
|
+import static org.mockito.Mockito.atLeastOnce;
|
|
|
import static org.mockito.Mockito.doAnswer;
|
|
|
import static org.mockito.Mockito.inOrder;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
+import static org.mockito.Mockito.verifyNoMoreInteractions;
|
|
|
import static org.mockito.Mockito.when;
|
|
|
|
|
|
public class DataFrameAnalyticsTaskTests extends ESTestCase {
|
|
@@ -156,6 +170,56 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
|
|
|
".ml-state-dummy");
|
|
|
}
|
|
|
|
|
|
+ public void testSetFailed() {
|
|
|
+ testSetFailed(false);
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testSetFailedDuringNodeShutdown() {
|
|
|
+ testSetFailed(true);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void testSetFailed(boolean nodeShuttingDown) {
|
|
|
+ ThreadPool threadPool = mock(ThreadPool.class);
|
|
|
+ when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
|
|
+ Client client = mock(Client.class);
|
|
|
+ when(client.threadPool()).thenReturn(threadPool);
|
|
|
+ ClusterService clusterService = mock(ClusterService.class);
|
|
|
+ DataFrameAnalyticsManager analyticsManager = mock(DataFrameAnalyticsManager.class);
|
|
|
+ when(analyticsManager.isNodeShuttingDown()).thenReturn(nodeShuttingDown);
|
|
|
+ DataFrameAnalyticsAuditor auditor = mock(DataFrameAnalyticsAuditor.class);
|
|
|
+ PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, mock(ThreadPool.class), client);
|
|
|
+ TaskManager taskManager = mock(TaskManager.class);
|
|
|
+
|
|
|
+ StartDataFrameAnalyticsAction.TaskParams taskParams =
|
|
|
+ new StartDataFrameAnalyticsAction.TaskParams(
|
|
|
+ "job-id",
|
|
|
+ Version.CURRENT,
|
|
|
+ List.of(
|
|
|
+ new PhaseProgress(ProgressTracker.REINDEXING, 0),
|
|
|
+ new PhaseProgress(ProgressTracker.LOADING_DATA, 0),
|
|
|
+ new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0)),
|
|
|
+ false);
|
|
|
+ DataFrameAnalyticsTask task =
|
|
|
+ new DataFrameAnalyticsTask(
|
|
|
+ 123, "type", "action", null, Map.of(), client, clusterService, analyticsManager, auditor, taskParams);
|
|
|
+ task.init(persistentTasksService, taskManager, "task-id", 42);
|
|
|
+ Exception exception = new Exception("some exception");
|
|
|
+
|
|
|
+ task.setFailed(exception);
|
|
|
+
|
|
|
+ verify(analyticsManager).isNodeShuttingDown();
|
|
|
+ verify(client, atLeastOnce()).settings();
|
|
|
+ verify(client, atLeastOnce()).threadPool();
|
|
|
+ if (nodeShuttingDown == false) {
|
|
|
+ verify(client).execute(
|
|
|
+ same(UpdatePersistentTaskStatusAction.INSTANCE),
|
|
|
+ eq(new UpdatePersistentTaskStatusAction.Request(
|
|
|
+ "task-id", 42, new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, 42, "some exception"))),
|
|
|
+ any());
|
|
|
+ }
|
|
|
+ verifyNoMoreInteractions(client, clusterService, analyticsManager, auditor, taskManager);
|
|
|
+ }
|
|
|
+
|
|
|
@SuppressWarnings("unchecked")
|
|
|
private static <Response> Answer<Response> withResponse(Response response) {
|
|
|
return invocationOnMock -> {
|