Browse Source

[ML] Refactor autodetect service into its own class (#41378)

This also improves aims to improve the corresponding unit tests
with regard to readability and maintainability.
Dimitris Athanasiou 6 years ago
parent
commit
768ff2e331

+ 0 - 104
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java

@@ -16,14 +16,12 @@ import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStateListener;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.CheckedConsumer;
-import org.elasticsearch.common.SuppressForbidden;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeUnit;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
-import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@@ -77,20 +75,13 @@ import java.io.InputStream;
 import java.nio.file.Path;
 import java.time.Duration;
 import java.time.ZonedDateTime;
-import java.util.ArrayList;
 import java.util.Date;
 import java.util.Iterator;
-import java.util.List;
 import java.util.Locale;
 import java.util.Optional;
-import java.util.concurrent.AbstractExecutorService;
-import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
-import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
-import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.TimeUnit;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 
@@ -791,99 +782,4 @@ public class AutodetectProcessManager implements ClusterStateListener {
         upgradeInProgress = MlMetadata.getMlMetadata(event.state()).isUpgradeMode();
     }
 
-    /*
-     * The autodetect native process can only handle a single operation at a time. In order to guarantee that, all
-     * operations are initially added to a queue and a worker thread from ml autodetect threadpool will process each
-     * operation at a time.
-     */
-    static class AutodetectWorkerExecutorService extends AbstractExecutorService {
-
-        private final ThreadContext contextHolder;
-        private final CountDownLatch awaitTermination = new CountDownLatch(1);
-        private final BlockingQueue<Runnable> queue = new LinkedBlockingQueue<>(100);
-
-        private volatile boolean running = true;
-
-        @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
-        AutodetectWorkerExecutorService(ThreadContext contextHolder) {
-            this.contextHolder = contextHolder;
-        }
-
-        @Override
-        public void shutdown() {
-            running = false;
-        }
-
-        @Override
-        public List<Runnable> shutdownNow() {
-            throw new UnsupportedOperationException("not supported");
-        }
-
-        @Override
-        public boolean isShutdown() {
-            return running == false;
-        }
-
-        @Override
-        public boolean isTerminated() {
-            return awaitTermination.getCount() == 0;
-        }
-
-        @Override
-        public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
-            return awaitTermination.await(timeout, unit);
-        }
-
-        @Override
-        public synchronized void execute(Runnable command) {
-            if (isShutdown()) {
-                EsRejectedExecutionException rejected = new EsRejectedExecutionException("autodetect worker service has shutdown", true);
-                if (command instanceof AbstractRunnable) {
-                    ((AbstractRunnable) command).onRejection(rejected);
-                } else {
-                    throw rejected;
-                }
-            }
-
-            boolean added = queue.offer(contextHolder.preserveContext(command));
-            if (added == false) {
-                throw new ElasticsearchStatusException("Unable to submit operation", RestStatus.TOO_MANY_REQUESTS);
-            }
-        }
-
-        void start() {
-            try {
-                while (running) {
-                    Runnable runnable = queue.poll(500, TimeUnit.MILLISECONDS);
-                    if (runnable != null) {
-                        try {
-                            runnable.run();
-                        } catch (Exception e) {
-                            logger.error("error handling job operation", e);
-                        }
-                        EsExecutors.rethrowErrors(contextHolder.unwrap(runnable));
-                    }
-                }
-
-                synchronized (this) {
-                    // if shutdown with tasks pending notify the handlers
-                    if (queue.isEmpty() == false) {
-                        List<Runnable> notExecuted = new ArrayList<>();
-                        queue.drainTo(notExecuted);
-
-                        for (Runnable runnable : notExecuted) {
-                            if (runnable instanceof AbstractRunnable) {
-                                ((AbstractRunnable) runnable).onRejection(
-                                    new EsRejectedExecutionException("unable to process as autodetect worker service has shutdown", true));
-                            }
-                        }
-                    }
-                }
-            } catch (InterruptedException e) {
-                Thread.currentThread().interrupt();
-            } finally {
-                awaitTermination.countDown();
-            }
-        }
-    }
 }

+ 122 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorService.java

