Selaa lähdekoodia

Generalize remote license checker (#32971)

Machine learning has baked a remote license checker for use in checking
license compatibility of a remote license. This remote license checker
has general usage for any feature that relies on a remote cluster. For
example, cross-cluster replication will pull changes from a remote
cluster and require that the local and remote clusters have platinum
licenses. This commit generalizes the remote cluster license check for
use in cross-cluster replication.
Jason Tedor 7 vuotta sitten
vanhempi
commit
9050c7e846

+ 281 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/license/RemoteClusterLicenseChecker.java

@@ -0,0 +1,281 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.license;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.ContextPreservingActionListener;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.protocol.xpack.XPackInfoRequest;
+import org.elasticsearch.protocol.xpack.XPackInfoResponse;
+import org.elasticsearch.protocol.xpack.license.LicenseStatus;
+import org.elasticsearch.transport.RemoteClusterAware;
+import org.elasticsearch.xpack.core.action.XPackInfoAction;
+
+import java.util.EnumSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+
+/**
+ * Checks remote clusters for license compatibility with a specified license predicate.
+ */
+public final class RemoteClusterLicenseChecker {
+
+    /**
+     * Encapsulates the license info of a remote cluster.
+     */
+    public static final class RemoteClusterLicenseInfo {
+
+        private final String clusterAlias;
+
+        /**
+         * The alias of the remote cluster.
+         *
+         * @return the cluster alias
+         */
+        public String clusterAlias() {
+            return clusterAlias;
+        }
+
+        private final XPackInfoResponse.LicenseInfo licenseInfo;
+
+        /**
+         * The license info of the remote cluster.
+         *
+         * @return the license info
+         */
+        public XPackInfoResponse.LicenseInfo licenseInfo() {
+            return licenseInfo;
+        }
+
+        RemoteClusterLicenseInfo(final String clusterAlias, final XPackInfoResponse.LicenseInfo licenseInfo) {
+            this.clusterAlias = clusterAlias;
+            this.licenseInfo = licenseInfo;
+        }
+
+    }
+
+    /**
+     * Encapsulates a remote cluster license check. The check is either successful if the license of the remote cluster is compatible with
+     * the predicate used to check license compatibility, or the check is a failure.
+     */
+    public static final class LicenseCheck {
+
+        private final RemoteClusterLicenseInfo remoteClusterLicenseInfo;
+
+        /**
+         * The remote cluster license info. This method should only be invoked if this instance represents a failing license check.
+         *
+         * @return the remote cluster license info
+         */
+        public RemoteClusterLicenseInfo remoteClusterLicenseInfo() {
+            assert isSuccess() == false;
+            return remoteClusterLicenseInfo;
+        }
+
+        private static final LicenseCheck SUCCESS = new LicenseCheck(null);
+
+        /**
+         * A successful license check.
+         *
+         * @return a successful license check instance
+         */
+        public static LicenseCheck success() {
+            return SUCCESS;
+        }
+
+        /**
+         * Test if this instance represents a successful license check.
+         *
+         * @return true if this instance represents a successful license check, otherwise false
+         */
+        public boolean isSuccess() {
+            return this == SUCCESS;
+        }
+
+        /**
+         * Creates a failing license check encapsulating the specified remote cluster license info.
+         *
+         * @param remoteClusterLicenseInfo the remote cluster license info
+         * @return a failing license check
+         */
+        public static LicenseCheck failure(final RemoteClusterLicenseInfo remoteClusterLicenseInfo) {
+            return new LicenseCheck(remoteClusterLicenseInfo);
+        }
+
+        private LicenseCheck(final RemoteClusterLicenseInfo remoteClusterLicenseInfo) {
+            this.remoteClusterLicenseInfo = remoteClusterLicenseInfo;
+        }
+
+    }
+
+    private final Client client;
+    private final Predicate<XPackInfoResponse.LicenseInfo> predicate;
+
+    /**
+     * Constructs a remote cluster license checker with the specified license predicate for checking license compatibility. The predicate
+     * does not need to check for the active license state as this is handled by the remote cluster license checker.
+     *
+     * @param client    the client
+     * @param predicate the license predicate
+     */
+    public RemoteClusterLicenseChecker(final Client client, final Predicate<XPackInfoResponse.LicenseInfo> predicate) {
+        this.client = client;
+        this.predicate = predicate;
+    }
+
+    public static boolean isLicensePlatinumOrTrial(final XPackInfoResponse.LicenseInfo licenseInfo) {
+        final License.OperationMode mode = License.OperationMode.resolve(licenseInfo.getMode());
+        return mode == License.OperationMode.PLATINUM || mode == License.OperationMode.TRIAL;
+    }
+
+    /**
+     * Checks the specified clusters for license compatibility. The specified callback will be invoked once if all clusters are
+     * license-compatible, otherwise the specified callback will be invoked once on the first cluster that is not license-compatible.
+     *
+     * @param clusterAliases the cluster aliases to check
+     * @param listener       a callback
+     */
+    public void checkRemoteClusterLicenses(final List<String> clusterAliases, final ActionListener<LicenseCheck> listener) {
+        final Iterator<String> clusterAliasesIterator = clusterAliases.iterator();
+        if (clusterAliasesIterator.hasNext() == false) {
+            listener.onResponse(LicenseCheck.success());
+            return;
+        }
+
+        final AtomicReference<String> clusterAlias = new AtomicReference<>();
+
+        final ActionListener<XPackInfoResponse> infoListener = new ActionListener<XPackInfoResponse>() {
+
+            @Override
+            public void onResponse(final XPackInfoResponse xPackInfoResponse) {
+                final XPackInfoResponse.LicenseInfo licenseInfo = xPackInfoResponse.getLicenseInfo();
+                if ((licenseInfo.getStatus() == LicenseStatus.ACTIVE) == false || predicate.test(licenseInfo) == false) {
+                    listener.onResponse(LicenseCheck.failure(new RemoteClusterLicenseInfo(clusterAlias.get(), licenseInfo)));
+                    return;
+                }
+
+                if (clusterAliasesIterator.hasNext()) {
+                    clusterAlias.set(clusterAliasesIterator.next());
+                    // recurse to the next cluster
+                    remoteClusterLicense(clusterAlias.get(), this);
+                } else {
+                    listener.onResponse(LicenseCheck.success());
+                }
+            }
+
+            @Override
+            public void onFailure(final Exception e) {
+                final String message = "could not determine the license type for cluster [" + clusterAlias.get() + "]";
+                listener.onFailure(new ElasticsearchException(message, e));
+            }
+
+        };
+
+        // check the license on the first cluster, and then we recursively check licenses on the remaining clusters
+        clusterAlias.set(clusterAliasesIterator.next());
+        remoteClusterLicense(clusterAlias.get(), infoListener);
+    }
+
+    private void remoteClusterLicense(final String clusterAlias, final ActionListener<XPackInfoResponse> listener) {
+        final ThreadContext threadContext = client.threadPool().getThreadContext();
+        final ContextPreservingActionListener<XPackInfoResponse> contextPreservingActionListener =
+                new ContextPreservingActionListener<>(threadContext.newRestorableContext(false), listener);
+        try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
+            // we stash any context here since this is an internal execution and should not leak any existing context information
+            threadContext.markAsSystemContext();
+
+            final XPackInfoRequest request = new XPackInfoRequest();
+            request.setCategories(EnumSet.of(XPackInfoRequest.Category.LICENSE));
+            try {
+                client.getRemoteClusterClient(clusterAlias).execute(XPackInfoAction.INSTANCE, request, contextPreservingActionListener);
+            } catch (final Exception e) {
+                contextPreservingActionListener.onFailure(e);
+            }
+        }
+    }
+
+    /**
+     * Predicate to test if the index name represents the name of a remote index.
+     *
+     * @param index the index name
+     * @return true if the collection of indices contains a remote index, otherwise false
+     */
+    public static boolean isRemoteIndex(final String index) {
+        return index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR) != -1;
+    }
+
+    /**
+     * Predicate to test if the collection of index names contains any that represent the name of a remote index.
+     *
+     * @param indices the collection of index names
+     * @return true if the collection of index names contains a name that represents a remote index, otherwise false
+     */
+    public static boolean containsRemoteIndex(final List<String> indices) {
+        return indices.stream().anyMatch(RemoteClusterLicenseChecker::isRemoteIndex);
+    }
+
+    /**
+     * Filters the collection of index names for names that represent a remote index. Remote index names are of the form
+     * {@code cluster_name:index_name}.
+     *
+     * @param indices the collection of index names
+     * @return list of index names that represent remote index names
+     */
+    public static List<String> remoteIndices(final List<String> indices) {
+        return indices.stream().filter(RemoteClusterLicenseChecker::isRemoteIndex).collect(Collectors.toList());
+    }
+
+    /**
+     * Extract the list of remote cluster aliases from the list of index names. Remote index names are of the form
+     * {@code cluster_alias:index_name} and the cluster_alias is extracted for each index name that represents a remote index.
+     *
+     * @param indices the collection of index names
+     * @return the remote cluster names
+     */
+    public static List<String> remoteClusterAliases(final List<String> indices) {
+        return indices.stream()
+                .filter(RemoteClusterLicenseChecker::isRemoteIndex)
+                .map(index -> index.substring(0, index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR)))
+                .distinct()
+                .collect(Collectors.toList());
+    }
+
+    /**
+     * Constructs an error message for license incompatibility.
+     *
+     * @param feature                  the name of the feature that initiated the remote cluster license check.
+     * @param remoteClusterLicenseInfo the remote cluster license info of the cluster that failed the license check
+     * @return an error message representing license incompatibility
+     */
+    public static String buildErrorMessage(
+            final String feature,
+            final RemoteClusterLicenseInfo remoteClusterLicenseInfo,
+            final Predicate<XPackInfoResponse.LicenseInfo> predicate) {
+        final StringBuilder error = new StringBuilder();
+        if (remoteClusterLicenseInfo.licenseInfo().getStatus() != LicenseStatus.ACTIVE) {
+            error.append(String.format(Locale.ROOT, "the license on cluster [%s] is not active", remoteClusterLicenseInfo.clusterAlias()));
+        } else {
+            assert predicate.test(remoteClusterLicenseInfo.licenseInfo()) == false : "license must be incompatible to build error message";
+            final String message = String.format(
+                    Locale.ROOT,
+                    "the license mode [%s] on cluster [%s] does not enable [%s]",
+                    License.OperationMode.resolve(remoteClusterLicenseInfo.licenseInfo().getMode()),
+                    remoteClusterLicenseInfo.clusterAlias(),
+                    feature);
+            error.append(message);
+        }
+
+        return error.toString();
+    }
+
+}

