1
0
Эх сурвалжийг харах

[ML] Fix master node deadlock during ML daily maintenance (#31836)

This is the implementation for master and 6.x of #31691.
Native tests are changed to use multi-node clusters in #31757.

Relates #31683
Dimitris Athanasiou 7 жил өмнө
parent
commit
49ba271bd8

+ 12 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteExpiredDataAction.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.action;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
+import org.elasticsearch.action.support.ThreadedActionListener;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
@@ -57,8 +58,8 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
         Auditor auditor = new Auditor(client, clusterService.nodeName());
         List<MlDataRemover> dataRemovers = Arrays.asList(
                 new ExpiredResultsRemover(client, clusterService, auditor),
-                new ExpiredForecastsRemover(client),
-                new ExpiredModelSnapshotsRemover(client, clusterService),
+                new ExpiredForecastsRemover(client, threadPool),
+                new ExpiredModelSnapshotsRemover(client, threadPool, clusterService),
                 new UnusedStateRemover(client, clusterService)
         );
         Iterator<MlDataRemover> dataRemoversIterator = new VolatileCursorIterator<>(dataRemovers);
@@ -69,9 +70,15 @@ public class TransportDeleteExpiredDataAction extends HandledTransportAction<Del
                                    ActionListener<DeleteExpiredDataAction.Response> listener) {
         if (mlDataRemoversIterator.hasNext()) {
             MlDataRemover remover = mlDataRemoversIterator.next();
-            remover.remove(ActionListener.wrap(
-                    booleanResponse -> deleteExpiredData(mlDataRemoversIterator, listener),
-                    listener::onFailure));
+            ActionListener<Boolean> nextListener = ActionListener.wrap(
+                    booleanResponse -> deleteExpiredData(mlDataRemoversIterator, listener), listener::onFailure);
+            // Removing expired ML data and artifacts requires multiple operations.
+            // These are queued up and executed sequentially in the action listener,
+            // the chained calls must all run the ML utility thread pool NOT the thread
+            // the previous action returned in which in the case of a transport_client_boss
+            // thread is a disaster.
+            remover.remove(new ThreadedActionListener<>(logger, threadPool, MachineLearning.UTILITY_THREAD_POOL_NAME, nextListener,
+                    false));
         } else {
             logger.info("Completed deletion of expired data");
             listener.onResponse(new DeleteExpiredDataAction.Response(true));

+ 8 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredForecastsRemover.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.support.ThreadedActionListener;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.common.logging.Loggers;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
@@ -27,11 +28,13 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.job.results.Forecast;
 import org.elasticsearch.xpack.core.ml.job.results.ForecastRequestStats;
 import org.elasticsearch.xpack.core.ml.job.results.Result;
+import org.elasticsearch.xpack.ml.MachineLearning;
 import org.joda.time.DateTime;
 import org.joda.time.chrono.ISOChronology;
 
@@ -57,10 +60,12 @@ public class ExpiredForecastsRemover implements MlDataRemover {
     private static final String RESULTS_INDEX_PATTERN =  AnomalyDetectorsIndex.jobResultsIndexPrefix() + "*";
 
     private final Client client;
+    private final ThreadPool threadPool;
     private final long cutoffEpochMs;
 
-    public ExpiredForecastsRemover(Client client) {
+    public ExpiredForecastsRemover(Client client, ThreadPool threadPool) {
         this.client = Objects.requireNonNull(client);
+        this.threadPool = Objects.requireNonNull(threadPool);
         this.cutoffEpochMs = DateTime.now(ISOChronology.getInstance()).getMillis();
     }
 
@@ -79,7 +84,8 @@ public class ExpiredForecastsRemover implements MlDataRemover {
 
         SearchRequest searchRequest = new SearchRequest(RESULTS_INDEX_PATTERN);
         searchRequest.source(source);
-        client.execute(SearchAction.INSTANCE, searchRequest, forecastStatsHandler);
+        client.execute(SearchAction.INSTANCE, searchRequest, new ThreadedActionListener<>(LOGGER, threadPool,
+                MachineLearning.UTILITY_THREAD_POOL_NAME, forecastStatsHandler, false));
     }
 
     private void deleteForecasts(SearchResponse searchResponse, ActionListener<Boolean> listener) {

+ 14 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemover.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.support.ThreadedActionListener;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.logging.Loggers;
@@ -18,11 +19,13 @@ import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshotField;
+import org.elasticsearch.xpack.ml.MachineLearning;
 
 import java.util.ArrayList;
 import java.util.Iterator;
@@ -51,10 +54,12 @@ public class ExpiredModelSnapshotsRemover extends AbstractExpiredJobDataRemover
     private static final int MODEL_SNAPSHOT_SEARCH_SIZE = 10000;
 
     private final Client client;
+    private final ThreadPool threadPool;
 
-    public ExpiredModelSnapshotsRemover(Client client, ClusterService clusterService) {
+    public ExpiredModelSnapshotsRemover(Client client, ThreadPool threadPool, ClusterService clusterService) {
         super(clusterService);
         this.client = Objects.requireNonNull(client);
+        this.threadPool = Objects.requireNonNull(threadPool);
     }
 
     @Override
@@ -84,7 +89,12 @@ public class ExpiredModelSnapshotsRemover extends AbstractExpiredJobDataRemover
 
         searchRequest.source(new SearchSourceBuilder().query(query).size(MODEL_SNAPSHOT_SEARCH_SIZE));
 
-        client.execute(SearchAction.INSTANCE, searchRequest, new ActionListener<SearchResponse>() {
+        client.execute(SearchAction.INSTANCE, searchRequest, new ThreadedActionListener<>(LOGGER, threadPool,
+                MachineLearning.UTILITY_THREAD_POOL_NAME, expiredSnapshotsListener(job.getId(), listener), false));
+    }
+
+    private ActionListener<SearchResponse> expiredSnapshotsListener(String jobId, ActionListener<Boolean> listener) {
+        return new ActionListener<SearchResponse>() {
             @Override
             public void onResponse(SearchResponse searchResponse) {
                 try {
@@ -100,9 +110,9 @@ public class ExpiredModelSnapshotsRemover extends AbstractExpiredJobDataRemover
 
             @Override
             public void onFailure(Exception e) {
-                listener.onFailure(new ElasticsearchException("[" + job.getId() +  "] Search for expired snapshots failed", e));
+                listener.onFailure(new ElasticsearchException("[" + jobId +  "] Search for expired snapshots failed", e));
             }
-        });
+        };
     }
 
     private void deleteModelSnapshots(Iterator<ModelSnapshot> modelSnapshotIterator, ActionListener<Boolean> listener) {

+ 65 - 14
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/retention/ExpiredModelSnapshotsRemoverTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.MetaData;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.json.JsonXContent;
@@ -21,6 +22,8 @@ import org.elasticsearch.mock.orig.Mockito;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.FixedExecutorBuilder;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.MLMetadataField;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
@@ -28,6 +31,8 @@ import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.config.JobTests;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
+import org.elasticsearch.xpack.ml.MachineLearning;
+import org.junit.After;
 import org.junit.Before;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -38,24 +43,27 @@ import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.same;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
 
     private Client client;
+    private ThreadPool threadPool;
     private ClusterService clusterService;
     private ClusterState clusterState;
     private List<SearchRequest> capturedSearchRequests;
     private List<DeleteModelSnapshotAction.Request> capturedDeleteModelSnapshotRequests;
     private List<SearchResponse> searchResponsesPerCall;
-    private ActionListener<Boolean> listener;
+    private TestListener listener;
 
     @Before
     public void setUpTests() {
@@ -66,7 +74,19 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
         clusterState = mock(ClusterState.class);
         when(clusterService.state()).thenReturn(clusterState);
         client = mock(Client.class);
-        listener = mock(ActionListener.class);
+        listener = new TestListener();
+
+        // Init thread pool
+        Settings settings = Settings.builder()
+                .put("node.name", "expired_model_snapshots_remover_test")
+                .build();
+        threadPool = new ThreadPool(settings,
+                new FixedExecutorBuilder(settings, MachineLearning.UTILITY_THREAD_POOL_NAME, 1, 1000, ""));
+    }
+
+    @After
+    public void shutdownThreadPool() throws InterruptedException {
+        terminate(threadPool);
     }
 
     public void testRemove_GivenJobsWithoutRetentionPolicy() {
@@ -78,7 +98,8 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
 
         createExpiredModelSnapshotsRemover().remove(listener);
 
-        verify(listener).onResponse(true);
+        listener.waitToCompletion();
+        assertThat(listener.success, is(true));
         Mockito.verifyNoMoreInteractions(client);
     }
 
@@ -88,7 +109,8 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
 
         createExpiredModelSnapshotsRemover().remove(listener);
 
-        verify(listener).onResponse(true);
+        listener.waitToCompletion();
+        assertThat(listener.success, is(true));
         Mockito.verifyNoMoreInteractions(client);
     }
 
@@ -108,6 +130,9 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
 
         createExpiredModelSnapshotsRemover().remove(listener);
 
+        listener.waitToCompletion();
+        assertThat(listener.success, is(true));
+
         assertThat(capturedSearchRequests.size(), equalTo(2));
         SearchRequest searchRequest = capturedSearchRequests.get(0);
         assertThat(searchRequest.indices(), equalTo(new String[] {AnomalyDetectorsIndex.jobResultsAliasedName("snapshots-1")}));
@@ -124,8 +149,6 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
         deleteSnapshotRequest = capturedDeleteModelSnapshotRequests.get(2);
         assertThat(deleteSnapshotRequest.getJobId(), equalTo("snapshots-2"));
         assertThat(deleteSnapshotRequest.getSnapshotId(), equalTo("snapshots-2_1"));
-
-        verify(listener).onResponse(true);
     }
 
     public void testRemove_GivenClientSearchRequestsFail() throws IOException {
@@ -144,13 +167,14 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
 
         createExpiredModelSnapshotsRemover().remove(listener);
 
+        listener.waitToCompletion();
+        assertThat(listener.success, is(false));
+
         assertThat(capturedSearchRequests.size(), equalTo(1));
         SearchRequest searchRequest = capturedSearchRequests.get(0);
         assertThat(searchRequest.indices(), equalTo(new String[] {AnomalyDetectorsIndex.jobResultsAliasedName("snapshots-1")}));
 
         assertThat(capturedDeleteModelSnapshotRequests.size(), equalTo(0));
-
-        verify(listener).onFailure(any());
     }
 
     public void testRemove_GivenClientDeleteSnapshotRequestsFail() throws IOException {
@@ -169,6 +193,9 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
 
         createExpiredModelSnapshotsRemover().remove(listener);
 
+        listener.waitToCompletion();
+        assertThat(listener.success, is(false));
+
         assertThat(capturedSearchRequests.size(), equalTo(1));
         SearchRequest searchRequest = capturedSearchRequests.get(0);
         assertThat(searchRequest.indices(), equalTo(new String[] {AnomalyDetectorsIndex.jobResultsAliasedName("snapshots-1")}));
@@ -177,8 +204,6 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
         DeleteModelSnapshotAction.Request deleteSnapshotRequest = capturedDeleteModelSnapshotRequests.get(0);
         assertThat(deleteSnapshotRequest.getJobId(), equalTo("snapshots-1"));
         assertThat(deleteSnapshotRequest.getSnapshotId(), equalTo("snapshots-1_1"));
-
-        verify(listener).onFailure(any());
     }
 
     private void givenJobs(List<Job> jobs) {
@@ -192,7 +217,7 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
     }
 
     private ExpiredModelSnapshotsRemover createExpiredModelSnapshotsRemover() {
-        return new ExpiredModelSnapshotsRemover(client, clusterService);
+        return new ExpiredModelSnapshotsRemover(client, threadPool, clusterService);
     }
 
     private static ModelSnapshot createModelSnapshot(String jobId, String snapshotId) {
@@ -230,7 +255,7 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
             int callCount = 0;
 
             @Override
-            public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+            public Void answer(InvocationOnMock invocationOnMock) {
                 SearchRequest searchRequest = (SearchRequest) invocationOnMock.getArguments()[1];
                 capturedSearchRequests.add(searchRequest);
                 ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocationOnMock.getArguments()[2];
@@ -244,7 +269,7 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
         }).when(client).execute(same(SearchAction.INSTANCE), any(), any());
         doAnswer(new Answer<Void>() {
             @Override
-            public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+            public Void answer(InvocationOnMock invocationOnMock) {
                 capturedDeleteModelSnapshotRequests.add((DeleteModelSnapshotAction.Request) invocationOnMock.getArguments()[1]);
                 ActionListener<DeleteModelSnapshotAction.Response> listener =
                         (ActionListener<DeleteModelSnapshotAction.Response>) invocationOnMock.getArguments()[2];
@@ -257,4 +282,30 @@ public class ExpiredModelSnapshotsRemoverTests extends ESTestCase {
             }
         }).when(client).execute(same(DeleteModelSnapshotAction.INSTANCE), any(), any());
     }
+
+    private class TestListener implements ActionListener<Boolean> {
+
+        private boolean success;
+        private final CountDownLatch latch = new CountDownLatch(1);
+
+        @Override
+        public void onResponse(Boolean aBoolean) {
+            success = aBoolean;
+            latch.countDown();
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            latch.countDown();
+        }
+
+        public void waitToCompletion() {
+            try {
+                latch.await(10, TimeUnit.SECONDS);
+            } catch (InterruptedException e) {
+                fail("listener timed out before completing");
+            }
+        }
+    }
+
 }