Browse Source

EQL: Add cascading search cancellation (#54843)

EQL search cancellation now propagates cancellation to underlying search
operations.

Relates to #49638
Igor Motov 5 years ago
parent
commit
9d38dcf401

+ 12 - 0
x-pack/plugin/eql/build.gradle

@@ -20,6 +20,18 @@ archivesBaseName = 'x-pack-eql'
 // All integration tests live in qa modules
 integTest.enabled = false
 
+task internalClusterTest(type: Test) {
+  mustRunAfter test
+  include '**/*IT.class'
+  /*
+   * We have to disable setting the number of available processors as tests in the same JVM randomize processors and will step on each
+   * other if we allow them to set the number of available processors as it's set-once in Netty.
+   */
+  systemProperty 'es.set.netty.runtime.available.processors', 'false'
+}
+
+check.dependsOn internalClusterTest
+
 dependencies {
   compileOnly project(path: xpackModule('core'), configuration: 'default')
   compileOnly(project(':modules:lang-painless')) {

+ 1 - 1
x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlSearchTask.java

@@ -23,7 +23,7 @@ public class EqlSearchTask extends CancellableTask {
 
     @Override
     public boolean shouldCancelChildrenOnCancellation() {
-        return false;
+        return true;
     }
 
     @Override

+ 5 - 3
x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlSearchAction.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.time.DateUtils;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.XPackSettings;
@@ -50,11 +51,12 @@ public class TransportEqlSearchAction extends HandledTransportAction<EqlSearchRe
 
     @Override
     protected void doExecute(Task task, EqlSearchRequest request, ActionListener<EqlSearchResponse> listener) {
-        operation(planExecutor, (EqlSearchTask) task, request, username(securityContext), clusterName(clusterService), listener);
+        operation(planExecutor, (EqlSearchTask) task, request, username(securityContext), clusterName(clusterService),
+            clusterService.localNode().getId(), listener);
     }
 
     public static void operation(PlanExecutor planExecutor, EqlSearchTask task, EqlSearchRequest request, String username,
-                                 String clusterName, ActionListener<EqlSearchResponse> listener) {
+                                 String clusterName, String nodeId, ActionListener<EqlSearchResponse> listener) {
         // TODO: these should be sent by the client
         ZoneId zoneId = DateUtils.of("Z");
         QueryBuilder filter = request.filter();
@@ -68,7 +70,7 @@ public class TransportEqlSearchAction extends HandledTransportAction<EqlSearchRe
             .implicitJoinKey(request.implicitJoinKeyField());
 
         Configuration cfg = new Configuration(request.indices(), zoneId, username, clusterName, filter, timeout, request.fetchSize(),
-                includeFrozen, clientId, task);
+                includeFrozen, clientId, new TaskId(nodeId, task.getId()), task::isCancelled);
         planExecutor.eql(cfg, request.query(), params, wrap(r -> listener.onResponse(createResponse(r)), listener::onFailure));
     }
 

+ 13 - 6
x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/session/Configuration.java

@@ -9,9 +9,10 @@ package org.elasticsearch.xpack.eql.session;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.xpack.eql.action.EqlSearchTask;
+import org.elasticsearch.tasks.TaskId;
 
 import java.time.ZoneId;
+import java.util.function.Supplier;
 
 public class Configuration extends org.elasticsearch.xpack.ql.session.Configuration {
 
@@ -20,13 +21,14 @@ public class Configuration extends org.elasticsearch.xpack.ql.session.Configurat
     private final int size;
     private final String clientId;
     private final boolean includeFrozenIndices;
-    private final EqlSearchTask task;
+    private final Supplier<Boolean> isCancelled;
+    private final TaskId taskId;
 
     @Nullable
-    private QueryBuilder filter;
+    private final QueryBuilder filter;
 
     public Configuration(String[] indices, ZoneId zi, String username, String clusterName, QueryBuilder filter, TimeValue requestTimeout,
-                         int size, boolean includeFrozen, String clientId, EqlSearchTask task) {
+                         int size, boolean includeFrozen, String clientId, TaskId taskId, Supplier<Boolean> isCancelled) {
 
         super(zi, username, clusterName);
 
@@ -36,7 +38,8 @@ public class Configuration extends org.elasticsearch.xpack.ql.session.Configurat
         this.size = size;
         this.clientId = clientId;
         this.includeFrozenIndices = includeFrozen;
-        this.task = task;
+        this.taskId = taskId;
+        this.isCancelled = isCancelled;
     }
 
     public String[] indices() {
@@ -64,6 +67,10 @@ public class Configuration extends org.elasticsearch.xpack.ql.session.Configurat
     }
 
     public boolean isCancelled() {
-        return task.isCancelled();
+        return isCancelled.get();
+    }
+
+    public TaskId getTaskId() {
+        return taskId;
     }
 }

+ 2 - 1
x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/session/EqlSession.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.eql.session;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.Client;
+import org.elasticsearch.client.ParentTaskAssigningClient;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.xpack.eql.analysis.Analyzer;
@@ -37,7 +38,7 @@ public class EqlSession {
     public EqlSession(Client client, Configuration cfg, IndexResolver indexResolver, PreAnalyzer preAnalyzer, Analyzer analyzer,
             Optimizer optimizer, Planner planner, PlanExecutor planExecutor) {
 
-        this.client = client;
+        this.client = new ParentTaskAssigningClient(client, cfg.getTaskId());
         this.configuration = cfg;
         this.indexResolver = indexResolver;
         this.preAnalyzer = preAnalyzer;

+ 4 - 2
x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/EqlTestUtils.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.eql;
 
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.xpack.eql.action.EqlSearchAction;
 import org.elasticsearch.xpack.eql.action.EqlSearchTask;
 import org.elasticsearch.xpack.eql.session.Configuration;
@@ -27,7 +28,7 @@ public final class EqlTestUtils {
 
     public static final Configuration TEST_CFG = new Configuration(new String[]{"none"}, org.elasticsearch.xpack.ql.util.DateUtils.UTC,
             "nobody", "cluster", null, TimeValue.timeValueSeconds(30), -1, false, "",
-            new EqlSearchTask(-1, "", EqlSearchAction.NAME, () -> "", null, Collections.emptyMap()));
+            new TaskId(randomAlphaOfLength(10), randomNonNegativeLong()), () -> false);
 
     public static Configuration randomConfiguration() {
         return new Configuration(new String[]{randomAlphaOfLength(16)},
@@ -39,7 +40,8 @@ public final class EqlTestUtils {
             randomIntBetween(5, 100),
             randomBoolean(),
             randomAlphaOfLength(16),
-            randomTask());
+            new TaskId(randomAlphaOfLength(10), randomNonNegativeLong()),
+            () -> false);
     }
 
     public static EqlSearchTask randomTask() {

+ 42 - 0
x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/action/AbstractEqlIntegTestCase.java

@@ -0,0 +1,42 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.eql.action;
+
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.license.LicenseService;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.xpack.core.XPackSettings;
+import org.elasticsearch.xpack.eql.plugin.EqlPlugin;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import static org.elasticsearch.test.ESIntegTestCase.Scope.SUITE;
+
+@ESIntegTestCase.ClusterScope(scope = SUITE, numDataNodes = 0, numClientNodes = 0, maxNumDataNodes = 0)
+public abstract class AbstractEqlIntegTestCase extends ESIntegTestCase {
+
+    @Override
+    protected Settings nodeSettings(int nodeOrdinal) {
+        Settings.Builder settings = Settings.builder().put(super.nodeSettings(nodeOrdinal));
+        settings.put(XPackSettings.SECURITY_ENABLED.getKey(), false);
+        settings.put(XPackSettings.MONITORING_ENABLED.getKey(), false);
+        settings.put(XPackSettings.WATCHER_ENABLED.getKey(), false);
+        settings.put(XPackSettings.GRAPH_ENABLED.getKey(), false);
+        settings.put(XPackSettings.MACHINE_LEARNING_ENABLED.getKey(), false);
+        settings.put(XPackSettings.SQL_ENABLED.getKey(), false);
+        settings.put(EqlPlugin.EQL_ENABLED_SETTING.getKey(), true);
+        settings.put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial");
+        return settings.build();
+    }
+
+    @Override
+    protected Collection<Class<? extends Plugin>> nodePlugins() {
+        return Collections.singletonList(LocalStateEQLXPackPlugin.class);
+    }
+}

+ 280 - 0
x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/action/EqlCancellationIT.java

@@ -0,0 +1,280 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.eql.action;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ExceptionsHelper;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
+import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
+import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction;
+import org.elasticsearch.action.index.IndexRequestBuilder;
+import org.elasticsearch.action.search.SearchPhaseExecutionException;
+import org.elasticsearch.action.support.ActionFilter;
+import org.elasticsearch.action.support.ActionFilterChain;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.IndexModule;
+import org.elasticsearch.index.shard.SearchOperationListener;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.plugins.PluginsService;
+import org.elasticsearch.search.internal.SearchContext;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskCancelledException;
+import org.elasticsearch.tasks.TaskId;
+import org.elasticsearch.tasks.TaskInfo;
+import org.junit.After;
+
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+
+public class EqlCancellationIT extends AbstractEqlIntegTestCase {
+
+    private final ExecutorService executorService = Executors.newFixedThreadPool(1);
+
+    /**
+     * Shutdown the executor so we don't leak threads into other test runs.
+     */
+    @After
+    public void shutdownExec() {
+        executorService.shutdown();
+    }
+
+    public void testCancellation() throws Exception {
+        assertAcked(client().admin().indices().prepareCreate("test")
+            .setMapping("val", "type=integer", "event_type", "type=keyword", "@timestamp", "type=date")
+            .get());
+        createIndex("idx_unmapped");
+
+        int numDocs = randomIntBetween(6, 20);
+
+        List<IndexRequestBuilder> builders = new ArrayList<>();
+
+        for (int i = 0; i < numDocs; i++) {
+            int fieldValue = randomIntBetween(0, 10);
+            builders.add(client().prepareIndex("test").setSource(
+                jsonBuilder().startObject()
+                    .field("val", fieldValue).field("event_type", "my_event").field("@timestamp", "2020-04-09T12:35:48Z")
+                    .endObject()));
+        }
+
+        indexRandom(true, builders);
+        boolean cancelDuringSearch = randomBoolean();
+        List<SearchBlockPlugin> plugins = initBlockFactory(cancelDuringSearch, cancelDuringSearch == false);
+        EqlSearchRequest request = new EqlSearchRequest().indices("test").query("my_event where val=1").eventCategoryField("event_type");
+        String id = randomAlphaOfLength(10);
+        logger.trace("Preparing search");
+        // We might perform field caps on the same thread if it is local client, so we cannot use the standard mechanism
+        Future<EqlSearchResponse> future = executorService.submit(() ->
+            client().filterWithHeader(Collections.singletonMap(Task.X_OPAQUE_ID, id)).execute(EqlSearchAction.INSTANCE, request).get()
+        );
+        logger.trace("Waiting for block to be established");
+        if (cancelDuringSearch) {
+            awaitForBlockedSearches(plugins, "test");
+        } else {
+            awaitForBlockedFieldCaps(plugins);
+        }
+        logger.trace("Block is established");
+        ListTasksResponse tasks = client().admin().cluster().prepareListTasks().setActions(EqlSearchAction.NAME).get();
+        TaskId taskId = null;
+        for (TaskInfo task : tasks.getTasks()) {
+            if (id.equals(task.getHeaders().get(Task.X_OPAQUE_ID))) {
+                taskId = task.getTaskId();
+                break;
+            }
+        }
+        assertNotNull(taskId);
+        logger.trace("Cancelling task " + taskId);
+        CancelTasksResponse response = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).get();
+        assertThat(response.getTasks(), hasSize(1));
+        assertThat(response.getTasks().get(0).getAction(), equalTo(EqlSearchAction.NAME));
+        logger.trace("Task is cancelled " + taskId);
+        disableBlocks(plugins);
+        Exception exception = expectThrows(Exception.class, future::get);
+        Throwable inner = ExceptionsHelper.unwrap(exception, SearchPhaseExecutionException.class);
+        if (cancelDuringSearch) {
+            // Make sure we cancelled inside search
+            assertNotNull(inner);
+            assertThat(inner, instanceOf(SearchPhaseExecutionException.class));
+            assertThat(inner.getCause(), instanceOf(TaskCancelledException.class));
+        } else {
+            // Make sure we were not cancelled inside search
+            assertNull(inner);
+            assertThat(getNumberOfContexts(plugins), equalTo(0));
+            Throwable cancellationException = ExceptionsHelper.unwrap(exception, TaskCancelledException.class);
+            assertNotNull(cancellationException);
+        }
+    }
+
+    private List<SearchBlockPlugin> initBlockFactory(boolean searchBlock, boolean fieldCapsBlock) {
+        List<SearchBlockPlugin> plugins = new ArrayList<>();
+        for (PluginsService pluginsService : internalCluster().getDataNodeInstances(PluginsService.class)) {
+            plugins.addAll(pluginsService.filterPlugins(SearchBlockPlugin.class));
+        }
+        for (SearchBlockPlugin plugin : plugins) {
+            plugin.reset();
+            if (searchBlock) {
+                plugin.enableSearchBlock();
+            }
+            if (fieldCapsBlock) {
+                plugin.enableFieldCapBlock();
+            }
+        }
+        return plugins;
+    }
+
+    private void disableBlocks(List<SearchBlockPlugin> plugins) {
+        for (SearchBlockPlugin plugin : plugins) {
+            plugin.disableSearchBlock();
+            plugin.disableFieldCapBlock();
+        }
+    }
+
+    private void awaitForBlockedSearches(List<SearchBlockPlugin> plugins, String index) throws Exception {
+        int numberOfShards = getNumShards(index).numPrimaries;
+        assertBusy(() -> {
+            int numberOfBlockedPlugins = getNumberOfContexts(plugins);
+            logger.trace("The plugin blocked on {} out of {} shards", numberOfBlockedPlugins, numberOfShards);
+            assertThat(numberOfBlockedPlugins, greaterThan(0));
+        });
+    }
+
+    private int getNumberOfContexts(List<SearchBlockPlugin> plugins) throws Exception {
+        int count = 0;
+        for (SearchBlockPlugin plugin : plugins) {
+            count += plugin.contexts.get();
+        }
+        return count;
+    }
+
+    private int getNumberOfFieldCaps(List<SearchBlockPlugin> plugins) throws Exception {
+        int count = 0;
+        for (SearchBlockPlugin plugin : plugins) {
+            count += plugin.fieldCaps.get();
+        }
+        return count;
+    }
+
+    private void awaitForBlockedFieldCaps(List<SearchBlockPlugin> plugins) throws Exception {
+        assertBusy(() -> {
+            int numberOfBlockedPlugins = getNumberOfFieldCaps(plugins);
+            logger.trace("The plugin blocked on {} nodes", numberOfBlockedPlugins);
+            assertThat(numberOfBlockedPlugins, greaterThan(0));
+        });
+    }
+
+    public static class SearchBlockPlugin extends LocalStateEQLXPackPlugin {
+        protected final Logger logger = LogManager.getLogger(getClass());
+
+        private final AtomicInteger contexts = new AtomicInteger();
+
+        private final AtomicInteger fieldCaps = new AtomicInteger();
+
+        private final AtomicBoolean shouldBlockOnSearch = new AtomicBoolean(false);
+
+        private final AtomicBoolean shouldBlockOnFieldCapabilities = new AtomicBoolean(false);
+
+        private final String nodeId;
+
+        public void reset() {
+            contexts.set(0);
+            fieldCaps.set(0);
+        }
+
+        public void disableSearchBlock() {
+            shouldBlockOnSearch.set(false);
+        }
+
+        public void enableSearchBlock() {
+            shouldBlockOnSearch.set(true);
+        }
+
+
+        public void disableFieldCapBlock() {
+            shouldBlockOnFieldCapabilities.set(false);
+        }
+
+        public void enableFieldCapBlock() {
+            shouldBlockOnFieldCapabilities.set(true);
+        }
+
+        public SearchBlockPlugin(Settings settings, Path configPath) throws Exception {
+            super(settings, configPath);
+            nodeId = settings.get("node.name");
+        }
+
+        @Override
+        public void onIndexModule(IndexModule indexModule) {
+            super.onIndexModule(indexModule);
+            indexModule.addSearchOperationListener(new SearchOperationListener() {
+                @Override
+                public void onNewContext(SearchContext context) {
+                    contexts.incrementAndGet();
+                    try {
+                        logger.trace("blocking search on " + nodeId);
+                        assertBusy(() -> assertFalse(shouldBlockOnSearch.get()));
+                        logger.trace("unblocking search on " + nodeId);
+                    } catch (Exception e) {
+                        throw new RuntimeException(e);
+                    }
+                }
+            });
+        }
+
+        @Override
+        public List<ActionFilter> getActionFilters() {
+            List<ActionFilter> list = new ArrayList<>(super.getActionFilters());
+            list.add(new ActionFilter() {
+                @Override
+                public int order() {
+                    return 0;
+                }
+
+                @Override
+                public <Request extends ActionRequest, Response extends ActionResponse> void apply(
+                    Task task, String action, Request request, ActionListener<Response> listener,
+                    ActionFilterChain<Request, Response> chain) {
+                    if (action.equals(FieldCapabilitiesAction.NAME)) {
+                        try {
+                            fieldCaps.incrementAndGet();
+                            logger.trace("blocking field caps on " + nodeId);
+                            assertBusy(() -> assertFalse(shouldBlockOnFieldCapabilities.get()));
+                            logger.trace("unblocking field caps on " + nodeId);
+                        } catch (Exception e) {
+                            throw new RuntimeException(e);
+                        }
+                    }
+                    chain.proceed(task, action, request, listener);
+                }
+            });
+            return list;
+        }
+    }
+
+    @Override
+    protected Collection<Class<? extends Plugin>> nodePlugins() {
+        return Collections.singletonList(SearchBlockPlugin.class);
+    }
+
+}

+ 31 - 0
x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/action/LocalStateEQLXPackPlugin.java

@@ -0,0 +1,31 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.eql.action;
+
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.license.XPackLicenseState;
+import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
+import org.elasticsearch.xpack.eql.plugin.EqlPlugin;
+import org.elasticsearch.xpack.ql.plugin.QlPlugin;
+
+import java.nio.file.Path;
+
+public class LocalStateEQLXPackPlugin extends LocalStateCompositeXPackPlugin {
+
+    public LocalStateEQLXPackPlugin(final Settings settings, final Path configPath) {
+        super(settings, configPath);
+        LocalStateEQLXPackPlugin thisVar = this;
+        plugins.add(new EqlPlugin(settings) {
+            @Override
+            protected XPackLicenseState getLicenseState() {
+                return thisVar.getLicenseState();
+            }
+        });
+        plugins.add(new QlPlugin());
+    }
+
+}

+ 96 - 12
x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/analysis/CancellationTests.java

@@ -8,9 +8,14 @@ package org.elasticsearch.xpack.eql.analysis;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.fieldcaps.FieldCapabilities;
 import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
+import org.elasticsearch.action.search.SearchAction;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchRequestBuilder;
+import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.tasks.TaskCancelledException;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.eql.action.EqlSearchRequest;
 import org.elasticsearch.xpack.eql.action.EqlSearchResponse;
@@ -19,6 +24,7 @@ import org.elasticsearch.xpack.eql.execution.PlanExecutor;
 import org.elasticsearch.xpack.eql.plugin.TransportEqlSearchAction;
 import org.elasticsearch.xpack.ql.index.IndexResolver;
 import org.elasticsearch.xpack.ql.type.DefaultDataTypeRegistry;
+import org.mockito.ArgumentCaptor;
 import org.mockito.stubbing.Answer;
 
 import java.util.Collections;
@@ -48,7 +54,7 @@ public class CancellationTests extends ESTestCase {
         IndexResolver indexResolver = new IndexResolver(client, randomAlphaOfLength(10), DefaultDataTypeRegistry.INSTANCE);
         PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList()));
         CountDownLatch countDownLatch = new CountDownLatch(1);
-        TransportEqlSearchAction.operation(planExecutor, task, new EqlSearchRequest().query("foo where blah"), "", "",
+        TransportEqlSearchAction.operation(planExecutor, task, new EqlSearchRequest().query("foo where blah"), "", "", "node_id",
             new ActionListener<>() {
                 @Override
                 public void onResponse(EqlSearchResponse eqlSearchResponse) {
@@ -64,18 +70,13 @@ public class CancellationTests extends ESTestCase {
             });
         countDownLatch.await();
         verify(task, times(1)).isCancelled();
+        verify(task, times(1)).getId();
+        verify(client, times(1)).settings();
+        verify(client, times(1)).threadPool();
         verifyNoMoreInteractions(client, task);
     }
 
-    public void testCancellationBeforeSearch() throws InterruptedException {
-        Client client = mock(Client.class);
-
-        AtomicBoolean cancelled = new AtomicBoolean(false);
-        EqlSearchTask task = mock(EqlSearchTask.class);
-        when(task.isCancelled()).then(invocationOnMock -> cancelled.get());
-
-        String[] indices = new String[]{"endgame"};
-
+    private Map<String, Map<String, FieldCapabilities>> fields(String[] indices) {
         FieldCapabilities fooField =
             new FieldCapabilities("foo", "integer", true, true, indices, null, null, emptyMap());
         FieldCapabilities categoryField =
@@ -86,10 +87,24 @@ public class CancellationTests extends ESTestCase {
         fields.put(fooField.getName(), singletonMap(fooField.getName(), fooField));
         fields.put(categoryField.getName(), singletonMap(categoryField.getName(), categoryField));
         fields.put(timestampField.getName(), singletonMap(timestampField.getName(), timestampField));
+        return fields;
+    }
+
+    public void testCancellationBeforeSearch() throws InterruptedException {
+        Client client = mock(Client.class);
+
+        AtomicBoolean cancelled = new AtomicBoolean(false);
+        EqlSearchTask task = mock(EqlSearchTask.class);
+        String nodeId = randomAlphaOfLength(10);
+        long taskId = randomNonNegativeLong();
+        when(task.isCancelled()).then(invocationOnMock -> cancelled.get());
+        when(task.getId()).thenReturn(taskId);
+
+        String[] indices = new String[]{"endgame"};
 
         FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
         when(fieldCapabilitiesResponse.getIndices()).thenReturn(indices);
-        when(fieldCapabilitiesResponse.get()).thenReturn(fields);
+        when(fieldCapabilitiesResponse.get()).thenReturn(fields(indices));
         doAnswer((Answer<Void>) invocation -> {
             @SuppressWarnings("unchecked")
             ActionListener<FieldCapabilitiesResponse> listener = (ActionListener<FieldCapabilitiesResponse>) invocation.getArguments()[1];
@@ -103,7 +118,71 @@ public class CancellationTests extends ESTestCase {
         PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList()));
         CountDownLatch countDownLatch = new CountDownLatch(1);
         TransportEqlSearchAction.operation(planExecutor, task, new EqlSearchRequest().indices("endgame")
-            .query("process where foo==3"), "", "", new ActionListener<>() {
+            .query("process where foo==3"), "", "", nodeId, new ActionListener<>() {
+            @Override
+            public void onResponse(EqlSearchResponse eqlSearchResponse) {
+                fail("Shouldn't be here");
+                countDownLatch.countDown();
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                assertThat(e, instanceOf(TaskCancelledException.class));
+                countDownLatch.countDown();
+            }
+        });
+        countDownLatch.await();
+        verify(client).fieldCaps(any(), any());
+        verify(task, times(2)).isCancelled();
+        verify(task, times(1)).getId();
+        verify(client, times(1)).settings();
+        verify(client, times(1)).threadPool();
+        verifyNoMoreInteractions(client, task);
+    }
+
+    public void testCancellationDuringSearch() throws InterruptedException {
+        Client client = mock(Client.class);
+
+        EqlSearchTask task = mock(EqlSearchTask.class);
+        String nodeId = randomAlphaOfLength(10);
+        long taskId = randomNonNegativeLong();
+        when(task.isCancelled()).thenReturn(false);
+        when(task.getId()).thenReturn(taskId);
+
+        String[] indices = new String[]{"endgame"};
+
+        // Emulation of field capabilities
+        FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class);
+        when(fieldCapabilitiesResponse.getIndices()).thenReturn(indices);
+        when(fieldCapabilitiesResponse.get()).thenReturn(fields(indices));
+        doAnswer((Answer<Void>) invocation -> {
+            @SuppressWarnings("unchecked")
+            ActionListener<FieldCapabilitiesResponse> listener = (ActionListener<FieldCapabilitiesResponse>) invocation.getArguments()[1];
+            listener.onResponse(fieldCapabilitiesResponse);
+            return null;
+        }).when(client).fieldCaps(any(), any());
+
+        // Emulation of search cancellation
+        ArgumentCaptor<SearchRequest> searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class);
+        when(client.prepareSearch(any())).thenReturn(new SearchRequestBuilder(client, SearchAction.INSTANCE).setIndices(indices));
+        doAnswer((Answer<Void>) invocation -> {
+            @SuppressWarnings("unchecked")
+            SearchRequest request = (SearchRequest) invocation.getArguments()[1];
+            TaskId parentTask = request.getParentTask();
+            assertNotNull(parentTask);
+            assertEquals(taskId, parentTask.getId());
+            assertEquals(nodeId, parentTask.getNodeId());
+            @SuppressWarnings("unchecked")
+            ActionListener<SearchResponse> listener = (ActionListener<SearchResponse>) invocation.getArguments()[2];
+            listener.onFailure(new TaskCancelledException("cancelled"));
+            return null;
+        }).when(client).execute(any(), searchRequestCaptor.capture(), any());
+
+        IndexResolver indexResolver = new IndexResolver(client, randomAlphaOfLength(10), DefaultDataTypeRegistry.INSTANCE);
+        PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList()));
+        CountDownLatch countDownLatch = new CountDownLatch(1);
+        TransportEqlSearchAction.operation(planExecutor, task, new EqlSearchRequest().indices("endgame")
+            .query("process where foo==3"), "", "", nodeId, new ActionListener<>() {
             @Override
             public void onResponse(EqlSearchResponse eqlSearchResponse) {
                 fail("Shouldn't be here");
@@ -117,8 +196,13 @@ public class CancellationTests extends ESTestCase {
             }
         });
         countDownLatch.await();
+        // Final verification to ensure no more interaction
         verify(client).fieldCaps(any(), any());
+        verify(client).execute(any(), any(), any());
         verify(task, times(2)).isCancelled();
+        verify(task, times(1)).getId();
+        verify(client, times(1)).settings();
+        verify(client, times(1)).threadPool();
         verifyNoMoreInteractions(client, task);
     }