+ 414 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/license/RemoteClusterLicenseCheckerTests.java

@@ -0,0 +1,414 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.license;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.protocol.xpack.XPackInfoResponse;
+import org.elasticsearch.protocol.xpack.license.LicenseStatus;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.action.XPackInfoAction;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
+
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasToString;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Matchers.argThat;
+import static org.mockito.Matchers.same;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public final class RemoteClusterLicenseCheckerTests extends ESTestCase {
+
+    public void testIsNotRemoteIndex() {
+        assertFalse(RemoteClusterLicenseChecker.isRemoteIndex("local-index"));
+    }
+
+    public void testIsRemoteIndex() {
+        assertTrue(RemoteClusterLicenseChecker.isRemoteIndex("remote-cluster:remote-index"));
+    }
+
+    public void testNoRemoteIndex() {
+        final List<String> indices = Arrays.asList("local-index1", "local-index2");
+        assertFalse(RemoteClusterLicenseChecker.containsRemoteIndex(indices));
+    }
+
+    public void testRemoteIndex() {
+        final List<String> indices = Arrays.asList("local-index", "remote-cluster:remote-index");
+        assertTrue(RemoteClusterLicenseChecker.containsRemoteIndex(indices));
+    }
+
+    public void testNoRemoteIndices() {
+        final List<String> indices = Collections.singletonList("local-index");
+        assertThat(RemoteClusterLicenseChecker.remoteIndices(indices), is(empty()));
+    }
+
+    public void testRemoteIndices() {
+        final List<String> indices = Arrays.asList("local-index1", "remote-cluster1:index1", "local-index2", "remote-cluster2:index1");
+        assertThat(
+                RemoteClusterLicenseChecker.remoteIndices(indices),
+                containsInAnyOrder("remote-cluster1:index1", "remote-cluster2:index1"));
+    }
+
+    public void testNoRemoteClusterAliases() {
+        final List<String> indices = Arrays.asList("local-index1", "local-index2");
+        assertThat(RemoteClusterLicenseChecker.remoteClusterAliases(indices), empty());
+    }
+
+    public void testOneRemoteClusterAlias() {
+        final List<String> indices = Arrays.asList("local-index1", "remote-cluster1:remote-index1");
+        assertThat(RemoteClusterLicenseChecker.remoteClusterAliases(indices), contains("remote-cluster1"));
+    }
+
+    public void testMoreThanOneRemoteClusterAlias() {
+        final List<String> indices = Arrays.asList("remote-cluster1:remote-index1", "local-index1", "remote-cluster2:remote-index1");
+        assertThat(RemoteClusterLicenseChecker.remoteClusterAliases(indices), contains("remote-cluster1", "remote-cluster2"));
+    }
+
+    public void testDuplicateRemoteClusterAlias() {
+        final List<String> indices = Arrays.asList(
+                "remote-cluster1:remote-index1", "local-index1", "remote-cluster2:index1", "remote-cluster2:remote-index2");
+        assertThat(RemoteClusterLicenseChecker.remoteClusterAliases(indices), contains("remote-cluster1", "remote-cluster2"));
+    }
+
+    public void testCheckRemoteClusterLicensesGivenCompatibleLicenses() {
+        final AtomicInteger index = new AtomicInteger();
+        final List<XPackInfoResponse> responses = new ArrayList<>();
+
+        final ThreadPool threadPool = createMockThreadPool();
+        final Client client = createMockClient(threadPool);
+        doAnswer(invocationMock -> {
+            @SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
+                    (ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
+            listener.onResponse(responses.get(index.getAndIncrement()));
+            return null;
+        }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
+
+        final List<String> remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3");
+        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+
+        final RemoteClusterLicenseChecker licenseChecker =
+                new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
+        final AtomicReference<RemoteClusterLicenseChecker.LicenseCheck> licenseCheck = new AtomicReference<>();
+
+        licenseChecker.checkRemoteClusterLicenses(
+                remoteClusterAliases,
+                doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
+
+                    @Override
+                    public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
+                        licenseCheck.set(response);
+                    }
+
+                    @Override
+                    public void onFailure(final Exception e) {
+                        fail(e.getMessage());
+                    }
+
+                }));
+
+        verify(client, times(3)).execute(same(XPackInfoAction.INSTANCE), any(), any());
+        assertNotNull(licenseCheck.get());
+        assertTrue(licenseCheck.get().isSuccess());
+    }
+
+    public void testCheckRemoteClusterLicensesGivenIncompatibleLicense() {
+        final AtomicInteger index = new AtomicInteger();
+        final List<String> remoteClusterAliases = Arrays.asList("good", "cluster-with-basic-license", "good2");
+        final List<XPackInfoResponse> responses = new ArrayList<>();
+        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+        responses.add(new XPackInfoResponse(null, createBasicLicenseResponse(), null));
+        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+
+        final ThreadPool threadPool = createMockThreadPool();
+        final Client client = createMockClient(threadPool);
+        doAnswer(invocationMock -> {
+            @SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
+                    (ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
+            listener.onResponse(responses.get(index.getAndIncrement()));
+            return null;
+        }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
+
+        final RemoteClusterLicenseChecker licenseChecker =
+                new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
+        final AtomicReference<RemoteClusterLicenseChecker.LicenseCheck> licenseCheck = new AtomicReference<>();
+
+        licenseChecker.checkRemoteClusterLicenses(
+                remoteClusterAliases,
+                doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
+
+                    @Override
+                    public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
+                        licenseCheck.set(response);
+                    }
+
+                    @Override
+                    public void onFailure(final Exception e) {
+                        fail(e.getMessage());
+                    }
+
+                }));
+
+        verify(client, times(2)).execute(same(XPackInfoAction.INSTANCE), any(), any());
+        assertNotNull(licenseCheck.get());
+        assertFalse(licenseCheck.get().isSuccess());
+        assertThat(licenseCheck.get().remoteClusterLicenseInfo().clusterAlias(), equalTo("cluster-with-basic-license"));
+        assertThat(licenseCheck.get().remoteClusterLicenseInfo().licenseInfo().getType(), equalTo("BASIC"));
+    }
+
+    public void testCheckRemoteClusterLicencesGivenNonExistentCluster() {
+        final AtomicInteger index = new AtomicInteger();
+        final List<XPackInfoResponse> responses = new ArrayList<>();
+
+        final List<String> remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3");
+        final String failingClusterAlias = randomFrom(remoteClusterAliases);
+        final ThreadPool threadPool = createMockThreadPool();
+        final Client client = createMockClientThatThrowsOnGetRemoteClusterClient(threadPool, failingClusterAlias);
+        doAnswer(invocationMock -> {
+            @SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
+                    (ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
+            listener.onResponse(responses.get(index.getAndIncrement()));
+            return null;
+        }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
+
+        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+
+        final RemoteClusterLicenseChecker licenseChecker =
+                new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
+        final AtomicReference<Exception> exception = new AtomicReference<>();
+
+        licenseChecker.checkRemoteClusterLicenses(
+                remoteClusterAliases,
+                doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
+
+                    @Override
+                    public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
+                        fail();
+                    }
+
+                    @Override
+                    public void onFailure(final Exception e) {
+                        exception.set(e);
+                    }
+
+                }));
+
+        assertNotNull(exception.get());
+        assertThat(exception.get(), instanceOf(ElasticsearchException.class));
+        assertThat(exception.get().getMessage(), equalTo("could not determine the license type for cluster [" + failingClusterAlias + "]"));
+        assertNotNull(exception.get().getCause());
+        assertThat(exception.get().getCause(), instanceOf(IllegalArgumentException.class));
+    }
+
+    public void testRemoteClusterLicenseCallUsesSystemContext() throws InterruptedException {
+        final ThreadPool threadPool = new TestThreadPool(getTestName());
+
+        try {
+            final Client client = createMockClient(threadPool);
+            doAnswer(invocationMock -> {
+                assertTrue(threadPool.getThreadContext().isSystemContext());
+                @SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
+                        (ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
+                listener.onResponse(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+                return null;
+            }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
+
+            final RemoteClusterLicenseChecker licenseChecker =
+                    new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
+
+            final List<String> remoteClusterAliases = Collections.singletonList("valid");
+            licenseChecker.checkRemoteClusterLicenses(
+                    remoteClusterAliases, doubleInvocationProtectingListener(ActionListener.wrap(() -> {})));
+
+            verify(client, times(1)).execute(same(XPackInfoAction.INSTANCE), any(), any());
+        } finally {
+            terminate(threadPool);
+        }
+    }
+
+    public void testListenerIsExecutedWithCallingContext() throws InterruptedException {
+        final AtomicInteger index = new AtomicInteger();
+        final List<XPackInfoResponse> responses = new ArrayList<>();
+
+        final ThreadPool threadPool = new TestThreadPool(getTestName());
+
+        try {
+            final List<String> remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3");
+            final Client client;
+            final boolean failure = randomBoolean();
+            if (failure) {
+                client = createMockClientThatThrowsOnGetRemoteClusterClient(threadPool, randomFrom(remoteClusterAliases));
+            } else {
+                client = createMockClient(threadPool);
+            }
+            doAnswer(invocationMock -> {
+                @SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
+                        (ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
+                listener.onResponse(responses.get(index.getAndIncrement()));
+                return null;
+            }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
+
+            responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+            responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+            responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
+
+            final RemoteClusterLicenseChecker licenseChecker =
+                    new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
+
+            final AtomicBoolean listenerInvoked = new AtomicBoolean();
+            threadPool.getThreadContext().putHeader("key", "value");
+            licenseChecker.checkRemoteClusterLicenses(
+                    remoteClusterAliases,
+                    doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
+
+                        @Override
+                        public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
+                            if (failure) {
+                                fail();
+                            }
+                            assertThat(threadPool.getThreadContext().getHeader("key"), equalTo("value"));
+                            assertFalse(threadPool.getThreadContext().isSystemContext());
+                            listenerInvoked.set(true);
+                        }
+
+                        @Override
+                        public void onFailure(final Exception e) {
+                            if (failure == false) {
+                                fail();
+                            }
+                            assertThat(threadPool.getThreadContext().getHeader("key"), equalTo("value"));
+                            assertFalse(threadPool.getThreadContext().isSystemContext());
+                            listenerInvoked.set(true);
+                        }
+
+                    }));
+
+            assertTrue(listenerInvoked.get());
+        } finally {
+            terminate(threadPool);
+        }
+    }
+
+    public void testBuildErrorMessageForActiveCompatibleLicense() {
+        final XPackInfoResponse.LicenseInfo platinumLicence = createPlatinumLicenseResponse();
+        final RemoteClusterLicenseChecker.RemoteClusterLicenseInfo info =
+                new RemoteClusterLicenseChecker.RemoteClusterLicenseInfo("platinum-cluster", platinumLicence);
+        final AssertionError e = expectThrows(
+                AssertionError.class,
+                () -> RemoteClusterLicenseChecker.buildErrorMessage("", info, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial));
+        assertThat(e, hasToString(containsString("license must be incompatible to build error message")));
+    }
+
+    public void testBuildErrorMessageForIncompatibleLicense() {
+        final XPackInfoResponse.LicenseInfo basicLicense = createBasicLicenseResponse();
+        final RemoteClusterLicenseChecker.RemoteClusterLicenseInfo info =
+                new RemoteClusterLicenseChecker.RemoteClusterLicenseInfo("basic-cluster", basicLicense);
+        assertThat(
+                RemoteClusterLicenseChecker.buildErrorMessage("Feature", info, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial),
+                equalTo("the license mode [BASIC] on cluster [basic-cluster] does not enable [Feature]"));
+    }
+
+    public void testBuildErrorMessageForInactiveLicense() {
+        final XPackInfoResponse.LicenseInfo expiredLicense = createExpiredLicenseResponse();
+        final RemoteClusterLicenseChecker.RemoteClusterLicenseInfo info =
+                new RemoteClusterLicenseChecker.RemoteClusterLicenseInfo("expired-cluster", expiredLicense);
+        assertThat(
+                RemoteClusterLicenseChecker.buildErrorMessage("Feature", info, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial),
+                equalTo("the license on cluster [expired-cluster] is not active"));
+    }
+
+    private ActionListener<RemoteClusterLicenseChecker.LicenseCheck> doubleInvocationProtectingListener(
+            final ActionListener<RemoteClusterLicenseChecker.LicenseCheck> listener) {
+        final AtomicBoolean listenerInvoked = new AtomicBoolean();
+        return new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
+
+            @Override
+            public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
+                if (listenerInvoked.compareAndSet(false, true) == false) {
+                    fail("listener invoked twice");
+                }
+                listener.onResponse(response);
+            }
+
+            @Override
+            public void onFailure(final Exception e) {
+                if (listenerInvoked.compareAndSet(false, true) == false) {
+                    fail("listener invoked twice");
+                }
+                listener.onFailure(e);
+            }
+
+        };
+    }
+
+    private ThreadPool createMockThreadPool() {
+        final ThreadPool threadPool = mock(ThreadPool.class);
+        when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
+        return threadPool;
+    }
+
+    private Client createMockClient(final ThreadPool threadPool) {
+        return createMockClient(threadPool, client -> when(client.getRemoteClusterClient(anyString())).thenReturn(client));
+    }
+
+    private Client createMockClientThatThrowsOnGetRemoteClusterClient(final ThreadPool threadPool, final String clusterAlias) {
+        return createMockClient(
+                threadPool,
+                client -> {
+                    when(client.getRemoteClusterClient(clusterAlias)).thenThrow(new IllegalArgumentException());
+                    when(client.getRemoteClusterClient(argThat(not(clusterAlias)))).thenReturn(client);
+                });
+    }
+
+    private Client createMockClient(final ThreadPool threadPool, final Consumer<Client> finish) {
+        final Client client = mock(Client.class);
+        when(client.threadPool()).thenReturn(threadPool);
+        finish.accept(client);
+        return client;
+    }
+
+    private XPackInfoResponse.LicenseInfo createPlatinumLicenseResponse() {
+        return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", LicenseStatus.ACTIVE, randomNonNegativeLong());
+    }
+
+    private XPackInfoResponse.LicenseInfo createBasicLicenseResponse() {
+        return new XPackInfoResponse.LicenseInfo("uid", "BASIC", "BASIC", LicenseStatus.ACTIVE, randomNonNegativeLong());
+    }
+
+    private XPackInfoResponse.LicenseInfo createExpiredLicenseResponse() {
+        return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", LicenseStatus.EXPIRED, randomNonNegativeLong());
+    }
+
+}

