|  | @@ -6,20 +6,102 @@
 | 
											
												
													
														|  |   */
 |  |   */
 | 
											
												
													
														|  |  package org.elasticsearch.xpack.ml;
 |  |  package org.elasticsearch.xpack.ml;
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +import org.apache.lucene.util.SetOnce;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.action.ActionListener;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.action.support.master.AcknowledgedResponse;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.client.Client;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.cluster.ClusterName;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.cluster.ClusterState;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.cluster.metadata.Metadata;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.cluster.service.ClusterService;
 | 
											
												
													
														|  |  import org.elasticsearch.common.settings.Settings;
 |  |  import org.elasticsearch.common.settings.Settings;
 | 
											
												
													
														|  |  import org.elasticsearch.license.XPackLicenseState;
 |  |  import org.elasticsearch.license.XPackLicenseState;
 | 
											
												
													
														|  |  import org.elasticsearch.monitor.os.OsStats;
 |  |  import org.elasticsearch.monitor.os.OsStats;
 | 
											
												
													
														|  |  import org.elasticsearch.test.ESTestCase;
 |  |  import org.elasticsearch.test.ESTestCase;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.threadpool.TestThreadPool;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.threadpool.ThreadPool;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.xpack.core.ml.MlMetadata;
 | 
											
												
													
														|  | 
 |  | +import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  import java.io.IOException;
 |  |  import java.io.IOException;
 | 
											
												
													
														|  | 
 |  | +import java.util.Collections;
 | 
											
												
													
														|  | 
 |  | +import java.util.Map;
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  import static org.hamcrest.Matchers.containsString;
 |  |  import static org.hamcrest.Matchers.containsString;
 | 
											
												
													
														|  | 
 |  | +import static org.hamcrest.Matchers.equalTo;
 | 
											
												
													
														|  |  import static org.hamcrest.Matchers.startsWith;
 |  |  import static org.hamcrest.Matchers.startsWith;
 | 
											
												
													
														|  | 
 |  | +import static org.mockito.Matchers.any;
 | 
											
												
													
														|  | 
 |  | +import static org.mockito.Matchers.eq;
 | 
											
												
													
														|  | 
 |  | +import static org.mockito.Matchers.same;
 | 
											
												
													
														|  | 
 |  | +import static org.mockito.Mockito.doAnswer;
 | 
											
												
													
														|  |  import static org.mockito.Mockito.mock;
 |  |  import static org.mockito.Mockito.mock;
 | 
											
												
													
														|  | 
 |  | +import static org.mockito.Mockito.verify;
 | 
											
												
													
														|  | 
 |  | +import static org.mockito.Mockito.verifyZeroInteractions;
 | 
											
												
													
														|  |  import static org.mockito.Mockito.when;
 |  |  import static org.mockito.Mockito.when;
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  public class MachineLearningTests extends ESTestCase {
 |  |  public class MachineLearningTests extends ESTestCase {
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +    @SuppressWarnings("unchecked")
 | 
											
												
													
														|  | 
 |  | +    public void testPrePostSystemIndexUpgrade_givenNotInUpgradeMode() {
 | 
											
												
													
														|  | 
 |  | +        ThreadPool threadpool = new TestThreadPool("test");
 | 
											
												
													
														|  | 
 |  | +        ClusterService clusterService = mock(ClusterService.class);
 | 
											
												
													
														|  | 
 |  | +        when(clusterService.state()).thenReturn(ClusterState.EMPTY_STATE);
 | 
											
												
													
														|  | 
 |  | +        Client client = mock(Client.class);
 | 
											
												
													
														|  | 
 |  | +        when(client.threadPool()).thenReturn(threadpool);
 | 
											
												
													
														|  | 
 |  | +        doAnswer(invocationOnMock -> {
 | 
											
												
													
														|  | 
 |  | +            ActionListener<AcknowledgedResponse> listener = (ActionListener<AcknowledgedResponse>) invocationOnMock.getArguments()[2];
 | 
											
												
													
														|  | 
 |  | +            listener.onResponse(AcknowledgedResponse.TRUE);
 | 
											
												
													
														|  | 
 |  | +            return null;
 | 
											
												
													
														|  | 
 |  | +        }).when(client).execute(same(SetUpgradeModeAction.INSTANCE), any(SetUpgradeModeAction.Request.class), any(ActionListener.class));
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        MachineLearning machineLearning = createMachineLearning(Settings.EMPTY);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        SetOnce<Map<String, Object>> response = new SetOnce<>();
 | 
											
												
													
														|  | 
 |  | +        machineLearning.prepareForIndicesMigration(clusterService, client, ActionListener.wrap(
 | 
											
												
													
														|  | 
 |  | +            response::set,
 | 
											
												
													
														|  | 
 |  | +            e -> fail(e.getMessage())
 | 
											
												
													
														|  | 
 |  | +        ));
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        assertThat(response.get(), equalTo(Collections.singletonMap("already_in_upgrade_mode", false)));
 | 
											
												
													
														|  | 
 |  | +        verify(client).execute(same(SetUpgradeModeAction.INSTANCE), eq(new SetUpgradeModeAction.Request(true)), any(ActionListener.class));
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        machineLearning.indicesMigrationComplete(response.get(), clusterService, client, ActionListener.wrap(
 | 
											
												
													
														|  | 
 |  | +            ESTestCase::assertTrue,
 | 
											
												
													
														|  | 
 |  | +            e -> fail(e.getMessage())
 | 
											
												
													
														|  | 
 |  | +        ));
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        verify(client).execute(same(SetUpgradeModeAction.INSTANCE), eq(new SetUpgradeModeAction.Request(false)), any(ActionListener.class));
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        threadpool.shutdown();
 | 
											
												
													
														|  | 
 |  | +    }
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    public void testPrePostSystemIndexUpgrade_givenAlreadyInUpgradeMode() {
 | 
											
												
													
														|  | 
 |  | +        ClusterService clusterService = mock(ClusterService.class);
 | 
											
												
													
														|  | 
 |  | +        when(clusterService.state()).thenReturn(
 | 
											
												
													
														|  | 
 |  | +            ClusterState.builder(ClusterName.DEFAULT)
 | 
											
												
													
														|  | 
 |  | +                .metadata(Metadata.builder().putCustom(MlMetadata.TYPE, new MlMetadata.Builder().isUpgradeMode(true).build())).build());
 | 
											
												
													
														|  | 
 |  | +        Client client = mock(Client.class);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        MachineLearning machineLearning = createMachineLearning(Settings.EMPTY);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        SetOnce<Map<String, Object>> response = new SetOnce<>();
 | 
											
												
													
														|  | 
 |  | +        machineLearning.prepareForIndicesMigration(clusterService, client, ActionListener.wrap(
 | 
											
												
													
														|  | 
 |  | +            response::set,
 | 
											
												
													
														|  | 
 |  | +            e -> fail(e.getMessage())
 | 
											
												
													
														|  | 
 |  | +        ));
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        assertThat(response.get(), equalTo(Collections.singletonMap("already_in_upgrade_mode", true)));
 | 
											
												
													
														|  | 
 |  | +        verifyZeroInteractions(client);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        machineLearning.indicesMigrationComplete(response.get(), clusterService, client, ActionListener.wrap(
 | 
											
												
													
														|  | 
 |  | +            ESTestCase::assertTrue,
 | 
											
												
													
														|  | 
 |  | +            e -> fail(e.getMessage())
 | 
											
												
													
														|  | 
 |  | +        ));
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        // Neither pre nor post should have called any action
 | 
											
												
													
														|  | 
 |  | +        verifyZeroInteractions(client);
 | 
											
												
													
														|  | 
 |  | +    }
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      public void testMaxOpenWorkersSetting_givenDefault() {
 |  |      public void testMaxOpenWorkersSetting_givenDefault() {
 | 
											
												
													
														|  |          int maxOpenWorkers = MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(Settings.EMPTY);
 |  |          int maxOpenWorkers = MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(Settings.EMPTY);
 | 
											
												
													
														|  |          assertEquals(512, maxOpenWorkers);
 |  |          assertEquals(512, maxOpenWorkers);
 |