ソースを参照

Driver should wait for async actions (#102508)

Today, the driver may close an operator while some of its async actions 
are still running. In fact, a Driver doesn't wait for async actions,
such as the enrich lookup action, to complete before finishing itself.
This change enables the tracking of async actions in the DriverContext
so that the Driver can register a listener to wait for async actions
before completing. Another change in this pull request is to discard the
async result of an AsyncOperator after it's closed.


Closes #102264
Closes #102459
Nhat Nguyen 1 年間 前
コミット
a419993068

+ 24 - 13
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AsyncOperator.java

@@ -31,9 +31,11 @@ public abstract class AsyncOperator implements Operator {
 
     private final Map<Long, Page> buffers = ConcurrentCollections.newConcurrentMap();
     private final AtomicReference<Exception> failure = new AtomicReference<>();
+    private final DriverContext driverContext;
 
     private final int maxOutstandingRequests;
     private boolean finished = false;
+    private volatile boolean closed = false;
 
     /*
      * The checkpoint tracker is used to maintain the order of emitted pages after passing through this async operator.
@@ -51,7 +53,8 @@ public abstract class AsyncOperator implements Operator {
      *
      * @param maxOutstandingRequests the maximum number of outstanding requests
      */
-    public AsyncOperator(int maxOutstandingRequests) {
+    public AsyncOperator(DriverContext driverContext, int maxOutstandingRequests) {
+        this.driverContext = driverContext;
         this.maxOutstandingRequests = maxOutstandingRequests;
     }
 
@@ -68,14 +71,24 @@ public abstract class AsyncOperator implements Operator {
             return;
         }
         final long seqNo = checkpoint.generateSeqNo();
-        performAsync(input, ActionListener.wrap(output -> {
-            buffers.put(seqNo, output);
-            onSeqNoCompleted(seqNo);
-        }, e -> {
-            input.releaseBlocks();
-            onFailure(e);
-            onSeqNoCompleted(seqNo);
-        }));
+        driverContext.addAsyncAction();
+        boolean success = false;
+        try {
+            final ActionListener<Page> listener = ActionListener.wrap(output -> {
+                buffers.put(seqNo, output);
+                onSeqNoCompleted(seqNo);
+            }, e -> {
+                input.releaseBlocks();
+                onFailure(e);
+                onSeqNoCompleted(seqNo);
+            });
+            performAsync(input, ActionListener.runAfter(listener, driverContext::removeAsyncAction));
+            success = true;
+        } finally {
+            if (success == false) {
+                driverContext.removeAsyncAction();
+            }
+        }
     }
 
     /**
@@ -112,7 +125,7 @@ public abstract class AsyncOperator implements Operator {
         if (checkpoint.getPersistedCheckpoint() < checkpoint.getProcessedCheckpoint()) {
             notifyIfBlocked();
         }
-        if (failure.get() != null) {
+        if (closed || failure.get() != null) {
             discardPages();
         }
     }
@@ -152,6 +165,7 @@ public abstract class AsyncOperator implements Operator {
     @Override
     public final void close() {
         finish();
+        closed = true;
         discardPages();
         doClose();
     }
@@ -159,9 +173,6 @@ public abstract class AsyncOperator implements Operator {
     @Override
     public void finish() {
         finished = true;
-        if (failure.get() != null) {
-            discardPages();
-        }
     }
 
     @Override

+ 33 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java

@@ -9,6 +9,7 @@ package org.elasticsearch.compute.operator;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ContextPreservingActionListener;
+import org.elasticsearch.action.support.RefCountingListener;
 import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
@@ -59,7 +60,7 @@ public class Driver implements Releasable, Describable {
     private final AtomicReference<SubscribableListener<Void>> blocked = new AtomicReference<>();
 
     private final AtomicBoolean started = new AtomicBoolean();
-    private final SubscribableListener<Void> completionListener = new SubscribableListener<>();
+    private final CompletionListener completionListener;
 
     /**
      * Status reported to the tasks API. We write the status at most once every
@@ -97,6 +98,7 @@ public class Driver implements Releasable, Describable {
         this.activeOperators.add(sink);
         this.statusNanos = statusInterval.nanos();
         this.releasable = releasable;
+        this.completionListener = new CompletionListener(driverContext);
         this.status = new AtomicReference<>(new DriverStatus(sessionId, System.currentTimeMillis(), DriverStatus.Status.QUEUED, List.of()));
     }
 
@@ -393,4 +395,34 @@ public class Driver implements Releasable, Describable {
             activeOperators.stream().map(o -> new DriverStatus.OperatorStatus(o.toString(), o.status())).toList()
         );
     }
+
+    /**
+     * A listener that is notified when both the Driver and its DriverContext are completed.
+     */
+    private static class CompletionListener implements ActionListener<Void> {
+        private final SubscribableListener<Void> completionListener;
+        private final ActionListener<Void> driverListener;
+
+        CompletionListener(DriverContext driverContext) {
+            this.completionListener = new SubscribableListener<>();
+            try (var refs = new RefCountingListener(1, completionListener)) {
+                driverListener = refs.acquire();
+                driverContext.waitForAsyncActions(refs.acquire());
+            }
+        }
+
+        void addListener(ActionListener<Void> listener) {
+            completionListener.addListener(listener);
+        }
+
+        @Override
+        public void onResponse(Void unused) {
+            driverListener.onResponse(null);
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            driverListener.onFailure(e);
+        }
+    }
 }

+ 53 - 1
x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java

@@ -7,6 +7,8 @@
 
 package org.elasticsearch.compute.operator;
 
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.compute.data.BlockFactory;
@@ -17,6 +19,8 @@ import java.util.Collections;
 import java.util.IdentityHashMap;
 import java.util.Objects;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
 /**
@@ -34,7 +38,11 @@ import java.util.concurrent.atomic.AtomicReference;
  * This allows to "transfer ownership" of a shared resource across operators (and even across
  * Drivers), while ensuring that the resource can be correctly released when no longer needed.
  *
- * Currently only supports releasables, but additional driver-local context can be added.
+ * DriverContext can also be used to track async actions. The driver may close an operator while
+ * some of its async actions are still running. To prevent the driver from finishing in this case,
+ * methods {@link #addAsyncAction()} and {@link #removeAsyncAction()} are provided for tracking
+ * such actions. Subsequently, the driver uses {@link #waitForAsyncActions(ActionListener)} to
+ * await the completion of all async actions before finalizing the Driver.
  */
 public class DriverContext {
 
@@ -47,6 +55,8 @@ public class DriverContext {
 
     private final BlockFactory blockFactory;
 
+    private final AsyncActions asyncActions = new AsyncActions();
+
     public DriverContext(BigArrays bigArrays, BlockFactory blockFactory) {
         Objects.requireNonNull(bigArrays);
         Objects.requireNonNull(blockFactory);
@@ -119,6 +129,7 @@ public class DriverContext {
         }
         // must be called by the thread executing the driver.
         // no more updates to this context.
+        asyncActions.finish();
         var itr = workingSet.iterator();
         workingSet = null;
         Set<Releasable> releasableSet = Collections.newSetFromMap(new IdentityHashMap<>());
@@ -135,4 +146,45 @@ public class DriverContext {
             throw new IllegalStateException("not finished");
         }
     }
+
+    public void waitForAsyncActions(ActionListener<Void> listener) {
+        asyncActions.addListener(listener);
+    }
+
+    public void addAsyncAction() {
+        asyncActions.addInstance();
+    }
+
+    public void removeAsyncAction() {
+        asyncActions.removeInstance();
+    }
+
+    private static class AsyncActions {
+        private final SubscribableListener<Void> completion = new SubscribableListener<>();
+        private final AtomicBoolean finished = new AtomicBoolean();
+        private final AtomicInteger instances = new AtomicInteger(1);
+
+        void addInstance() {
+            if (finished.get()) {
+                throw new IllegalStateException("DriverContext was finished already");
+            }
+            instances.incrementAndGet();
+        }
+
+        void removeInstance() {
+            if (instances.decrementAndGet() == 0) {
+                completion.onResponse(null);
+            }
+        }
+
+        void addListener(ActionListener<Void> listener) {
+            completion.addListener(listener);
+        }
+
+        void finish() {
+            if (finished.compareAndSet(false, true)) {
+                removeInstance();
+            }
+        }
+    }
 }

+ 14 - 6
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AsyncOperatorTests.java

@@ -97,7 +97,7 @@ public class AsyncOperatorTests extends ESTestCase {
             }
         };
         int maxConcurrentRequests = randomIntBetween(1, 10);
-        AsyncOperator asyncOperator = new AsyncOperator(maxConcurrentRequests) {
+        AsyncOperator asyncOperator = new AsyncOperator(driverContext, maxConcurrentRequests) {
             final LookupService lookupService = new LookupService(threadPool, driverContext.blockFactory(), dict, maxConcurrentRequests);
 
             @Override
@@ -110,7 +110,16 @@ public class AsyncOperatorTests extends ESTestCase {
 
             }
         };
-        Iterator<Long> it = ids.iterator();
+        List<Operator> intermediateOperators = new ArrayList<>();
+        intermediateOperators.add(asyncOperator);
+        final Iterator<Long> it;
+        if (randomBoolean()) {
+            int limit = between(1, ids.size());
+            it = ids.subList(0, limit).iterator();
+            intermediateOperators.add(new LimitOperator(limit));
+        } else {
+            it = ids.iterator();
+        }
         SinkOperator outputOperator = new PageConsumerOperator(page -> {
             try (Releasable ignored = page::releaseBlocks) {
                 assertThat(page.getBlockCount(), equalTo(2));
@@ -131,7 +140,7 @@ public class AsyncOperatorTests extends ESTestCase {
             }
         });
         PlainActionFuture<Void> future = new PlainActionFuture<>();
-        Driver driver = new Driver(driverContext, sourceOperator, List.of(asyncOperator), outputOperator, () -> assertFalse(it.hasNext()));
+        Driver driver = new Driver(driverContext, sourceOperator, intermediateOperators, outputOperator, () -> assertFalse(it.hasNext()));
         Driver.start(threadPool.getThreadContext(), threadPool.executor(ESQL_TEST_EXECUTOR), driver, between(1, 10000), future);
         future.actionGet();
     }
@@ -139,7 +148,7 @@ public class AsyncOperatorTests extends ESTestCase {
     public void testStatus() {
         DriverContext driverContext = driverContext();
         Map<Page, ActionListener<Page>> handlers = new HashMap<>();
-        AsyncOperator operator = new AsyncOperator(2) {
+        AsyncOperator operator = new AsyncOperator(driverContext, 2) {
             @Override
             protected void performAsync(Page inputPage, ActionListener<Page> listener) {
                 handlers.put(inputPage, listener);
@@ -185,7 +194,6 @@ public class AsyncOperatorTests extends ESTestCase {
         operator.close();
     }
 
-    @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/102264")
     public void testFailure() throws Exception {
         DriverContext driverContext = driverContext();
         final SequenceLongBlockSourceOperator sourceOperator = new SequenceLongBlockSourceOperator(
@@ -194,7 +202,7 @@ public class AsyncOperatorTests extends ESTestCase {
         );
         int maxConcurrentRequests = randomIntBetween(1, 10);
         AtomicBoolean failed = new AtomicBoolean();
-        AsyncOperator asyncOperator = new AsyncOperator(maxConcurrentRequests) {
+        AsyncOperator asyncOperator = new AsyncOperator(driverContext, maxConcurrentRequests) {
             @Override
             protected void performAsync(Page inputPage, ActionListener<Page> listener) {
                 ActionRunnable<Page> command = new ActionRunnable<>(listener) {

+ 20 - 0
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverContextTests.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.compute.operator;
 
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.MockBigArrays;
@@ -14,6 +15,7 @@ import org.elasticsearch.common.util.PageCacheRecycler;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.compute.data.BlockFactory;
 import org.elasticsearch.core.Releasable;
+import org.elasticsearch.core.Releasables;
 import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.FixedExecutorBuilder;
@@ -136,6 +138,24 @@ public class DriverContextTests extends ESTestCase {
         finishedReleasables.stream().flatMap(Set::stream).forEach(Releasable::close);
     }
 
+    public void testWaitForAsyncActions() {
+        DriverContext driverContext = new AssertingDriverContext();
+        driverContext.addAsyncAction();
+        driverContext.addAsyncAction();
+        PlainActionFuture<Void> future = new PlainActionFuture<>();
+        driverContext.waitForAsyncActions(future);
+        assertFalse(future.isDone());
+        driverContext.finish();
+        assertFalse(future.isDone());
+        IllegalStateException error = expectThrows(IllegalStateException.class, driverContext::addAsyncAction);
+        assertThat(error.getMessage(), equalTo("DriverContext was finished already"));
+        driverContext.removeAsyncAction();
+        assertFalse(future.isDone());
+        driverContext.removeAsyncAction();
+        assertTrue(future.isDone());
+        Releasables.closeExpectNoException(driverContext.getSnapshot());
+    }
+
     static TestDriver newTestDriver(int unused) {
         var driverContext = new AssertingDriverContext();
         return new TestDriver(driverContext, randomInt(128), driverContext.bigArrays());

+ 3 - 3
x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/DriverTests.java

@@ -53,7 +53,7 @@ public class DriverTests extends ESTestCase {
                     assertRunningWithRegularUser(threadPool);
                     return super.getOutput();
                 }
-            }, List.of(warning1, new SwitchContextOperator(threadPool), warning2), new PageConsumerOperator(page -> {
+            }, List.of(warning1, new SwitchContextOperator(driverContext, threadPool), warning2), new PageConsumerOperator(page -> {
                 assertRunningWithRegularUser(threadPool);
                 outPages.add(page);
             }), () -> {});
@@ -106,8 +106,8 @@ public class DriverTests extends ESTestCase {
     static class SwitchContextOperator extends AsyncOperator {
         private final ThreadPool threadPool;
 
-        SwitchContextOperator(ThreadPool threadPool) {
-            super(between(1, 3));
+        SwitchContextOperator(DriverContext driverContext, ThreadPool threadPool) {
+            super(driverContext, between(1, 3));
             this.threadPool = threadPool;
         }
 

+ 3 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupOperator.java

@@ -56,6 +56,7 @@ public final class EnrichLookupOperator extends AsyncOperator {
         public Operator get(DriverContext driverContext) {
             return new EnrichLookupOperator(
                 sessionId,
+                driverContext,
                 parentTask,
                 maxOutstandingRequests,
                 inputChannel,
@@ -70,6 +71,7 @@ public final class EnrichLookupOperator extends AsyncOperator {
 
     public EnrichLookupOperator(
         String sessionId,
+        DriverContext driverContext,
         CancellableTask parentTask,
         int maxOutstandingRequests,
         int inputChannel,
@@ -79,7 +81,7 @@ public final class EnrichLookupOperator extends AsyncOperator {
         String matchField,
         List<NamedExpression> enrichFields
     ) {
-        super(maxOutstandingRequests);
+        super(driverContext, maxOutstandingRequests);
         this.sessionId = sessionId;
         this.parentTask = parentTask;
         this.inputChannel = inputChannel;