Browse Source

Fix expiration time in ES|QL async (#135209) (#135241)

Currently, we incorrectly use the initial keep-alive value when creating 
the document for the async response. Instead, we should use the latest
expiration from the search task, which is updated by get requests.

Closes #135169
Nhat Nguyen 2 weeks ago
parent
commit
5041157c0a

+ 6 - 0
docs/changelog/135209.yaml

@@ -0,0 +1,6 @@
+pr: 135209
+summary: Fix expiration time in ES|QL async
+area: ES|QL
+type: bug
+issues:
+ - 135169

+ 5 - 7
x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/async/AsyncTaskManagementService.java

@@ -177,7 +177,6 @@ public class AsyncTaskManagementService<
     public void asyncExecute(
         Request request,
         TimeValue waitForCompletionTimeout,
-        TimeValue keepAlive,
         boolean keepOnCompletion,
         ActionListener<Response> listener
     ) {
@@ -190,7 +189,7 @@ public class AsyncTaskManagementService<
                 operation.execute(
                     request,
                     searchTask,
-                    wrapStoringListener(searchTask, waitForCompletionTimeout, keepAlive, keepOnCompletion, listener)
+                    wrapStoringListener(searchTask, waitForCompletionTimeout, keepOnCompletion, listener)
                 );
                 operationStarted = true;
             } finally {
@@ -205,7 +204,6 @@ public class AsyncTaskManagementService<
     private ActionListener<Response> wrapStoringListener(
         T searchTask,
         TimeValue waitForCompletionTimeout,
-        TimeValue keepAlive,
         boolean keepOnCompletion,
         ActionListener<Response> listener
     ) {
@@ -227,7 +225,7 @@ public class AsyncTaskManagementService<
                 if (keepOnCompletion) {
                     storeResults(
                         searchTask,
-                        new StoredAsyncResponse<>(response, threadPool.absoluteTimeInMillis() + keepAlive.getMillis()),
+                        new StoredAsyncResponse<>(response, searchTask.getExpirationTimeMillis()),
                         ActionListener.running(() -> acquiredListener.onResponse(response))
                     );
                 } else {
@@ -239,7 +237,7 @@ public class AsyncTaskManagementService<
                 // We finished after timeout - saving results
                 storeResults(
                     searchTask,
-                    new StoredAsyncResponse<>(response, threadPool.absoluteTimeInMillis() + keepAlive.getMillis()),
+                    new StoredAsyncResponse<>(response, searchTask.getExpirationTimeMillis()),
                     ActionListener.running(response::decRef)
                 );
             }
@@ -251,7 +249,7 @@ public class AsyncTaskManagementService<
                 if (keepOnCompletion) {
                     storeResults(
                         searchTask,
-                        new StoredAsyncResponse<>(e, threadPool.absoluteTimeInMillis() + keepAlive.getMillis()),
+                        new StoredAsyncResponse<>(e, searchTask.getExpirationTimeMillis()),
                         ActionListener.running(() -> acquiredListener.onFailure(e))
                     );
                 } else {
@@ -261,7 +259,7 @@ public class AsyncTaskManagementService<
                 }
             } else {
                 // We finished after timeout - saving exception
-                storeResults(searchTask, new StoredAsyncResponse<>(e, threadPool.absoluteTimeInMillis() + keepAlive.getMillis()));
+                storeResults(searchTask, new StoredAsyncResponse<>(e, searchTask.getExpirationTimeMillis()));
             }
         });
     }

+ 63 - 57
x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/async/AsyncTaskManagementServiceTests.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.action.ActionResponse;
 import org.elasticsearch.action.LegacyActionRequest;
 import org.elasticsearch.action.support.ActionTestUtils;
+import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -40,6 +41,7 @@ import java.util.concurrent.atomic.AtomicReference;
 
 import static org.elasticsearch.xpack.esql.core.async.AsyncTaskManagementService.addCompletionListener;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.Matchers.nullValue;
 