@@ -0,0 +1,122 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.job.process.autodetect;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.SuppressForbidden;
+import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.rest.RestStatus;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.AbstractExecutorService;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/*
+ * The autodetect native process can only handle a single operation at a time. In order to guarantee that, all
+ * operations are initially added to a queue and a worker thread from ml autodetect threadpool will process each
+ * operation at a time.
+ */
+class AutodetectWorkerExecutorService extends AbstractExecutorService {
+
+    private static final Logger logger = LogManager.getLogger(AutodetectWorkerExecutorService.class);
+
+    private final ThreadContext contextHolder;
+    private final CountDownLatch awaitTermination = new CountDownLatch(1);
+    private final BlockingQueue<Runnable> queue = new LinkedBlockingQueue<>(100);
+
+    private volatile boolean running = true;
+
+    @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
+    AutodetectWorkerExecutorService(ThreadContext contextHolder) {
+        this.contextHolder = contextHolder;
+    }
+
+    @Override
+    public void shutdown() {
+        running = false;
+    }
+
+    @Override
+    public List<Runnable> shutdownNow() {
+        throw new UnsupportedOperationException("not supported");
+    }
+
+    @Override
+    public boolean isShutdown() {
+        return running == false;
+    }
+
+    @Override
+    public boolean isTerminated() {
+        return awaitTermination.getCount() == 0;
+    }
+
+    @Override
+    public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
+        return awaitTermination.await(timeout, unit);
+    }
+
+    @Override
+    public synchronized void execute(Runnable command) {
+        if (isShutdown()) {
+            EsRejectedExecutionException rejected = new EsRejectedExecutionException("autodetect worker service has shutdown", true);
+            if (command instanceof AbstractRunnable) {
+                ((AbstractRunnable) command).onRejection(rejected);
+            } else {
+                throw rejected;
+            }
+        }
+
+        boolean added = queue.offer(contextHolder.preserveContext(command));
+        if (added == false) {
+            throw new ElasticsearchStatusException("Unable to submit operation", RestStatus.TOO_MANY_REQUESTS);
+        }
+    }
+
+    void start() {
+        try {
+            while (running) {
+                Runnable runnable = queue.poll(500, TimeUnit.MILLISECONDS);
+                if (runnable != null) {
+                    try {
+                        runnable.run();
+                    } catch (Exception e) {
+                        logger.error("error handling job operation", e);
+                    }
+                    EsExecutors.rethrowErrors(contextHolder.unwrap(runnable));
+                }
+            }
+
+            synchronized (this) {
+                // if shutdown with tasks pending notify the handlers
+                if (queue.isEmpty() == false) {
+                    List<Runnable> notExecuted = new ArrayList<>();
+                    queue.drainTo(notExecuted);
+
+                    for (Runnable runnable : notExecuted) {
+                        if (runnable instanceof AbstractRunnable) {
+                            ((AbstractRunnable) runnable).onRejection(
+                                new EsRejectedExecutionException("unable to process as autodetect worker service has shutdown", true));
+                        }
+                    }
+                }
+            }
+        } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+        } finally {
+            awaitTermination.countDown();
+        }
+    }
+}

+ 67 - 180
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java

@@ -15,7 +15,6 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.CheckedConsumer;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
@@ -26,7 +25,6 @@ import org.elasticsearch.env.TestEnvironment;
 import org.elasticsearch.index.analysis.AnalysisRegistry;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.junit.annotations.TestLogging;
-import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
 import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
@@ -50,17 +48,14 @@ import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzerTests
 import org.elasticsearch.xpack.ml.job.persistence.JobDataCountsPersister;
 import org.elasticsearch.xpack.ml.job.persistence.JobResultsPersister;
 import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider;
-import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager.AutodetectWorkerExecutorService;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.FlushJobParams;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.TimeRange;
 import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory;
 import org.elasticsearch.xpack.ml.notifications.Auditor;
-import org.junit.After;
 import org.junit.Before;
 import org.mockito.ArgumentCaptor;
-import org.mockito.Mockito;
 
 import java.io.ByteArrayInputStream;
 import java.io.IOException;
@@ -76,11 +71,9 @@ import java.util.SortedMap;
 import java.util.TreeMap;
 import java.util.concurrent.Callable;
 import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.BiConsumer;
@@ -93,7 +86,6 @@ import static org.elasticsearch.mock.orig.Mockito.times;
 import static org.elasticsearch.mock.orig.Mockito.verify;
 import static org.elasticsearch.mock.orig.Mockito.verifyNoMoreInteractions;
 import static org.elasticsearch.mock.orig.Mockito.when;
