Browse Source

Add TransportHealthNodeAction (#89127)

Mary Gouseti 3 years ago
parent
commit
399a8ac283

+ 8 - 1
server/src/main/java/org/elasticsearch/ElasticsearchException.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.logging.LoggerMessageFormat;
 import org.elasticsearch.core.CheckedFunction;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Tuple;
+import org.elasticsearch.health.node.action.HealthNodeNotDiscoveredException;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.rest.RestStatus;
@@ -721,7 +722,7 @@ public class ElasticsearchException extends RuntimeException implements ToXConte
     /**
      * This is the list of Exceptions Elasticsearch can throw over the wire or save into a corruption marker. Each value in the enum is a
      * single exception tying the Class to an id for use of the encode side and the id back to a constructor for use on the decode side. As
-     * such its ok if the exceptions to change names so long as their constructor can still read the exception. Each exception is listed
+     * such it's ok if the exceptions to change names so long as their constructor can still read the exception. Each exception is listed
      * in id order below. If you want to remove an exception leave a tombstone comment and mark the id as null in
      * ExceptionSerializationTests.testIds.ids.
      */
@@ -1571,6 +1572,12 @@ public class ElasticsearchException extends RuntimeException implements ToXConte
             org.elasticsearch.snapshots.SnapshotNameAlreadyInUseException::new,
             165,
             Version.V_8_2_0
+        ),
+        HEALTH_NODE_NOT_DISCOVERED_EXCEPTION(
+            HealthNodeNotDiscoveredException.class,
+            HealthNodeNotDiscoveredException::new,
+            166,
+            Version.V_8_5_0
         );
 
         final Class<? extends ElasticsearchException> exceptionClass;

+ 35 - 0
server/src/main/java/org/elasticsearch/health/node/action/HealthNodeNotDiscoveredException.java

@@ -0,0 +1,35 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.health.node.action;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.rest.RestStatus;
+
+import java.io.IOException;
+
+/**
+ * Exception which indicates that no health node is selected in this cluster, aka the
+ * health node persistent task is not assigned.
+ */
+public class HealthNodeNotDiscoveredException extends ElasticsearchException {
+
+    public HealthNodeNotDiscoveredException(String message) {
+        super(message);
+    }
+
+    public HealthNodeNotDiscoveredException(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    @Override
+    public RestStatus status() {
+        return RestStatus.SERVICE_UNAVAILABLE;
+    }
+}

+ 123 - 0
server/src/main/java/org/elasticsearch/health/node/action/TransportHealthNodeAction.java

@@ -0,0 +1,123 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.health.node.action;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionListenerResponseHandler;
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.HandledTransportAction;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.health.node.selection.HealthNode;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskCancelledException;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportException;
+import org.elasticsearch.transport.TransportRequestOptions;
+import org.elasticsearch.transport.TransportService;
+
+import static org.elasticsearch.core.Strings.format;
+
+/**
+ * A base class for operations that need to be performed on the health node.
+ */
+public abstract class TransportHealthNodeAction<Request extends ActionRequest, Response extends ActionResponse> extends
+    HandledTransportAction<Request, Response> {
+
+    private static final Logger logger = LogManager.getLogger(TransportHealthNodeAction.class);
+
+    protected final TransportService transportService;
+    protected final ClusterService clusterService;
+    protected final ThreadPool threadPool;
+    protected final String executor;
+
+    private final Writeable.Reader<Response> responseReader;
+
+    protected TransportHealthNodeAction(
+        String actionName,
+        TransportService transportService,
+        ClusterService clusterService,
+        ThreadPool threadPool,
+        ActionFilters actionFilters,
+        Writeable.Reader<Request> request,
+        Writeable.Reader<Response> response,
+        String executor
+    ) {
+        super(actionName, true, transportService, actionFilters, request);
+        this.transportService = transportService;
+        this.clusterService = clusterService;
+        this.threadPool = threadPool;
+        this.executor = executor;
+        this.responseReader = response;
+    }
+
+    protected abstract void healthOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener)
+        throws Exception;
+
+    @Override
+    protected void doExecute(Task task, final Request request, ActionListener<Response> listener) {
+        ClusterState state = clusterService.state();
+        logger.trace("starting to process request [{}] with cluster state version [{}]", request, state.version());
+        if (isTaskCancelled(task)) {
+            listener.onFailure(new TaskCancelledException("Task was cancelled"));
+            return;
+        }
+        try {
+            ClusterState clusterState = clusterService.state();
+            DiscoveryNode healthNode = HealthNode.findHealthNode(clusterState);
+            DiscoveryNode localNode = clusterState.nodes().getLocalNode();
+            if (healthNode == null) {
+                listener.onFailure(new HealthNodeNotDiscoveredException("Health node was null"));
+            } else if (localNode.getId().equals(healthNode.getId())) {
+                threadPool.executor(executor).execute(() -> {
+                    try {
+                        if (isTaskCancelled(task)) {
+                            listener.onFailure(new TaskCancelledException("Task was cancelled"));
+                        } else {
+                            healthOperation(task, request, clusterState, listener);
+                        }
+                    } catch (Exception e) {
+                        listener.onFailure(e);
+                    }
+                });
+            } else {
+                logger.trace("forwarding request [{}] to health node [{}]", actionName, healthNode);
+                ActionListenerResponseHandler<Response> handler = new ActionListenerResponseHandler<>(listener, responseReader) {
+                    @Override
+                    public void handleException(final TransportException exception) {
+                        logger.trace(
+                            () -> format("failure when forwarding request [%s] to health node [%s]", actionName, healthNode),
+                            exception
+                        );
+                        listener.onFailure(exception);
+                    }
+                };
+                if (task != null) {
+                    transportService.sendChildRequest(healthNode, actionName, request, task, TransportRequestOptions.EMPTY, handler);
+                } else {
+                    transportService.sendRequest(healthNode, actionName, request, handler);
+                }
+            }
+        } catch (Exception e) {
+            logger.trace(() -> format("Failed to route/execute health node action %s", actionName), e);
+            listener.onFailure(e);
+        }
+    }
+
+    private boolean isTaskCancelled(Task task) {
+        return (task instanceof CancellableTask t) && t.isCancelled();
+    }
+}