+ 37 - 21
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java

@@ -23,6 +23,7 @@ import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.license.LicenseUtils;
+import org.elasticsearch.license.RemoteClusterLicenseChecker;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.persistent.AllocatedPersistentTask;
 import org.elasticsearch.persistent.PersistentTaskState;
@@ -46,10 +47,10 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.datafeed.DatafeedManager;
 import org.elasticsearch.xpack.ml.datafeed.DatafeedNodeSelector;
-import org.elasticsearch.xpack.ml.datafeed.MlRemoteLicenseChecker;
 import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractorFactory;
 
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.function.Predicate;
 
@@ -141,19 +142,22 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
             DatafeedConfig datafeed = mlMetadata.getDatafeed(params.getDatafeedId());
             Job job = mlMetadata.getJobs().get(datafeed.getJobId());
 
-            if (MlRemoteLicenseChecker.containsRemoteIndex(datafeed.getIndices())) {
-                MlRemoteLicenseChecker remoteLicenseChecker = new MlRemoteLicenseChecker(client);
-                remoteLicenseChecker.checkRemoteClusterLicenses(MlRemoteLicenseChecker.remoteClusterNames(datafeed.getIndices()),
+            if (RemoteClusterLicenseChecker.containsRemoteIndex(datafeed.getIndices())) {
+                final RemoteClusterLicenseChecker remoteClusterLicenseChecker =
+                        new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
+                remoteClusterLicenseChecker.checkRemoteClusterLicenses(
+                        RemoteClusterLicenseChecker.remoteClusterAliases(datafeed.getIndices()),
                         ActionListener.wrap(
                                 response -> {
-                                    if (response.isViolated()) {
+                                    if (response.isSuccess() == false) {
                                         listener.onFailure(createUnlicensedError(datafeed.getId(), response));
                                     } else {
                                         createDataExtractor(job, datafeed, params, waitForTaskListener);
                                     }
                                 },
-                                e -> listener.onFailure(createUnknownLicenseError(datafeed.getId(),
-                                        MlRemoteLicenseChecker.remoteIndices(datafeed.getIndices()), e))
+                                e -> listener.onFailure(
+                                        createUnknownLicenseError(
+                                                datafeed.getId(), RemoteClusterLicenseChecker.remoteIndices(datafeed.getIndices()), e))
                         ));
             } else {
                 createDataExtractor(job, datafeed, params, waitForTaskListener);
@@ -232,23 +236,35 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
         );
     }
 
-    private ElasticsearchStatusException createUnlicensedError(String datafeedId,
-                                                               MlRemoteLicenseChecker.LicenseViolation licenseViolation) {
-        String message = "Cannot start datafeed [" + datafeedId + "] as it is configured to use "
-                + "indices on a remote cluster [" + licenseViolation.get().getClusterName()
-                + "] that is not licensed for Machine Learning. "
-                + MlRemoteLicenseChecker.buildErrorMessage(licenseViolation.get());
-
+    private ElasticsearchStatusException createUnlicensedError(
+            final String datafeedId, final RemoteClusterLicenseChecker.LicenseCheck licenseCheck) {
+        final String message = String.format(
+                Locale.ROOT,
+                "cannot start datafeed [%s] as it is configured to use indices on remote cluster [%s] that is not licensed for ml; %s",
+                datafeedId,
+                licenseCheck.remoteClusterLicenseInfo().clusterAlias(),
+                RemoteClusterLicenseChecker.buildErrorMessage(
+                        "ml",
+                        licenseCheck.remoteClusterLicenseInfo(),
+                        RemoteClusterLicenseChecker::isLicensePlatinumOrTrial));
         return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST);
     }
 
-    private ElasticsearchStatusException createUnknownLicenseError(String datafeedId, List<String> remoteIndices,
-                                                                   Exception cause) {
-        String message = "Cannot start datafeed [" + datafeedId + "] as it is configured to use"
-                + " indices on a remote cluster " + remoteIndices
-                + " but the license type could not be verified";
-
-        return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST, new Exception(cause.getMessage()));
+    private ElasticsearchStatusException createUnknownLicenseError(
+            final String datafeedId, final List<String> remoteIndices, final Exception cause) {
+        final int numberOfRemoteClusters = RemoteClusterLicenseChecker.remoteClusterAliases(remoteIndices).size();
+        assert numberOfRemoteClusters > 0;
+        final String remoteClusterQualifier = numberOfRemoteClusters == 1 ? "a remote cluster" : "remote clusters";
+        final String licenseTypeQualifier = numberOfRemoteClusters == 1 ? "" : "s";
+        final String message = String.format(
+                Locale.ROOT,
+                "cannot start datafeed [%s] as it uses indices on %s %s but the license type%s could not be verified",
+                datafeedId,
+                remoteClusterQualifier,
+                remoteIndices,
+                licenseTypeQualifier);
+
+        return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST, cause);
     }
 
     public static class StartDatafeedPersistentTasksExecutor extends PersistentTasksExecutor<StartDatafeedAction.DatafeedParams> {

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java

@@ -12,6 +12,7 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.routing.IndexRoutingTable;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.logging.Loggers;
+import org.elasticsearch.license.RemoteClusterLicenseChecker;
 import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.MlTasks;
@@ -92,7 +93,7 @@ public class DatafeedNodeSelector {
         List<String> indices = datafeed.getIndices();
         for (String index : indices) {
 
-            if (MlRemoteLicenseChecker.isRemoteIndex(index)) {
+            if (RemoteClusterLicenseChecker.isRemoteIndex(index)) {
                 // We cannot verify remote indices
                 continue;
             }

+ 0 - 193
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseChecker.java

@@ -1,193 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License;
- * you may not use this file except in compliance with the Elastic License.
- */
-
-package org.elasticsearch.xpack.ml.datafeed;
-
-import org.elasticsearch.ElasticsearchException;
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.client.Client;
-import org.elasticsearch.common.Nullable;
-import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.license.License;
-import org.elasticsearch.protocol.xpack.XPackInfoRequest;
-import org.elasticsearch.protocol.xpack.XPackInfoResponse;
-import org.elasticsearch.protocol.xpack.license.LicenseStatus;
-import org.elasticsearch.transport.ActionNotFoundTransportException;
-import org.elasticsearch.transport.RemoteClusterAware;
-import org.elasticsearch.xpack.core.action.XPackInfoAction;
-
-import java.util.EnumSet;
-import java.util.Iterator;
-import java.util.List;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.stream.Collectors;
-
-/**
- * ML datafeeds can use cross cluster search to access data in a remote cluster.
- * The remote cluster should be licenced for ML this class performs that check
- * using the _xpack (info) endpoint.
- */
-public class MlRemoteLicenseChecker {
-
-    private final Client client;
-
-    public static class RemoteClusterLicenseInfo {
-        private final String clusterName;
-        private final XPackInfoResponse.LicenseInfo licenseInfo;
-
-        RemoteClusterLicenseInfo(String clusterName, XPackInfoResponse.LicenseInfo licenseInfo) {
-            this.clusterName = clusterName;
-            this.licenseInfo = licenseInfo;
-        }
-
-        public String getClusterName() {
-            return clusterName;
-        }
-
-        public XPackInfoResponse.LicenseInfo getLicenseInfo() {
-            return licenseInfo;
-        }
-    }
-
-    public class LicenseViolation {
-        private final RemoteClusterLicenseInfo licenseInfo;
-
-        private LicenseViolation(@Nullable RemoteClusterLicenseInfo licenseInfo) {
-            this.licenseInfo = licenseInfo;
-        }
-
-        public boolean isViolated() {
-            return licenseInfo != null;
-        }
-
-        public RemoteClusterLicenseInfo get() {
-            return licenseInfo;
-        }
-    }
-
-    public MlRemoteLicenseChecker(Client client) {
-        this.client = client;
-    }
-
-    /**
-     * Check each cluster is licensed for ML.
-     * This function evaluates lazily and will terminate when the first cluster
-     * that is not licensed is found or an error occurs.
-     *
-     * @param clusterNames List of remote cluster names
-     * @param listener Response listener
-     */
-    public void checkRemoteClusterLicenses(List<String> clusterNames, ActionListener<LicenseViolation> listener) {
-        final Iterator<String> itr = clusterNames.iterator();
-        if (itr.hasNext() == false) {
-            listener.onResponse(new LicenseViolation(null));
-            return;
-        }
-
-        final AtomicReference<String> clusterName = new AtomicReference<>(itr.next());
-
-        ActionListener<XPackInfoResponse> infoListener = new ActionListener<XPackInfoResponse>() {
-            @Override
-            public void onResponse(XPackInfoResponse xPackInfoResponse) {
-                if (licenseSupportsML(xPackInfoResponse.getLicenseInfo()) == false) {
-                    listener.onResponse(new LicenseViolation(
-                            new RemoteClusterLicenseInfo(clusterName.get(), xPackInfoResponse.getLicenseInfo())));
-                    return;
-                }
-
-                if (itr.hasNext()) {
-                    clusterName.set(itr.next());
-                    remoteClusterLicense(clusterName.get(), this);
-                } else {
-                    listener.onResponse(new LicenseViolation(null));
-                }
-            }
-
-            @Override
-            public void onFailure(Exception e) {
-                String message = "Could not determine the X-Pack licence type for cluster [" + clusterName.get() + "]";
-                if (e instanceof ActionNotFoundTransportException) {
-                    // This is likely to be because x-pack is not installed in the target cluster
-                    message += ". Is X-Pack installed on the target cluster?";
-                }
-                listener.onFailure(new ElasticsearchException(message, e));
-            }
-        };
-
-        remoteClusterLicense(clusterName.get(), infoListener);
-    }
-
-    private void remoteClusterLicense(String clusterName, ActionListener<XPackInfoResponse> listener) {
-        Client remoteClusterClient = client.getRemoteClusterClient(clusterName);
-        ThreadContext threadContext = remoteClusterClient.threadPool().getThreadContext();
-        try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
-            // we stash any context here since this is an internal execution and should not leak any
-            // existing context information.
-            threadContext.markAsSystemContext();
-
-            XPackInfoRequest request = new XPackInfoRequest();
-            request.setCategories(EnumSet.of(XPackInfoRequest.Category.LICENSE));
-            remoteClusterClient.execute(XPackInfoAction.INSTANCE, request, listener);
-        }
-    }
-
-    static boolean licenseSupportsML(XPackInfoResponse.LicenseInfo licenseInfo) {
-        License.OperationMode mode = License.OperationMode.resolve(licenseInfo.getMode());
-        return licenseInfo.getStatus() == LicenseStatus.ACTIVE &&
-                (mode == License.OperationMode.PLATINUM || mode == License.OperationMode.TRIAL);
-    }
-
-    public static boolean isRemoteIndex(String index) {
-        return index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR) != -1;
-    }
-
-    public static boolean containsRemoteIndex(List<String> indices) {
-        return indices.stream().anyMatch(MlRemoteLicenseChecker::isRemoteIndex);
-    }
-
-    /**
-     * Get any remote indices used in cross cluster search.
-     * Remote indices are of the form {@code cluster_name:index_name}
-     * @return List of remote cluster indices
-     */
-    public static List<String> remoteIndices(List<String> indices) {
-        return indices.stream().filter(MlRemoteLicenseChecker::isRemoteIndex).collect(Collectors.toList());
-    }
-
-    /**
-     * Extract the list of remote cluster names from the list of indices.
-     * @param indices List of indices. Remote cluster indices are prefixed
-     *                with {@code cluster-name:}
-     * @return Every cluster name found in {@code indices}
-     */
-    public static List<String> remoteClusterNames(List<String> indices) {
-        return indices.stream()
-                .filter(MlRemoteLicenseChecker::isRemoteIndex)
-                .map(index -> index.substring(0, index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR)))
-                .distinct()
-                .collect(Collectors.toList());
-    }
-
-    public static String buildErrorMessage(RemoteClusterLicenseInfo clusterLicenseInfo) {
-        StringBuilder error = new StringBuilder();
-        if (clusterLicenseInfo.licenseInfo.getStatus() != LicenseStatus.ACTIVE) {
-            error.append("The license on cluster [").append(clusterLicenseInfo.clusterName)
-                    .append("] is not active. ");
-        } else {
-            License.OperationMode mode = License.OperationMode.resolve(clusterLicenseInfo.licenseInfo.getMode());
-            if (mode != License.OperationMode.PLATINUM && mode != License.OperationMode.TRIAL) {
-                error.append("The license mode [").append(mode)
-                        .append("] on cluster [")
-                        .append(clusterLicenseInfo.clusterName)
-                        .append("] does not enable Machine Learning. ");
-            }
-        }
-
-        error.append(Strings.toString(clusterLicenseInfo.licenseInfo));
-        return error.toString();
-    }
-}

+ 2 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedActionTests.java

@@ -3,10 +3,12 @@
  * or more contributor license agreements. Licensed under the Elastic License;
  * you may not use this file except in compliance with the Elastic License.
  */
+
 package org.elasticsearch.xpack.ml.action;
 
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
@@ -14,7 +16,6 @@ import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
-import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
 import org.elasticsearch.xpack.ml.datafeed.DatafeedManager;
 import org.elasticsearch.xpack.ml.datafeed.DatafeedManagerTests;
 

+ 0 - 199
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseCheckerTests.java

@@ -1,199 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License;
- * you may not use this file except in compliance with the Elastic License.
- */
-
-package org.elasticsearch.xpack.ml.datafeed;
-
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.client.Client;
-import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.protocol.xpack.XPackInfoResponse;
-import org.elasticsearch.protocol.xpack.license.LicenseStatus;
-import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.xpack.core.action.XPackInfoAction;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReference;
-
-import static org.hamcrest.Matchers.contains;
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.hamcrest.Matchers.empty;
-import static org.hamcrest.Matchers.is;
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyString;
-import static org.mockito.Matchers.same;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-public class MlRemoteLicenseCheckerTests extends ESTestCase {
-
-    public void testIsRemoteIndex() {
-        List<String> indices = Arrays.asList("local-index1", "local-index2");
-        assertFalse(MlRemoteLicenseChecker.containsRemoteIndex(indices));
-        indices = Arrays.asList("local-index1", "remote-cluster:remote-index2");
-        assertTrue(MlRemoteLicenseChecker.containsRemoteIndex(indices));
-    }
-
-    public void testRemoteIndices() {
-        List<String> indices = Collections.singletonList("local-index");
-        assertThat(MlRemoteLicenseChecker.remoteIndices(indices), is(empty()));
-        indices = Arrays.asList("local-index", "remote-cluster:index1", "local-index2", "remote-cluster2:index1");
-        assertThat(MlRemoteLicenseChecker.remoteIndices(indices), containsInAnyOrder("remote-cluster:index1", "remote-cluster2:index1"));
-    }
-
-    public void testRemoteClusterNames() {
-        List<String> indices = Arrays.asList("local-index1", "local-index2");
-        assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), empty());
-        indices = Arrays.asList("local-index1", "remote-cluster1:remote-index2");
-        assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1"));
-        indices = Arrays.asList("remote-cluster1:index2", "index1", "remote-cluster2:index1");
-        assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1", "remote-cluster2"));
-        indices = Arrays.asList("remote-cluster1:index2", "index1", "remote-cluster2:index1", "remote-cluster2:index2");
-        assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1", "remote-cluster2"));
-    }
-
-    public void testLicenseSupportsML() {
-        XPackInfoResponse.LicenseInfo licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "trial", "trial",
-                LicenseStatus.ACTIVE, randomNonNegativeLong());
-        assertTrue(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo));
-
-        licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "trial", "trial", LicenseStatus.EXPIRED, randomNonNegativeLong());
-        assertFalse(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo));
-
-        licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "GOLD", "GOLD", LicenseStatus.ACTIVE, randomNonNegativeLong());
-        assertFalse(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo));
-
-        licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", LicenseStatus.ACTIVE, randomNonNegativeLong());
-        assertTrue(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo));
-    }
-
-    public void testCheckRemoteClusterLicenses_givenValidLicenses() {
-        final AtomicInteger index = new AtomicInteger(0);
-        final List<XPackInfoResponse> responses = new ArrayList<>();
-
-        Client client = createMockClient();
-        doAnswer(invocationMock -> {
-            @SuppressWarnings("raw_types")
-            ActionListener listener = (ActionListener) invocationMock.getArguments()[2];
-            listener.onResponse(responses.get(index.getAndIncrement()));
-            return null;
-        }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
-
-
-        List<String> remoteClusterNames = Arrays.asList("valid1", "valid2", "valid3");
-        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
-        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
-        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
-
-        MlRemoteLicenseChecker licenseChecker = new MlRemoteLicenseChecker(client);
-        AtomicReference<MlRemoteLicenseChecker.LicenseViolation> licCheckResponse = new AtomicReference<>();
-
-        licenseChecker.checkRemoteClusterLicenses(remoteClusterNames,
-                new ActionListener<MlRemoteLicenseChecker.LicenseViolation>() {
-            @Override
-            public void onResponse(MlRemoteLicenseChecker.LicenseViolation response) {
-                licCheckResponse.set(response);
-            }
-
-            @Override
-            public void onFailure(Exception e) {
-                fail(e.getMessage());
-            }
-        });
-
-        verify(client, times(3)).execute(same(XPackInfoAction.INSTANCE), any(), any());
-        assertNotNull(licCheckResponse.get());
-        assertFalse(licCheckResponse.get().isViolated());
-        assertNull(licCheckResponse.get().get());
-    }
-
-    public void testCheckRemoteClusterLicenses_givenInvalidLicense() {
-        final AtomicInteger index = new AtomicInteger(0);
-        List<String> remoteClusterNames = Arrays.asList("good", "cluster-with-basic-license", "good2");
-        final List<XPackInfoResponse> responses = new ArrayList<>();
-        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
-        responses.add(new XPackInfoResponse(null, createBasicLicenseResponse(), null));
-        responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
-
-        Client client = createMockClient();
-        doAnswer(invocationMock -> {
-            @SuppressWarnings("raw_types")
-            ActionListener listener = (ActionListener) invocationMock.getArguments()[2];
-            listener.onResponse(responses.get(index.getAndIncrement()));
-            return null;
-        }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
-
-        MlRemoteLicenseChecker licenseChecker = new MlRemoteLicenseChecker(client);
-        AtomicReference<MlRemoteLicenseChecker.LicenseViolation> licCheckResponse = new AtomicReference<>();
-
-        licenseChecker.checkRemoteClusterLicenses(remoteClusterNames,
-                new ActionListener<MlRemoteLicenseChecker.LicenseViolation>() {
-            @Override
-            public void onResponse(MlRemoteLicenseChecker.LicenseViolation response) {
-                licCheckResponse.set(response);
-            }
-
-            @Override
-            public void onFailure(Exception e) {
-                fail(e.getMessage());
-            }
-        });
-
-        verify(client, times(2)).execute(same(XPackInfoAction.INSTANCE), any(), any());
-        assertNotNull(licCheckResponse.get());
-        assertTrue(licCheckResponse.get().isViolated());
-        assertEquals("cluster-with-basic-license", licCheckResponse.get().get().getClusterName());
-        assertEquals("BASIC", licCheckResponse.get().get().getLicenseInfo().getType());
-    }
-
-    public void testBuildErrorMessage() {
-        XPackInfoResponse.LicenseInfo platinumLicence = createPlatinumLicenseResponse();
-        MlRemoteLicenseChecker.RemoteClusterLicenseInfo info =
-                new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("platinum-cluster", platinumLicence);
-        assertEquals(Strings.toString(platinumLicence), MlRemoteLicenseChecker.buildErrorMessage(info));
-
-        XPackInfoResponse.LicenseInfo basicLicense = createBasicLicenseResponse();
-        info = new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("basic-cluster", basicLicense);
-        String expected = "The license mode [BASIC] on cluster [basic-cluster] does not enable Machine Learning. "
-                + Strings.toString(basicLicense);
-        assertEquals(expected, MlRemoteLicenseChecker.buildErrorMessage(info));
-
-        XPackInfoResponse.LicenseInfo expiredLicense = createExpiredLicenseResponse();
-        info = new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("expired-cluster", expiredLicense);
-        expected = "The license on cluster [expired-cluster] is not active. " + Strings.toString(expiredLicense);
-        assertEquals(expected, MlRemoteLicenseChecker.buildErrorMessage(info));
-    }
-
-    private Client createMockClient() {
-        Client client = mock(Client.class);
-        ThreadPool threadPool = mock(ThreadPool.class);
-        when(client.threadPool()).thenReturn(threadPool);
-        when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
-        when(client.getRemoteClusterClient(anyString())).thenReturn(client);
-        return client;
-    }
-
-    private XPackInfoResponse.LicenseInfo createPlatinumLicenseResponse() {
-        return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", LicenseStatus.ACTIVE, randomNonNegativeLong());
-    }
-
-    private XPackInfoResponse.LicenseInfo createBasicLicenseResponse() {
-        return new XPackInfoResponse.LicenseInfo("uid", "BASIC", "BASIC", LicenseStatus.ACTIVE, randomNonNegativeLong());
-    }
-
-    private XPackInfoResponse.LicenseInfo createExpiredLicenseResponse() {
-        return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", LicenseStatus.EXPIRED, randomNonNegativeLong());
-    }
-}