-import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.Matchers.nullValue;
@@ -103,6 +95,7 @@ import static org.mockito.Matchers.anyBoolean;
 import static org.mockito.Matchers.anyString;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Matchers.same;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
 
@@ -115,11 +108,15 @@ import static org.mockito.Mockito.spy;
 public class AutodetectProcessManagerTests extends ESTestCase {
 
     private Environment environment;
+    private Client client;
+    private ThreadPool threadPool;
     private AnalysisRegistry analysisRegistry;
     private JobManager jobManager;
     private JobResultsProvider jobResultsProvider;
     private JobResultsPersister jobResultsPersister;
     private JobDataCountsPersister jobDataCountsPersister;
+    private AutodetectCommunicator autodetectCommunicator;
+    private AutodetectProcessFactory autodetectFactory;
     private NormalizerFactory normalizerFactory;
     private Auditor auditor;
     private ClusterState clusterState;
@@ -131,18 +128,24 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     private Quantiles quantiles = new Quantiles("foo", new Date(), "state");
     private Set<MlFilter> filters = new HashSet<>();
 
-    private ThreadPool threadPool;
-
     @Before
     public void setup() throws Exception {
         Settings settings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir()).build();
         environment = TestEnvironment.newEnvironment(settings);
+        client = mock(Client.class);
+
+        threadPool = mock(ThreadPool.class);
+        when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
+        when(threadPool.executor(anyString())).thenReturn(EsExecutors.newDirectExecutorService());
+
         analysisRegistry = CategorizationAnalyzerTests.buildTestAnalysisRegistry(environment);
         jobManager = mock(JobManager.class);
         jobResultsProvider = mock(JobResultsProvider.class);
         jobResultsPersister = mock(JobResultsPersister.class);
         when(jobResultsPersister.bulkPersisterBuilder(any())).thenReturn(mock(JobResultsPersister.Builder.class));
         jobDataCountsPersister = mock(JobDataCountsPersister.class);
+        autodetectCommunicator = mock(AutodetectCommunicator.class);
+        autodetectFactory = mock(AutodetectProcessFactory.class);
         normalizerFactory = mock(NormalizerFactory.class);
         auditor = mock(Auditor.class);
         clusterService = mock(ClusterService.class);
@@ -170,25 +173,16 @@ public class AutodetectProcessManagerTests extends ESTestCase {
             handler.accept(buildAutodetectParams());
             return null;
         }).when(jobResultsProvider).getAutodetectParams(any(), any(), any());
-
-        threadPool = new TestThreadPool("AutodetectProcessManagerTests");
-    }
-
-    @After
-    public void stopThreadPool() {
-        terminate(threadPool);
     }
 
     public void testOpenJob() {
-        Client client = mock(Client.class);
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
         doAnswer(invocationOnMock -> {
             @SuppressWarnings("unchecked")
             ActionListener<Job> listener = (ActionListener<Job>) invocationOnMock.getArguments()[1];
             listener.onResponse(createJobDetails("foo"));
             return null;
         }).when(jobManager).getJob(eq("foo"), any());
-        AutodetectProcessManager manager = createManager(communicator, client);
+        AutodetectProcessManager manager = createSpyManager();
 
         JobTask jobTask = mock(JobTask.class);
         when(jobTask.getJobId()).thenReturn("foo");
@@ -200,8 +194,6 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testOpenJob_withoutVersion() {
-        Client client = mock(Client.class);
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
         Job.Builder jobBuilder = new Job.Builder(createJobDetails("no_version"));
         jobBuilder.setJobVersion(null);
         Job job = jobBuilder.build();
@@ -214,7 +206,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
             return null;
         }).when(jobManager).getJob(eq(job.getId()), any());
 
-        AutodetectProcessManager manager = createManager(communicator, client);
+        AutodetectProcessManager manager = createSpyManager();
         JobTask jobTask = mock(JobTask.class);
         when(jobTask.getJobId()).thenReturn(job.getId());
         AtomicReference<Exception> errorHolder = new AtomicReference<>();
@@ -235,25 +227,22 @@ public class AutodetectProcessManagerTests extends ESTestCase {
             }).when(jobManager).getJob(eq(jobId), any());
         }
 
