Browse Source

Ignore cancellation exceptions (#117657) (#118169)

Today, when an ES|QL task encounters an exception, we trigger a 
cancellation on the root task, causing child tasks to fail due to
cancellation. We chose not to include cancellation exceptions in the 
output, as they are unhelpful and add noise during problem analysis.
However, these exceptions are still slipping through via
RefCountingListener. This change addresses the issue by introducing
ESQLRefCountingListener, ensuring that no cancellation exceptions are
returned.
Nhat Nguyen 10 months ago
parent
commit
25fd1be7d6

+ 5 - 0
docs/changelog/117657.yaml

@@ -0,0 +1,5 @@
+pr: 117657
+summary: Ignore cancellation exceptions
+area: ES|QL
+type: bug
+issues: []

+ 47 - 0
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/EsqlRefCountingListener.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.compute;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.RefCountingRunnable;
+import org.elasticsearch.compute.operator.FailureCollector;
+import org.elasticsearch.core.Releasable;
+
+/**
+ * Similar to {@link org.elasticsearch.action.support.RefCountingListener},
+ * but prefers non-task-cancelled exceptions over task-cancelled ones as they are more useful for diagnosing issues.
+ * @see FailureCollector
+ */
+public final class EsqlRefCountingListener implements Releasable {
+    private final FailureCollector failureCollector;
+    private final RefCountingRunnable refs;
+
+    public EsqlRefCountingListener(ActionListener<Void> delegate) {
+        this.failureCollector = new FailureCollector();
+        this.refs = new RefCountingRunnable(() -> {
+            Exception error = failureCollector.getFailure();
+            if (error != null) {
+                delegate.onFailure(error);
+            } else {
+                delegate.onResponse(null);
+            }
+        });
+    }
+
+    public ActionListener<Void> acquire() {
+        return refs.acquireListener().delegateResponse((l, e) -> {
+            failureCollector.unwrapAndCollect(e);
+            l.onFailure(e);
+        });
+    }
+
+    @Override
+    public void close() {
+        refs.close();
+    }
+}

+ 24 - 24
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java

@@ -13,9 +13,8 @@ import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.transport.TransportException;
 
-import java.util.List;
 import java.util.Queue;
-import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.Semaphore;
 
 /**
  * {@code FailureCollector} is responsible for collecting exceptions that occur in the compute engine.
@@ -26,12 +25,11 @@ import java.util.concurrent.atomic.AtomicInteger;
  */
 public final class FailureCollector {
     private final Queue<Exception> cancelledExceptions = ConcurrentCollections.newQueue();
-    private final AtomicInteger cancelledExceptionsCount = new AtomicInteger();
+    private final Semaphore cancelledExceptionsPermits;
 
     private final Queue<Exception> nonCancelledExceptions = ConcurrentCollections.newQueue();
-    private final AtomicInteger nonCancelledExceptionsCount = new AtomicInteger();
+    private final Semaphore nonCancelledExceptionsPermits;
 
-    private final int maxExceptions;
     private volatile boolean hasFailure = false;
     private Exception finalFailure = null;
 
@@ -43,7 +41,8 @@ public final class FailureCollector {
         if (maxExceptions <= 0) {
             throw new IllegalArgumentException("maxExceptions must be at least one");
         }
-        this.maxExceptions = maxExceptions;
+        this.cancelledExceptionsPermits = new Semaphore(maxExceptions);
+        this.nonCancelledExceptionsPermits = new Semaphore(maxExceptions);
     }
 
     private static Exception unwrapTransportException(TransportException te) {
@@ -60,13 +59,12 @@ public final class FailureCollector {
     public void unwrapAndCollect(Exception e) {
         e = e instanceof TransportException te ? unwrapTransportException(te) : e;
         if (ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null) {
-            if (cancelledExceptionsCount.incrementAndGet() <= maxExceptions) {
+            if (nonCancelledExceptions.isEmpty() && cancelledExceptionsPermits.tryAcquire()) {
                 cancelledExceptions.add(e);
             }
-        } else {
-            if (nonCancelledExceptionsCount.incrementAndGet() <= maxExceptions) {
-                nonCancelledExceptions.add(e);
-            }
+        } else if (nonCancelledExceptionsPermits.tryAcquire()) {
+            nonCancelledExceptions.add(e);
+            cancelledExceptions.clear();
         }
         hasFailure = true;
     }
@@ -99,20 +97,22 @@ public final class FailureCollector {
     private Exception buildFailure() {
         assert hasFailure;
         assert Thread.holdsLock(this);
-        int total = 0;
         Exception first = null;
-        for (var exceptions : List.of(nonCancelledExceptions, cancelledExceptions)) {
-            for (Exception e : exceptions) {
-                if (first == null) {
-                    first = e;
-                    total++;
-                } else if (first != e) {
-                    first.addSuppressed(e);
-                    total++;
-                }
-                if (total >= maxExceptions) {
-                    return first;
-                }
+        for (Exception e : nonCancelledExceptions) {
+            if (first == null) {
+                first = e;
+            } else if (first != e) {
+                first.addSuppressed(e);
+            }
+        }
+        if (first != null) {
+            return first;
+        }
+        for (Exception e : cancelledExceptions) {
+            if (first == null) {
+                first = e;
+            } else if (first != e) {
+                first.addSuppressed(e);
             }
         }
         assert first != null;

+ 9 - 8
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java

@@ -9,9 +9,10 @@ package org.elasticsearch.compute.operator.exchange;
 
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.support.RefCountingListener;
+import org.elasticsearch.action.support.RefCountingRunnable;
 import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.compute.EsqlRefCountingListener;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.operator.FailureCollector;
 import org.elasticsearch.compute.operator.IsBlockedResult;
@@ -54,20 +55,20 @@ public final class ExchangeSourceHandler {
         this.outstandingSinks = new PendingInstances(() -> buffer.finish(false));
         this.outstandingSources = new PendingInstances(() -> buffer.finish(true));
         buffer.addCompletionListener(ActionListener.running(() -> {
-            final ActionListener<Void> listener = ActionListener.assertAtLeastOnce(completionListener).delegateFailure((l, unused) -> {
+            final ActionListener<Void> listener = ActionListener.assertAtLeastOnce(completionListener);
+            try (RefCountingRunnable refs = new RefCountingRunnable(() -> {
                 final Exception e = failure.getFailure();
                 if (e != null) {
-                    l.onFailure(e);
+                    listener.onFailure(e);
                 } else {
-                    l.onResponse(null);
+                    listener.onResponse(null);
                 }
-            });
-            try (RefCountingListener refs = new RefCountingListener(listener)) {
+            })) {
                 for (PendingInstances pending : List.of(outstandingSinks, outstandingSources)) {
                     // Create an outstanding instance and then finish to complete the completionListener
                     // if we haven't registered any instances of exchange sinks or exchange sources before.
                     pending.trackNewInstance();
-                    pending.completion.addListener(refs.acquire());
+                    pending.completion.addListener(refs.acquireListener());
                     pending.finishInstance();
                 }
             }
@@ -269,7 +270,7 @@ public final class ExchangeSourceHandler {
 
             @Override
             protected void doRun() {
-                try (RefCountingListener refs = new RefCountingListener(sinkListener)) {
+                try (EsqlRefCountingListener refs = new EsqlRefCountingListener(sinkListener)) {
                     for (int i = 0; i < instances; i++) {
                         var fetcher = new RemoteSinkFetcher(remoteSink, failFast, refs.acquire());
                         fetcher.fetchPage();

+ 9 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.compute.operator;
 
+import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.cluster.node.DiscoveryNodeUtils;
 import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.breaker.CircuitBreaker;
@@ -86,6 +87,14 @@ public class FailureCollectorTests extends ESTestCase {
         assertNotNull(failure);
         assertThat(failure, Matchers.in(nonCancelledExceptions));
         assertThat(failure.getSuppressed().length, lessThan(maxExceptions));
+        assertTrue(
+            "cancellation exceptions must be ignored",
+            ExceptionsHelper.unwrapCausesAndSuppressed(failure, t -> t instanceof TaskCancelledException).isEmpty()
+        );
+        assertTrue(
+            "remote transport exception must be unwrapped",
+            ExceptionsHelper.unwrapCausesAndSuppressed(failure, t -> t instanceof TransportException).isEmpty()
+        );
     }
 
     public void testEmpty() {

+ 18 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.esql;
 import org.apache.lucene.document.InetAddressPoint;
 import org.apache.lucene.sandbox.document.HalfFloatPoint;
 import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.breaker.NoopCircuitBreaker;
@@ -30,7 +31,9 @@ import org.elasticsearch.geo.GeometryTestUtils;
 import org.elasticsearch.geo.ShapeTestUtils;
 import org.elasticsearch.index.IndexMode;
 import org.elasticsearch.license.XPackLicenseState;
+import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.transport.RemoteTransportException;
 import org.elasticsearch.xcontent.json.JsonXContent;
 import org.elasticsearch.xpack.esql.action.EsqlQueryResponse;
 import org.elasticsearch.xpack.esql.analysis.EnrichResolution;
@@ -129,6 +132,8 @@ import static org.elasticsearch.xpack.esql.parser.ParserUtils.ParamClassificatio
 import static org.elasticsearch.xpack.esql.parser.ParserUtils.ParamClassification.PATTERN;
 import static org.elasticsearch.xpack.esql.parser.ParserUtils.ParamClassification.VALUE;
 import static org.hamcrest.Matchers.instanceOf;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 
 public final class EsqlTestUtils {
 
@@ -784,4 +789,17 @@ public final class EsqlTestUtils {
     public static QueryParam paramAsPattern(String name, Object value) {
         return new QueryParam(name, value, NULL, PATTERN);
     }
+
+    /**
+     * Asserts that:
+     * 1. Cancellation exceptions are ignored when more relevant exceptions exist.
+     * 2. Transport exceptions are unwrapped, and the actual causes are reported to users.
+     */
+    public static void assertEsqlFailure(Exception e) {
+        assertNotNull(e);
+        var cancellationFailure = ExceptionsHelper.unwrapCausesAndSuppressed(e, t -> t instanceof TaskCancelledException).orElse(null);
+        assertNull("cancellation exceptions must be ignored", cancellationFailure);
+        ExceptionsHelper.unwrapCausesAndSuppressed(e, t -> t instanceof RemoteTransportException)
+            .ifPresent(transportFailure -> assertNull("remote transport exception must be unwrapped", transportFailure.getCause()));
+    }
 }

+ 1 - 0
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java

@@ -143,6 +143,7 @@ public class EnrichIT extends AbstractEsqlIntegTestCase {
                 return client.execute(EsqlQueryAction.INSTANCE, request).actionGet(2, TimeUnit.MINUTES);
             } catch (Exception e) {
                 logger.info("request failed", e);
+                EsqlTestUtils.assertEsqlFailure(e);
                 ensureBlocksReleased();
             } finally {
                 setRequestCircuitBreakerLimit(null);

+ 2 - 0
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java

@@ -23,6 +23,7 @@ import org.elasticsearch.indices.breaker.HierarchyCircuitBreakerService;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.test.junit.annotations.TestLogging;
+import org.elasticsearch.xpack.esql.EsqlTestUtils;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -85,6 +86,7 @@ public class EsqlActionBreakerIT extends EsqlActionIT {
         } catch (Exception e) {
             logger.info("request failed", e);
             ensureBlocksReleased();
+            EsqlTestUtils.assertEsqlFailure(e);
             throw e;
         } finally {
             setRequestCircuitBreakerLimit(null);

+ 11 - 1
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java

@@ -36,6 +36,7 @@ import org.elasticsearch.test.junit.annotations.TestLogging;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.esql.EsqlTestUtils;
 import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
 import org.junit.Before;
 
@@ -338,7 +339,15 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
          */
         assertThat(
             cancelException.getMessage(),
-            in(List.of("test cancel", "task cancelled", "request cancelled test cancel", "parent task was cancelled [test cancel]"))
+            in(
+                List.of(
+                    "test cancel",
+                    "task cancelled",
+                    "request cancelled test cancel",
+                    "parent task was cancelled [test cancel]",
+                    "cancelled on failure"
+                )
+            )
         );
         assertBusy(
             () -> assertThat(
@@ -434,6 +443,7 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
                 allowedFetching.countDown();
             }
             Exception failure = expectThrows(Exception.class, () -> future.actionGet().close());
+            EsqlTestUtils.assertEsqlFailure(failure);
             assertThat(failure.getMessage(), containsString("failed to fetch pages"));
             // If we proceed without waiting for pages, we might cancel the main request before starting the data-node request.
             // As a result, the exchange sinks on data-nodes won't be removed until the inactive_timeout elapses, which is

+ 2 - 0
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java

@@ -23,6 +23,7 @@ import org.elasticsearch.test.disruption.NetworkDisruption;
 import org.elasticsearch.test.disruption.ServiceDisruptionScheme;
 import org.elasticsearch.test.transport.MockTransportService;
 import org.elasticsearch.transport.TransportSettings;
+import org.elasticsearch.xpack.esql.EsqlTestUtils;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -111,6 +112,7 @@ public class EsqlDisruptionIT extends EsqlActionIT {
             assertTrue("request must be failed or completed after clearing disruption", future.isDone());
             ensureBlocksReleased();
             logger.info("--> failed to execute esql query with disruption; retrying...", e);
+            EsqlTestUtils.assertEsqlFailure(e);
             return client().execute(EsqlQueryAction.INSTANCE, request).actionGet(2, TimeUnit.MINUTES);
         }
     }

+ 4 - 6
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java

@@ -9,8 +9,8 @@ package org.elasticsearch.xpack.esql.plugin;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.RefCountingListener;
+import org.elasticsearch.compute.EsqlRefCountingListener;
 import org.elasticsearch.compute.operator.DriverProfile;
-import org.elasticsearch.compute.operator.FailureCollector;
 import org.elasticsearch.compute.operator.ResponseHeadersCollector;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Releasable;
@@ -39,8 +39,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
 final class ComputeListener implements Releasable {
     private static final Logger LOGGER = LogManager.getLogger(ComputeService.class);
 
-    private final RefCountingListener refs;
-    private final FailureCollector failureCollector = new FailureCollector();
+    private final EsqlRefCountingListener refs;
     private final AtomicBoolean cancelled = new AtomicBoolean();
     private final CancellableTask task;
     private final TransportService transportService;
@@ -105,7 +104,7 @@ final class ComputeListener implements Releasable {
             : "clusterAlias and executionInfo must both be null or both non-null";
 
         // listener that executes after all the sub-listeners refs (created via acquireCompute) have completed
-        this.refs = new RefCountingListener(1, ActionListener.wrap(ignored -> {
+        this.refs = new EsqlRefCountingListener(delegate.delegateFailure((l, ignored) -> {
             responseHeaders.finish();
             ComputeResponse result;
 
@@ -131,7 +130,7 @@ final class ComputeListener implements Releasable {
                 }
             }
             delegate.onResponse(result);
-        }, e -> delegate.onFailure(failureCollector.getFailure())));
+        }));
     }
 
     private static void setFinalStatusAndShardCounts(String clusterAlias, EsqlExecutionInfo executionInfo) {
@@ -191,7 +190,6 @@ final class ComputeListener implements Releasable {
      */
     ActionListener<Void> acquireAvoid() {
         return refs.acquire().delegateResponse((l, e) -> {
-            failureCollector.unwrapAndCollect(e);
             try {
                 if (cancelled.compareAndSet(false, true)) {
                     LOGGER.debug("cancelling ESQL task {} on failure", task);

+ 3 - 3
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java

@@ -16,11 +16,11 @@ import org.elasticsearch.action.search.SearchShardsGroup;
 import org.elasticsearch.action.search.SearchShardsRequest;
 import org.elasticsearch.action.search.SearchShardsResponse;
 import org.elasticsearch.action.support.ChannelActionListener;
-import org.elasticsearch.action.support.RefCountingListener;
 import org.elasticsearch.action.support.RefCountingRunnable;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.util.BigArrays;
+import org.elasticsearch.compute.EsqlRefCountingListener;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.compute.data.Page;
 import org.elasticsearch.compute.operator.Driver;
@@ -375,7 +375,7 @@ public class ComputeService {
         var lookupListener = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink());
         // SearchShards API can_match is done in lookupDataNodes
         lookupDataNodes(parentTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(dataNodeResult -> {
-            try (RefCountingListener refs = new RefCountingListener(lookupListener)) {
+            try (EsqlRefCountingListener refs = new EsqlRefCountingListener(lookupListener)) {
                 // update ExecutionInfo with shard counts (total and skipped)
                 executionInfo.swapCluster(
                     clusterAlias,
@@ -436,7 +436,7 @@ public class ComputeService {
     ) {
         var queryPragmas = configuration.pragmas();
         var linkExchangeListeners = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink());
-        try (RefCountingListener refs = new RefCountingListener(linkExchangeListeners)) {
+        try (EsqlRefCountingListener refs = new EsqlRefCountingListener(linkExchangeListeners)) {
             for (RemoteCluster cluster : clusters) {
                 final var childSessionId = newChildSession(sessionId);
                 ExchangeService.openExchange(