+ 10 - 0
server/src/main/java/org/elasticsearch/health/node/selection/HealthNode.java

@@ -9,6 +9,7 @@
 package org.elasticsearch.health.node.selection;
 
 import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.persistent.AllocatedPersistentTask;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
@@ -43,4 +44,13 @@ public class HealthNode extends AllocatedPersistentTask {
         PersistentTasksCustomMetadata taskMetadata = clusterState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
         return taskMetadata == null ? null : taskMetadata.getTask(TASK_NAME);
     }
+
+    @Nullable
+    public static DiscoveryNode findHealthNode(ClusterState clusterState) {
+        PersistentTasksCustomMetadata.PersistentTask<?> task = findTask(clusterState);
+        if (task == null || task.isAssigned() == false) {
+            return null;
+        }
+        return clusterState.nodes().get(task.getAssignment().getExecutorNode());
+    }
 }

+ 3 - 3
server/src/main/java/org/elasticsearch/health/node/selection/HealthNodeTaskParams.java

@@ -23,15 +23,15 @@ import static org.elasticsearch.health.node.selection.HealthNode.TASK_NAME;
 /**
  * Encapsulates the parameters needed to start the health node task, currently no parameters are required.
  */
-class HealthNodeTaskParams implements PersistentTaskParams {
+public class HealthNodeTaskParams implements PersistentTaskParams {
 
-    private static final HealthNodeTaskParams INSTANCE = new HealthNodeTaskParams();
+    public static final HealthNodeTaskParams INSTANCE = new HealthNodeTaskParams();
 
     public static final ObjectParser<HealthNodeTaskParams, Void> PARSER = new ObjectParser<>(TASK_NAME, true, () -> INSTANCE);
 
     HealthNodeTaskParams() {}
 
-    HealthNodeTaskParams(StreamInput in) {}
+    HealthNodeTaskParams(StreamInput ignored) {}
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {

+ 2 - 0
server/src/test/java/org/elasticsearch/ExceptionSerializationTests.java

@@ -47,6 +47,7 @@ import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.core.PathUtils;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.env.ShardLockObtainFailedException;
+import org.elasticsearch.health.node.action.HealthNodeNotDiscoveredException;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.engine.RecoveryEngineException;
 import org.elasticsearch.index.query.QueryShardException;
@@ -829,6 +830,7 @@ public class ExceptionSerializationTests extends ESTestCase {
         ids.put(163, RepositoryConflictException.class);
         ids.put(164, VersionConflictException.class);
         ids.put(165, SnapshotNameAlreadyInUseException.class);
+        ids.put(166, HealthNodeNotDiscoveredException.class);
 
         Map<Class<? extends ElasticsearchException>, Integer> reverse = new HashMap<>();
         for (Map.Entry<Integer, Class<? extends ElasticsearchException>> entry : ids.entrySet()) {

+ 378 - 0
server/src/test/java/org/elasticsearch/health/node/action/TransportHealthNodeActionTests.java

@@ -0,0 +1,378 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.health.node.action;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.ActionTestUtils;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.action.support.ThreadedActionListener;
+import org.elasticsearch.action.support.replication.ClusterStateCreationUtils;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskCancelledException;
+import org.elasticsearch.tasks.TaskId;
+import org.elasticsearch.tasks.TaskManager;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.transport.CapturingTransport;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportService;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
+import static org.elasticsearch.test.ClusterServiceUtils.setState;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
+
+public class TransportHealthNodeActionTests extends ESTestCase {
+    private static ThreadPool threadPool;
+
+    private ClusterService clusterService;
+    private TransportService transportService;
+    private CapturingTransport transport;
+    private DiscoveryNode localNode;
+    private DiscoveryNode remoteNode;
+    private DiscoveryNode[] allNodes;
+    private TaskManager taskManager;
+
+    @BeforeClass
+    public static void beforeClass() {
+        threadPool = new TestThreadPool("TransportHealthNodeActionTests");
+    }
+
+    @Before
+    @Override
+    public void setUp() throws Exception {
+        super.setUp();
+        taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
+        transport = new CapturingTransport();
+        clusterService = createClusterService(threadPool);
+        transportService = transport.createTransportService(
+            clusterService.getSettings(),
+            threadPool,
+            TransportService.NOOP_TRANSPORT_INTERCEPTOR,
+            x -> clusterService.localNode(),
+            null,
+            Collections.emptySet()
+        );
+        transportService.start();
+        transportService.acceptIncomingRequests();
+        localNode = new DiscoveryNode(
+            "local_node",
+            buildNewFakeTransportAddress(),
+            Collections.emptyMap(),
+            Set.of(DiscoveryNodeRole.MASTER_ROLE, DiscoveryNodeRole.DATA_ROLE),
+            Version.CURRENT
+        );
+        remoteNode = new DiscoveryNode(
+            "remote_node",
+            buildNewFakeTransportAddress(),
+            Collections.emptyMap(),
+            Set.of(DiscoveryNodeRole.MASTER_ROLE, DiscoveryNodeRole.DATA_ROLE),
+            Version.CURRENT
+        );
+        allNodes = new DiscoveryNode[] { localNode, remoteNode };
+    }
+
+    @After
+    public void tearDown() throws Exception {
+        super.tearDown();
+        clusterService.close();
+        transportService.close();
+    }
+
+    @AfterClass
+    public static void afterClass() {
+        ThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS);
+        threadPool = null;
+    }
+
+    public static class Request extends ActionRequest {
+
+        Request() {}
+
+        Request(StreamInput in) throws IOException {
+            super(in);
+        }
+
+        @Override
+        public ActionRequestValidationException validate() {
+            return null;
+        }
+
+        @Override
+        public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
+            return new CancellableTask(id, type, action, "", parentTaskId, headers);
+        }
+    }
+
+    static class Response extends ActionResponse {
+        private long identity = randomLong();
+
+        Response() {}
+
+        Response(StreamInput in) throws IOException {
+            super(in);
+            identity = in.readLong();
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Response response = (Response) o;
+            return identity == response.identity;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(identity);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeLong(identity);
+        }
+    }
+
+    class Action extends TransportHealthNodeAction<Request, Response> {
+        Action(String actionName, TransportService transportService, ClusterService clusterService, ThreadPool threadPool) {
+            this(actionName, transportService, clusterService, threadPool, ThreadPool.Names.SAME);
+        }
+
+        Action(
+            String actionName,
+            TransportService transportService,
+            ClusterService clusterService,
+            ThreadPool threadPool,
+            String executor
+        ) {
+            super(
+                actionName,
+                transportService,
+                clusterService,
+                threadPool,
+                new ActionFilters(new HashSet<>()),
+                Request::new,
+                Response::new,
+                executor
+            );
+        }
+
+        @Override
+        protected void doExecute(Task task, final Request request, ActionListener<Response> listener) {
+            // remove unneeded threading by wrapping listener with SAME to prevent super.doExecute from wrapping it with LISTENER
+            super.doExecute(task, request, new ThreadedActionListener<>(logger, threadPool, ThreadPool.Names.SAME, listener, false));
+        }
+
+        @Override
+        protected void healthOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener) {
+            listener.onResponse(new Response());
+        }
+    }
+
+    class WaitForSignalAction extends Action {
+        private final CountDownLatch countDownLatch;
+
+        WaitForSignalAction(
+            String actionName,
+            TransportService transportService,
+            ClusterService clusterService,
+            ThreadPool threadPool,
+            CountDownLatch countDownLatch
+        ) {
+            super(actionName, transportService, clusterService, threadPool, ThreadPool.Names.SAME);
+            this.countDownLatch = countDownLatch;
+        }
+
+        @Override
+        protected void doExecute(Task task, final Request request, ActionListener<Response> listener) {
+            try {
+                countDownLatch.await();
+            } catch (InterruptedException e) {
+                fail("Something went wrong while waiting for the latch");
+            }
+            super.doExecute(task, request, listener);
+        }
+    }
+
+    class HealthOperationWithExceptionAction extends Action {
+
+        HealthOperationWithExceptionAction(
+            String actionName,
+            TransportService transportService,
+            ClusterService clusterService,
+            ThreadPool threadPool
+        ) {
+            super(actionName, transportService, clusterService, threadPool);
+        }
+
+        @Override
+        protected void healthOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener) {
+            throw new RuntimeException("Simulated");
+        }
+    }
+
+    public void testLocalHealthNode() throws ExecutionException, InterruptedException {
+        final boolean healthOperationFailure = randomBoolean();
+
+        Request request = new Request();
+        PlainActionFuture<Response> listener = new PlainActionFuture<>();
+
+        final Exception exception = new Exception();
+        final Response response = new Response();
+
+        setState(clusterService, ClusterStateCreationUtils.state(localNode, localNode, localNode, allNodes));
+
+        ActionTestUtils.execute(new Action("internal:testAction", transportService, clusterService, threadPool) {
+            @Override
+            protected void healthOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener) {
+                if (healthOperationFailure) {
+                    listener.onFailure(exception);
+                } else {
+                    listener.onResponse(response);
+                }
+            }
+        }, null, request, listener);
+        assertTrue(listener.isDone());
+
+        if (healthOperationFailure) {
+            try {
+                listener.get();
+                fail("Expected exception but returned proper result");
+            } catch (ExecutionException ex) {
+                assertThat(ex.getCause(), equalTo(exception));
+            }
+        } else {
+            assertThat(listener.get(), equalTo(response));
+        }
+    }
+
+    public void testHealthNodeNotAvailable() throws InterruptedException {
+        Request request = new Request();
+        setState(clusterService, ClusterStateCreationUtils.state(localNode, null, allNodes));
+        PlainActionFuture<Response> listener = new PlainActionFuture<>();
+        ActionTestUtils.execute(new Action("internal:testAction", transportService, clusterService, threadPool), null, request, listener);
+        assertTrue(listener.isDone());
+        try {
+            listener.get();
+            fail("NoHealthNodeSelectedException should be thrown");
+        } catch (ExecutionException ex) {
+            assertThat(ex.getCause(), instanceOf(HealthNodeNotDiscoveredException.class));
+        }
+    }
+
+    public void testDelegateToHealthNodeWithoutParentTask() throws ExecutionException, InterruptedException {
+        Request request = new Request();
+        setState(clusterService, ClusterStateCreationUtils.state(localNode, remoteNode, remoteNode, allNodes));
+
+        PlainActionFuture<Response> listener = new PlainActionFuture<>();
+        ActionTestUtils.execute(new Action("internal:testAction", transportService, clusterService, threadPool), null, request, listener);
+
+        assertThat(transport.capturedRequests().length, equalTo(1));
+        CapturingTransport.CapturedRequest capturedRequest = transport.capturedRequests()[0];
+        assertThat(capturedRequest.node(), equalTo(remoteNode));
+        assertThat(capturedRequest.request(), equalTo(request));
+        assertThat(capturedRequest.action(), equalTo("internal:testAction"));
+
+        Response response = new Response();
+        transport.handleResponse(capturedRequest.requestId(), response);
+        assertTrue(listener.isDone());
+        assertThat(listener.get(), equalTo(response));
+    }
+
+    public void testDelegateToHealthNodeWithParentTask() throws ExecutionException, InterruptedException {
+        Request request = new Request();
+        setState(clusterService, ClusterStateCreationUtils.state(localNode, remoteNode, remoteNode, allNodes));
+
+        PlainActionFuture<Response> listener = new PlainActionFuture<>();
+        final CancellableTask task = (CancellableTask) taskManager.register("type", "internal:testAction", request);
+        ActionTestUtils.execute(new Action("internal:testAction", transportService, clusterService, threadPool), task, request, listener);
+
+        assertThat(transport.capturedRequests().length, equalTo(1));
+        CapturingTransport.CapturedRequest capturedRequest = transport.capturedRequests()[0];
+        assertThat(capturedRequest.node(), equalTo(remoteNode));
+        assertThat(capturedRequest.request(), equalTo(request));
+        assertThat(capturedRequest.action(), equalTo("internal:testAction"));
+
+        Response response = new Response();
+        transport.handleResponse(capturedRequest.requestId(), response);
+        assertTrue(listener.isDone());
+        assertThat(listener.get(), equalTo(response));
+    }
+
+    public void testHealthNodeOperationWithException() throws InterruptedException {
+        Request request = new Request();
+        setState(clusterService, ClusterStateCreationUtils.state(localNode, localNode, localNode, allNodes));
+        PlainActionFuture<Response> listener = new PlainActionFuture<>();
+        ActionTestUtils.execute(
+            new HealthOperationWithExceptionAction("internal:testAction", transportService, clusterService, threadPool),
+            null,
+            request,
+            listener
+        );
+        assertTrue(listener.isDone());
+        try {
+            listener.get();
+            fail("A simulated RuntimeException should be thrown");
+        } catch (ExecutionException ex) {
+            assertThat(ex.getCause().getMessage(), equalTo("Simulated"));
+        }
+    }
+
+    public void testTaskCancellation() {
+        Request request = new Request();
+        final CancellableTask task = (CancellableTask) taskManager.register("type", "internal:testAction", request);
+
+        PlainActionFuture<Response> listener = new PlainActionFuture<>();
+        CountDownLatch countDownLatch = new CountDownLatch(1);
+
+        threadPool.executor(ThreadPool.Names.MANAGEMENT)
+            .submit(
+                () -> ActionTestUtils.execute(
+                    new WaitForSignalAction("internal:testAction", transportService, clusterService, threadPool, countDownLatch),
+                    task,
+                    request,
+                    listener
+                )
+            );
+
+        taskManager.cancel(task, "", () -> {});
+        assertThat(task.isCancelled(), equalTo(true));
+
+        countDownLatch.countDown();
+
+        expectThrows(TaskCancelledException.class, listener::actionGet);
+    }
+}