-        Client client = mock(Client.class);
-        ThreadPool threadPool = mock(ThreadPool.class);
-        when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
         ThreadPool.Cancellable cancellable = mock(ThreadPool.Cancellable.class);
         when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(cancellable);
-        ExecutorService executorService = mock(ExecutorService.class);
-        Future<?> future = mock(Future.class);
-        when(executorService.submit(any(Callable.class))).thenReturn(future);
-        when(threadPool.executor(anyString())).thenReturn(EsExecutors.newDirectExecutorService());
+
         AutodetectProcess autodetectProcess = mock(AutodetectProcess.class);
         when(autodetectProcess.isProcessAlive()).thenReturn(true);
         when(autodetectProcess.readAutodetectResults()).thenReturn(Collections.emptyIterator());
-        AutodetectProcessFactory autodetectProcessFactory =
-                (j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
+
+        autodetectFactory = (j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
         Settings.Builder settings = Settings.builder();
         settings.put(MachineLearning.MAX_OPEN_JOBS_PER_NODE.getKey(), 3);
-        AutodetectProcessManager manager = spy(new AutodetectProcessManager(environment, settings.build(), client, threadPool,
-                jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory,
-                normalizerFactory, new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService));
+        AutodetectProcessManager manager = createSpyManager(settings.build());
+        doCallRealMethod().when(manager).create(any(), any(), any(), any());
+
+        ExecutorService executorService = mock(ExecutorService.class);
+        Future<?> future = mock(Future.class);
+        when(executorService.submit(any(Callable.class))).thenReturn(future);
         doReturn(executorService).when(manager).createAutodetectExecutorService(any());
 
         doAnswer(invocationOnMock -> {
@@ -293,8 +282,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testProcessData()  {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManager(communicator);
+        AutodetectProcessManager manager = createSpyManager();
         assertEquals(0, manager.numberOfOpenJobs());
 
         JobTask jobTask = mock(JobTask.class);
@@ -307,8 +295,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testProcessDataThrowsElasticsearchStatusException_onIoException() {
-        AutodetectCommunicator communicator = Mockito.mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManager(communicator);
+        AutodetectProcessManager manager = createSpyManager();
 
         DataLoadParams params = mock(DataLoadParams.class);
         InputStream inputStream = createInputStream("");
@@ -318,7 +305,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
             BiConsumer<DataCounts, Exception> handler = (BiConsumer<DataCounts, Exception>) invocationOnMock.getArguments()[4];
             handler.accept(null, new IOException("blah"));
             return null;
-        }).when(communicator).writeToJob(eq(inputStream), same(analysisRegistry), same(xContentType), eq(params), any());
+        }).when(autodetectCommunicator).writeToJob(eq(inputStream), same(analysisRegistry), same(xContentType), eq(params), any());
 
 
         JobTask jobTask = mock(JobTask.class);
@@ -330,8 +317,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testCloseJob() {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManager(communicator);
+        AutodetectProcessManager manager = createSpyManager();
         assertEquals(0, manager.numberOfOpenJobs());
 
         JobTask jobTask = mock(JobTask.class);
@@ -350,7 +336,6 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     // interleaved in the AutodetectProcessManager.close() call
     @TestLogging("org.elasticsearch.xpack.ml.job.process.autodetect:DEBUG")
     public void testCanCloseClosingJob() throws Exception {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
         AtomicInteger numberOfCommunicatorCloses = new AtomicInteger(0);
         doAnswer(invocationOnMock -> {
             numberOfCommunicatorCloses.incrementAndGet();
@@ -358,8 +343,8 @@ public class AutodetectProcessManagerTests extends ESTestCase {
             // the middle of the AutodetectProcessManager.close() method
             Thread.yield();
             return null;
-        }).when(communicator).close(anyBoolean(), anyString());
-        AutodetectProcessManager manager = createManager(communicator);
+        }).when(autodetectCommunicator).close(anyBoolean(), anyString());
+        AutodetectProcessManager manager = createSpyManager();
         assertEquals(0, manager.numberOfOpenJobs());
 
         JobTask jobTask = mock(JobTask.class);
@@ -395,19 +380,18 @@ public class AutodetectProcessManagerTests extends ESTestCase {
         CountDownLatch closeStartedLatch = new CountDownLatch(1);
         CountDownLatch killLatch = new CountDownLatch(1);
         CountDownLatch closeInterruptedLatch = new CountDownLatch(1);
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
         doAnswer(invocationOnMock -> {
             closeStartedLatch.countDown();
             if (killLatch.await(3, TimeUnit.SECONDS)) {
                 closeInterruptedLatch.countDown();
             }
             return null;
-        }).when(communicator).close(anyBoolean(), anyString());
+        }).when(autodetectCommunicator).close(anyBoolean(), anyString());
         doAnswer(invocationOnMock -> {
             killLatch.countDown();
             return null;
-        }).when(communicator).killProcess(anyBoolean(), anyBoolean(), anyBoolean());
-        AutodetectProcessManager manager = createManager(communicator);
+        }).when(autodetectCommunicator).killProcess(anyBoolean(), anyBoolean(), anyBoolean());
+        AutodetectProcessManager manager = createSpyManager();
         assertEquals(0, manager.numberOfOpenJobs());
 
         JobTask jobTask = mock(JobTask.class);