@@ -52,9 +54,11 @@ public class AsyncTaskManagementServiceTests extends ESSingleNodeTestCase {
 
     public static class TestRequest extends LegacyActionRequest {
         private final String string;
+        private final TimeValue keepAlive;
 
-        public TestRequest(String string) {
+        public TestRequest(String string, TimeValue keepAlive) {
             this.string = string;
+            this.keepAlive = keepAlive;
         }
 
         @Override
@@ -129,7 +133,7 @@ public class AsyncTaskManagementServiceTests extends ESSingleNodeTestCase {
                 headers,
                 originHeaders,
                 asyncExecutionId,
-                TimeValue.timeValueDays(5)
+                request.keepAlive
             );
         }
 
@@ -172,7 +176,7 @@ public class AsyncTaskManagementServiceTests extends ESSingleNodeTestCase {
         );
         results = new AsyncResultsService<>(
             store,
-            true,
+            false,
             TestTask.class,
             (task, listener, timeout) -> addCompletionListener(transportService.getThreadPool(), task, listener, timeout),
             transportService.getTaskManager(),
@@ -212,23 +216,17 @@ public class AsyncTaskManagementServiceTests extends ESSingleNodeTestCase {
         boolean success = randomBoolean();
         boolean keepOnCompletion = randomBoolean();
         CountDownLatch latch = new CountDownLatch(1);
-        TestRequest request = new TestRequest(success ? randomAlphaOfLength(10) : "die");
-        service.asyncExecute(
-            request,
-            TimeValue.timeValueMinutes(1),
-            TimeValue.timeValueMinutes(10),
-            keepOnCompletion,
-            ActionListener.wrap(r -> {
-                assertThat(success, equalTo(true));
-                assertThat(r.string, equalTo("response for [" + request.string + "]"));
-                assertThat(r.id, notNullValue());
-                latch.countDown();
-            }, e -> {
-                assertThat(success, equalTo(false));
-                assertThat(e.getMessage(), equalTo("test exception"));
-                latch.countDown();
-            })
-        );
+        TestRequest request = new TestRequest(success ? randomAlphaOfLength(10) : "die", TimeValue.timeValueDays(1));
+        service.asyncExecute(request, TimeValue.timeValueMinutes(1), keepOnCompletion, ActionListener.wrap(r -> {
+            assertThat(success, equalTo(true));
+            assertThat(r.string, equalTo("response for [" + request.string + "]"));
+            assertThat(r.id, notNullValue());
+            latch.countDown();
+        }, e -> {
+            assertThat(success, equalTo(false));
+            assertThat(e.getMessage(), equalTo("test exception"));
+            latch.countDown();
+        }));
         assertThat(latch.await(10, TimeUnit.SECONDS), equalTo(true));
     }
 
@@ -252,20 +250,14 @@ public class AsyncTaskManagementServiceTests extends ESSingleNodeTestCase {
         boolean timeoutOnFirstAttempt = randomBoolean();
         boolean waitForCompletion = randomBoolean();
         CountDownLatch latch = new CountDownLatch(1);
-        TestRequest request = new TestRequest(success ? randomAlphaOfLength(10) : "die");
+        TestRequest request = new TestRequest(success ? randomAlphaOfLength(10) : "die", TimeValue.timeValueDays(1));
         AtomicReference<TestResponse> responseHolder = new AtomicReference<>();
-        service.asyncExecute(
-            request,
-            TimeValue.timeValueMillis(1),
-            TimeValue.timeValueMinutes(10),
-            keepOnCompletion,
-            ActionTestUtils.assertNoFailureListener(r -> {
-                assertThat(r.string, nullValue());
-                assertThat(r.id, notNullValue());
-                assertThat(responseHolder.getAndSet(r), nullValue());
-                latch.countDown();
-            })
-        );
+        service.asyncExecute(request, TimeValue.timeValueMillis(1), keepOnCompletion, ActionTestUtils.assertNoFailureListener(r -> {
+            assertThat(r.string, nullValue());
+            assertThat(r.id, notNullValue());
+            assertThat(responseHolder.getAndSet(r), nullValue());
+            latch.countDown();
+        }));
         assertThat(latch.await(20, TimeUnit.SECONDS), equalTo(true));
 
         if (timeoutOnFirstAttempt) {
@@ -281,17 +273,11 @@ public class AsyncTaskManagementServiceTests extends ESSingleNodeTestCase {
         if (waitForCompletion) {
             // now we are waiting for the task to finish
             logger.trace("Waiting for response to complete");
-            AtomicReference<StoredAsyncResponse<TestResponse>> responseRef = new AtomicReference<>();
-            CountDownLatch getResponseCountDown = getResponse(
-                responseHolder.get().id,
-                TimeValue.timeValueSeconds(5),
-                ActionTestUtils.assertNoFailureListener(responseRef::set)
-            );
+            var getFuture = getResponse(responseHolder.get().id, TimeValue.timeValueSeconds(5), TimeValue.MINUS_ONE);
 
             executionLatch.countDown();
-            assertThat(getResponseCountDown.await(10, TimeUnit.SECONDS), equalTo(true));
+            var response = safeGet(getFuture);
 
-            StoredAsyncResponse<TestResponse> response = responseRef.get();
             if (success) {
                 assertThat(response.getException(), nullValue());
                 assertThat(response.getResponse(), notNullValue());
@@ -326,26 +312,46 @@ public class AsyncTaskManagementServiceTests extends ESSingleNodeTestCase {
         }
     }
 
+    public void testUpdateKeepAliveToTask() throws Exception {
+        long now = System.currentTimeMillis();
+        CountDownLatch executionLatch = new CountDownLatch(1);
+        AsyncTaskManagementService<TestRequest, TestResponse, TestTask> service = createManagementService(new TestOperation() {
+            @Override
+            public void execute(TestRequest request, TestTask task, ActionListener<TestResponse> listener) {
+                executorService.submit(() -> {
+                    try {
+                        assertThat(executionLatch.await(10, TimeUnit.SECONDS), equalTo(true));
+                    } catch (InterruptedException ex) {
+                        throw new AssertionError(ex);
+                    }
+                    super.execute(request, task, listener);
+                });
+            }
+        });
+        TestRequest request = new TestRequest(randomAlphaOfLength(10), TimeValue.timeValueHours(1));
+        PlainActionFuture<TestResponse> submitResp = new PlainActionFuture<>();
+        try {
+            service.asyncExecute(request, TimeValue.timeValueMillis(1), true, submitResp);
+            String id = submitResp.get().id;
+            assertThat(id, notNullValue());
+            TimeValue keepAlive = TimeValue.timeValueDays(between(1, 10));
+            var resp1 = safeGet(getResponse(id, TimeValue.ZERO, keepAlive));
+            assertThat(resp1.getExpirationTime(), greaterThanOrEqualTo(now + keepAlive.millis()));
+        } finally {
+            executionLatch.countDown();
+        }
+    }
+
     private StoredAsyncResponse<TestResponse> getResponse(String id, TimeValue timeout) throws InterruptedException {
-        AtomicReference<StoredAsyncResponse<TestResponse>> response = new AtomicReference<>();
-        assertThat(
-            getResponse(id, timeout, ActionTestUtils.assertNoFailureListener(response::set)).await(10, TimeUnit.SECONDS),
-            equalTo(true)
-        );
-        return response.get();
+        return safeGet(getResponse(id, timeout, TimeValue.MINUS_ONE));
     }
 
-    private CountDownLatch getResponse(String id, TimeValue timeout, ActionListener<StoredAsyncResponse<TestResponse>> listener) {
-        CountDownLatch responseLatch = new CountDownLatch(1);
+    private PlainActionFuture<StoredAsyncResponse<TestResponse>> getResponse(String id, TimeValue timeout, TimeValue keepAlive) {
+        PlainActionFuture<StoredAsyncResponse<TestResponse>> future = new PlainActionFuture<>();
         GetAsyncResultRequest getResultsRequest = new GetAsyncResultRequest(id).setWaitForCompletionTimeout(timeout);
-        results.retrieveResult(getResultsRequest, ActionListener.wrap(r -> {
-            listener.onResponse(r);
-            responseLatch.countDown();
-        }, e -> {
-            listener.onFailure(e);
-            responseLatch.countDown();
-        }));
-        return responseLatch;
+        getResultsRequest.setKeepAlive(keepAlive);
+        results.retrieveResult(getResultsRequest, future);
+        return future;
     }
 
 }

+ 1 - 0
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java

@@ -312,6 +312,7 @@ public class AsyncEsqlQueryActionIT extends AbstractPausableIntegTestCase {
                 assertThat(resp.isRunning(), is(false));
             }
         });
+        assertThat(getExpirationFromDoc(asyncId), greaterThanOrEqualTo(nowInMillis + keepAlive.getMillis()));
         // update the keepAlive after the query has completed
         int iters = between(1, 5);
         for (int i = 0; i < iters; i++) {

+ 1 - 7
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java

@@ -191,13 +191,7 @@ public class TransportEsqlQueryAction extends HandledTransportAction<EsqlQueryRe
     private void doExecuteForked(Task task, EsqlQueryRequest request, ActionListener<EsqlQueryResponse> listener) {
         assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH);
         if (requestIsAsync(request)) {
-            asyncTaskManagementService.asyncExecute(
-                request,
-                request.waitForCompletionTimeout(),
-                request.keepAlive(),
-                request.keepOnCompletion(),
-                listener
-            );
+            asyncTaskManagementService.asyncExecute(request, request.waitForCompletionTimeout(), request.keepOnCompletion(), listener);
         } else {
             innerExecute(task, request, listener);
         }