+ 77 - 0
server/src/test/java/org/elasticsearch/health/node/selection/HealthNodeTests.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.health.node.selection;
+
+import org.elasticsearch.Version;
+import org.elasticsearch.action.support.replication.ClusterStateCreationUtils;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.Collections;
+import java.util.Set;
+
+import static org.elasticsearch.persistent.PersistentTasksExecutor.NO_NODE_FOUND;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.notNullValue;
+import static org.hamcrest.Matchers.nullValue;
+
+public class HealthNodeTests extends ESTestCase {
+
+    private final DiscoveryNode node1 = new DiscoveryNode(
+        "node_1",
+        buildNewFakeTransportAddress(),
+        Collections.emptyMap(),
+        Set.of(DiscoveryNodeRole.MASTER_ROLE, DiscoveryNodeRole.DATA_ROLE),
+        Version.CURRENT
+    );
+    private final DiscoveryNode node2 = new DiscoveryNode(
+        "node_2",
+        buildNewFakeTransportAddress(),
+        Collections.emptyMap(),
+        Set.of(DiscoveryNodeRole.MASTER_ROLE, DiscoveryNodeRole.DATA_ROLE),
+        Version.CURRENT
+    );
+    private final DiscoveryNode[] allNodes = new DiscoveryNode[] { node1, node2 };
+
+    public void testFindTask() {
+        ClusterState state = ClusterStateCreationUtils.state(node1, node1, node1, allNodes);
+        assertThat(HealthNode.findTask(state), notNullValue());
+    }
+
+    public void testFindNoTask() {
+        ClusterState state = ClusterStateCreationUtils.state(node1, node1, allNodes);
+        assertThat(HealthNode.findTask(state), nullValue());
+    }
+
+    public void testFindHealthNode() {
+        ClusterState state = ClusterStateCreationUtils.state(node1, node1, node1, allNodes);
+        assertThat(HealthNode.findHealthNode(state), equalTo(node1));
+    }
+
+    public void testFindHealthNodeNoTask() {
+        ClusterState state = ClusterStateCreationUtils.state(node1, node1, allNodes);
+        assertThat(HealthNode.findHealthNode(state), nullValue());
+    }
+
+    public void testfindHealthNodeNoAssignment() {
+        PersistentTasksCustomMetadata.Builder tasks = PersistentTasksCustomMetadata.builder();
+        tasks.addTask(HealthNode.TASK_NAME, HealthNode.TASK_NAME, HealthNodeTaskParams.INSTANCE, NO_NODE_FOUND);
+        ClusterState state = ClusterStateCreationUtils.state(node1, node1, allNodes)
+            .copyAndUpdateMetadata(b -> b.putCustom(PersistentTasksCustomMetadata.TYPE, tasks.build()));
+        assertThat(HealthNode.findHealthNode(state), nullValue());
+    }
+
+    public void testFindHealthNodeMissingNode() {
+        ClusterState state = ClusterStateCreationUtils.state(node1, node1);
+        assertThat(HealthNode.findHealthNode(state), nullValue());
+    }
+}