@@ -433,8 +417,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testBucketResetMessageIsSent() {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManager(communicator);
+        AutodetectProcessManager manager = createSpyManager();
         XContentType xContentType = randomFrom(XContentType.values());
 
         DataLoadParams params = new DataLoadParams(TimeRange.builder().startTime("1000").endTime("2000").build(), Optional.empty());
@@ -443,12 +426,11 @@ public class AutodetectProcessManagerTests extends ESTestCase {
         when(jobTask.getJobId()).thenReturn("foo");
         manager.openJob(jobTask, clusterState, (e, b) -> {});
         manager.processData(jobTask, analysisRegistry, inputStream, xContentType, params, (dataCounts1, e) -> {});
-        verify(communicator).writeToJob(same(inputStream), same(analysisRegistry), same(xContentType), same(params), any());
+        verify(autodetectCommunicator).writeToJob(same(inputStream), same(analysisRegistry), same(xContentType), same(params), any());
     }
 
     public void testFlush() {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManager(communicator);
+        AutodetectProcessManager manager = createSpyManager();
 
         JobTask jobTask = mock(JobTask.class);
         when(jobTask.getJobId()).thenReturn("foo");
@@ -460,12 +442,11 @@ public class AutodetectProcessManagerTests extends ESTestCase {
         FlushJobParams params = FlushJobParams.builder().build();
         manager.flushJob(jobTask, params, ActionListener.wrap(flushAcknowledgement -> {}, e -> fail(e.getMessage())));
 
-        verify(communicator).flushJob(same(params), any());
+        verify(autodetectCommunicator).flushJob(same(params), any());
     }
 
     public void testFlushThrows() {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManagerAndCallProcessData(communicator, "foo");
+        AutodetectProcessManager manager = createSpyManagerAndCallProcessData("foo");
 
         FlushJobParams params = FlushJobParams.builder().build();
         doAnswer(invocationOnMock -> {
@@ -473,7 +454,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
             BiConsumer<Void, Exception> handler = (BiConsumer<Void, Exception>) invocationOnMock.getArguments()[1];
             handler.accept(null, new IOException("blah"));
             return null;
-        }).when(communicator).flushJob(same(params), any());
+        }).when(autodetectCommunicator).flushJob(same(params), any());
 
         JobTask jobTask = mock(JobTask.class);
         when(jobTask.getJobId()).thenReturn("foo");
@@ -483,12 +464,11 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testCloseThrows() {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManager(communicator);
+        AutodetectProcessManager manager = createSpyManager();
 
         // let the communicator throw, simulating a problem with the underlying
         // autodetect, e.g. a crash
-        doThrow(Exception.class).when(communicator).close(anyBoolean(), anyString());
+        doThrow(Exception.class).when(autodetectCommunicator).close(anyBoolean(), anyString());
 
         // create a jobtask
         JobTask jobTask = mock(JobTask.class);
@@ -507,8 +487,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testWriteUpdateProcessMessage() {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManagerAndCallProcessData(communicator, "foo");
+        AutodetectProcessManager manager = createSpyManagerAndCallProcessData("foo");
         ModelPlotConfig modelConfig = mock(ModelPlotConfig.class);
         List<DetectionRule> rules = Collections.singletonList(mock(DetectionRule.class));
         List<JobUpdate.DetectorUpdate> detectorUpdates = Collections.singletonList(new JobUpdate.DetectorUpdate(2, null, rules));
@@ -519,7 +498,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
         manager.writeUpdateProcessMessage(jobTask, updateParams, e -> {});
 
         ArgumentCaptor<UpdateProcessMessage> captor = ArgumentCaptor.forClass(UpdateProcessMessage.class);
-        verify(communicator).writeUpdateProcessMessage(captor.capture(), any());
+        verify(autodetectCommunicator).writeUpdateProcessMessage(captor.capture(), any());
 
         UpdateProcessMessage updateProcessMessage = captor.getValue();
         assertThat(updateProcessMessage.getModelPlotConfig(), equalTo(modelConfig));
@@ -527,8 +506,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testJobHasActiveAutodetectProcess() {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManager(communicator);
+        AutodetectProcessManager manager = createSpyManager();
         JobTask jobTask = mock(JobTask.class);
         when(jobTask.getJobId()).thenReturn("foo");
         assertFalse(manager.jobHasActiveAutodetectProcess(jobTask));
@@ -545,8 +523,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testKillKillsAutodetectProcess() throws IOException {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManager(communicator);
+        AutodetectProcessManager manager = createSpyManager();
         JobTask jobTask = mock(JobTask.class);
         when(jobTask.getJobId()).thenReturn("foo");
         assertFalse(manager.jobHasActiveAutodetectProcess(jobTask));
@@ -559,12 +536,11 @@ public class AutodetectProcessManagerTests extends ESTestCase {
 
         manager.killAllProcessesOnThisNode();
 
-        verify(communicator).killProcess(false, false, true);
+        verify(autodetectCommunicator).killProcess(false, false, true);
     }
 
     public void testKillingAMissingJobFinishesTheTask() {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
-        AutodetectProcessManager manager = createManager(communicator);
+        AutodetectProcessManager manager = createSpyManager();
         JobTask jobTask = mock(JobTask.class);
         when(jobTask.getJobId()).thenReturn("foo");
 
@@ -574,14 +550,13 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testProcessData_GivenStateNotOpened() {
-        AutodetectCommunicator communicator = mock(AutodetectCommunicator.class);
         doAnswer(invocationOnMock -> {
             @SuppressWarnings("unchecked")
             BiConsumer<DataCounts, Exception> handler = (BiConsumer<DataCounts, Exception>) invocationOnMock.getArguments()[4];
             handler.accept(new DataCounts("foo"), null);
             return null;
-        }).when(communicator).writeToJob(any(), any(), any(), any(), any());
-        AutodetectProcessManager manager = createManager(communicator);
+        }).when(autodetectCommunicator).writeToJob(any(), any(), any(), any(), any());
+        AutodetectProcessManager manager = createSpyManager();
 
         JobTask jobTask = mock(JobTask.class);
         when(jobTask.getJobId()).thenReturn("foo");
@@ -595,8 +570,6 @@ public class AutodetectProcessManagerTests extends ESTestCase {
     }
 
     public void testCreate_notEnoughThreads() throws IOException {
-        Client client = mock(Client.class);
-        ThreadPool threadPool = mock(ThreadPool.class);
         when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
         ExecutorService executorService = mock(ExecutorService.class);
         doThrow(new EsRejectedExecutionException("")).when(executorService).submit(any(Runnable.class));
@@ -611,11 +584,9 @@ public class AutodetectProcessManagerTests extends ESTestCase {
         }).when(jobManager).getJob(eq("my_id"), any());
 
         AutodetectProcess autodetectProcess = mock(AutodetectProcess.class);
-        AutodetectProcessFactory autodetectProcessFactory =
-                (j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
-        AutodetectProcessManager manager = new AutodetectProcessManager(environment, Settings.EMPTY,
-                client, threadPool, jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory,
-                normalizerFactory, new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService);
+        autodetectFactory = (j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
+        AutodetectProcessManager manager = createSpyManager();
+        doCallRealMethod().when(manager).create(any(), any(), any(), any());
 
         JobTask jobTask = mock(JobTask.class);
         when(jobTask.getJobId()).thenReturn("my_id");
@@ -675,86 +646,7 @@ public class AutodetectProcessManagerTests extends ESTestCase {
         verifyNoMoreInteractions(auditor);
     }
 
-    public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() {
-        final ThreadPool threadPool = new TestThreadPool("testAutodetectWorkerExecutorServiceDoesNotSwallowErrors");
-        try {
-            final AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(threadPool.getThreadContext());
-            if (randomBoolean()) {
-                executor.submit(() -> {
-                    throw new Error("future error");
-                });
-            } else {
-                executor.execute(() -> {
-                    throw new Error("future error");
-                });
-            }
-            final Error e = expectThrows(Error.class, () -> executor.start());
-            assertThat(e.getMessage(), containsString("future error"));
-        } finally {
-            ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS);
-        }
-    }
-
-    public void testAutodetectWorkerExecutorService_SubmitAfterShutdown() {
-        AutodetectProcessManager.AutodetectWorkerExecutorService executor =
-                new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY));
-
-        threadPool.generic().execute(() -> executor.start());
-        executor.shutdown();
-        expectThrows(EsRejectedExecutionException.class, () -> executor.execute(() -> {}));
-    }
-
-    public void testAutodetectWorkerExecutorService_TasksNotExecutedCallHandlerOnShutdown()
-            throws InterruptedException, ExecutionException {
-        AutodetectProcessManager.AutodetectWorkerExecutorService executor =
-                new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY));
-
-        CountDownLatch latch = new CountDownLatch(1);
-
-        Future<?> executorFinished = threadPool.generic().submit(() -> executor.start());
-
-        // run a task that will block while the others are queued up
-        executor.execute(() -> {
-            try {
-                latch.await();
-            } catch (InterruptedException e) {
-            }
-        });
-
-        AtomicBoolean runnableShouldNotBeCalled = new AtomicBoolean(false);
-        executor.execute(() -> runnableShouldNotBeCalled.set(true));
-
-        AtomicInteger onFailureCallCount = new AtomicInteger();
-        AtomicInteger doRunCallCount = new AtomicInteger();
-        for (int i=0; i<2; i++) {
-            executor.execute(new AbstractRunnable() {
-                @Override
-                public void onFailure(Exception e) {
-                    onFailureCallCount.incrementAndGet();
-                }
-
-                @Override
-                protected void doRun() {
-                    doRunCallCount.incrementAndGet();
-                }
-            });
-        }
-
-        // now shutdown
-        executor.shutdown();
-        latch.countDown();
-        executorFinished.get();
-
-        assertFalse(runnableShouldNotBeCalled.get());
-        // the AbstractRunnables should have had their callbacks called
-        assertEquals(2, onFailureCallCount.get());
-        assertEquals(0, doRunCallCount.get());
-    }
-
     private AutodetectProcessManager createNonSpyManager(String jobId) {
-        Client client = mock(Client.class);
-        ThreadPool threadPool = mock(ThreadPool.class);
-        when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
         ExecutorService executorService = mock(ExecutorService.class);
         when(threadPool.executor(anyString())).thenReturn(executorService);
         when(threadPool.scheduleWithFixedDelay(any(), any(), any())).thenReturn(mock(ThreadPool.Cancellable.class));
@@ -766,11 +658,8 @@ public class AutodetectProcessManagerTests extends ESTestCase {
         }).when(jobManager).getJob(eq(jobId), any());
 
         AutodetectProcess autodetectProcess = mock(AutodetectProcess.class);
-        AutodetectProcessFactory autodetectProcessFactory =
-                (j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
-        return new AutodetectProcessManager(environment, Settings.EMPTY, client, threadPool, jobManager,
-                jobResultsProvider, jobResultsPersister, jobDataCountsPersister, autodetectProcessFactory,
-                normalizerFactory, new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService);
+        autodetectFactory = (j, autodetectParams, e, onProcessCrash) -> autodetectProcess;
+        return createManager(Settings.EMPTY);
     }
 
     private AutodetectParams buildAutodetectParams() {
@@ -783,27 +672,25 @@ public class AutodetectProcessManagerTests extends ESTestCase {
                 .build();
     }
 
-    private AutodetectProcessManager createManager(AutodetectCommunicator communicator) {
-        Client client = mock(Client.class);
-        return createManager(communicator, client);
+    private AutodetectProcessManager createSpyManager() {
+        return createSpyManager(Settings.EMPTY);
     }
 
-    private AutodetectProcessManager createManager(AutodetectCommunicator communicator, Client client) {
-        ThreadPool threadPool = mock(ThreadPool.class);
-        when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
-        when(threadPool.executor(anyString())).thenReturn(EsExecutors.newDirectExecutorService());
-        AutodetectProcessFactory autodetectProcessFactory = mock(AutodetectProcessFactory.class);
-        AutodetectProcessManager manager = new AutodetectProcessManager(environment, Settings.EMPTY,
-                client, threadPool, jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister,
-                autodetectProcessFactory, normalizerFactory,
-                new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService);
+    private AutodetectProcessManager createSpyManager(Settings settings) {
+        AutodetectProcessManager manager = createManager(settings);
         manager = spy(manager);
-        doReturn(communicator).when(manager).create(any(), any(), eq(buildAutodetectParams()), any());
+        doReturn(autodetectCommunicator).when(manager).create(any(), any(), eq(buildAutodetectParams()), any());
         return manager;
     }
 
-    private AutodetectProcessManager createManagerAndCallProcessData(AutodetectCommunicator communicator, String jobId) {
-        AutodetectProcessManager manager = createManager(communicator);
+    private AutodetectProcessManager createManager(Settings settings) {
+        return new AutodetectProcessManager(environment, settings,
+            client, threadPool, jobManager, jobResultsProvider, jobResultsPersister, jobDataCountsPersister,
+            autodetectFactory, normalizerFactory,
+            new NamedXContentRegistry(Collections.emptyList()), auditor, clusterService);
+    }
+    private AutodetectProcessManager createSpyManagerAndCallProcessData(String jobId) {
+        AutodetectProcessManager manager = createSpyManager();
         JobTask jobTask = mock(JobTask.class);
         when(jobTask.getJobId()).thenReturn(jobId);
         manager.openJob(jobTask, clusterState, (e, b) -> {});

+ 100 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorServiceTests.java

@@ -0,0 +1,100 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.ml.job.process.autodetect;
+
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.junit.After;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class AutodetectWorkerExecutorServiceTests extends ESTestCase {
+
+    private ThreadPool threadPool = new TestThreadPool("AutodetectWorkerExecutorServiceTests");
+
+    @After
+    public void stopThreadPool() {
+        terminate(threadPool);
+    }
+
+    public void testAutodetectWorkerExecutorService_SubmitAfterShutdown() {
+        AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY));
+
+        threadPool.generic().execute(() -> executor.start());
+        executor.shutdown();
+        expectThrows(EsRejectedExecutionException.class, () -> executor.execute(() -> {}));
+    }
+
+    public void testAutodetectWorkerExecutorService_TasksNotExecutedCallHandlerOnShutdown() throws Exception {
+        AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY));
+
+        CountDownLatch latch = new CountDownLatch(1);
+
+        Future<?> executorFinished = threadPool.generic().submit(() -> executor.start());
+
+        // run a task that will block while the others are queued up
+        executor.execute(() -> {
+            try {
+                latch.await();
+            } catch (InterruptedException e) {
+            }
+        });
+
+        AtomicBoolean runnableShouldNotBeCalled = new AtomicBoolean(false);
+        executor.execute(() -> runnableShouldNotBeCalled.set(true));
+
+        AtomicInteger onFailureCallCount = new AtomicInteger();
+        AtomicInteger doRunCallCount = new AtomicInteger();
+        for (int i=0; i<2; i++) {
+            executor.execute(new AbstractRunnable() {
+                @Override
+                public void onFailure(Exception e) {
+                    onFailureCallCount.incrementAndGet();
+                }
+
+                @Override
+                protected void doRun() {
+                    doRunCallCount.incrementAndGet();
+                }
+            });
+        }
+
+        // now shutdown
+        executor.shutdown();
+        latch.countDown();
+        executorFinished.get();
+
+        assertFalse(runnableShouldNotBeCalled.get());
+        // the AbstractRunnables should have had their callbacks called
+        assertEquals(2, onFailureCallCount.get());
+        assertEquals(0, doRunCallCount.get());
+    }
+
+    public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() {
+        AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(threadPool.getThreadContext());
+        if (randomBoolean()) {
+            executor.submit(() -> {
+                throw new Error("future error");
+            });
+        } else {
+            executor.execute(() -> {
+                throw new Error("future error");
+            });
+        }
+        Error e = expectThrows(Error.class, () -> executor.start());
+        assertThat(e.getMessage(), containsString("future error"));
+    }
+}