+ 37 - 1
test/framework/src/main/java/org/elasticsearch/action/support/replication/ClusterStateCreationUtils.java

@@ -26,8 +26,11 @@ import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.TestShardRouting;
 import org.elasticsearch.cluster.routing.UnassignedInfo;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.health.node.selection.HealthNode;
+import org.elasticsearch.health.node.selection.HealthNodeTaskParams;
 import org.elasticsearch.index.shard.IndexLongFieldRange;
 import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.test.ESTestCase;
 
 import java.util.ArrayList;
@@ -41,6 +44,7 @@ import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_CREATION_
 import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS;
 import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
 import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_VERSION_CREATED;
+import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
 import static org.elasticsearch.test.ESTestCase.randomFrom;
 import static org.elasticsearch.test.ESTestCase.randomInt;
 import static org.elasticsearch.test.ESTestCase.randomIntBetween;
@@ -424,6 +428,24 @@ public class ClusterStateCreationUtils {
      * @return cluster state
      */
     public static ClusterState state(DiscoveryNode localNode, DiscoveryNode masterNode, DiscoveryNode... allNodes) {
+        return state(localNode, masterNode, null, allNodes);
+    }
+
+    /**
+     * Creates a cluster state where local node, master and health node can be specified
+     *
+     * @param localNode  node in allNodes that is the local node
+     * @param masterNode node in allNodes that is the master node. Can be null if no master exists
+     * @param healthNode node in allNodes that is the health node. Can be null if no health node exists
+     * @param allNodes   all nodes in the cluster
+     * @return cluster state
+     */
+    public static ClusterState state(
+        DiscoveryNode localNode,
+        DiscoveryNode masterNode,
+        DiscoveryNode healthNode,
+        DiscoveryNode... allNodes
+    ) {
         DiscoveryNodes.Builder discoBuilder = DiscoveryNodes.builder();
         for (DiscoveryNode node : allNodes) {
             discoBuilder.add(node);
@@ -436,7 +458,11 @@ public class ClusterStateCreationUtils {
 
         ClusterState.Builder state = ClusterState.builder(new ClusterName("test"));
         state.nodes(discoBuilder);
-        state.metadata(Metadata.builder().generateClusterUuidIfNeeded());
+        Metadata.Builder metadataBuilder = Metadata.builder().generateClusterUuidIfNeeded();
+        if (healthNode != null) {
+            addHealthNode(metadataBuilder, healthNode);
+        }
+        state.metadata(metadataBuilder);
         return state.build();
     }
 
@@ -455,4 +481,14 @@ public class ClusterStateCreationUtils {
         strings.remove(selection);
         return selection;
     }
+
+    private static Metadata.Builder addHealthNode(Metadata.Builder metadataBuilder, DiscoveryNode healthNode) {
+        PersistentTasksCustomMetadata.Builder tasks = PersistentTasksCustomMetadata.builder();
+        PersistentTasksCustomMetadata.Assignment assignment = new PersistentTasksCustomMetadata.Assignment(
+            healthNode.getId(),
+            randomAlphaOfLength(10)
+        );
+        tasks.addTask(HealthNode.TASK_NAME, HealthNode.TASK_NAME, HealthNodeTaskParams.INSTANCE, assignment);
+        return metadataBuilder.putCustom(PersistentTasksCustomMetadata.TYPE, tasks.build());
+    }
 }