Procházet zdrojové kódy

[Transform] Make transform `_preview` request cancellable (#91313)

Przemysław Witek před 2 roky
rodič
revize
a8a684ebab

+ 6 - 0
docs/changelog/91313.yaml

@@ -0,0 +1,6 @@
+pr: 91313
+summary: Make transform `_preview` request cancellable
+area: Transform
+type: bug
+issues:
+ - 91286

+ 9 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/action/PreviewTransformAction.java

@@ -17,6 +17,9 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentObject;
@@ -37,6 +40,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
+import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
 
 public class PreviewTransformAction extends ActionType<PreviewTransformAction.Response> {
@@ -135,6 +139,11 @@ public class PreviewTransformAction extends ActionType<PreviewTransformAction.Re
             Request other = (Request) obj;
             return Objects.equals(config, other.config);
         }
+
+        @Override
+        public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
+            return new CancellableTask(id, type, action, format("preview_transform[%s]", config.getId()), parentTaskId, headers);
+        }
     }
 
     public static class Response extends ActionResponse implements ToXContentObject {

+ 12 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/transform/action/PreviewTransformActionRequestTests.java

@@ -11,6 +11,9 @@ import org.elasticsearch.action.support.master.AcknowledgedRequest;
 import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.xcontent.DeprecationHandler;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.json.JsonXContent;
@@ -22,9 +25,11 @@ import org.elasticsearch.xpack.core.transform.transforms.TransformConfigTests;
 import org.elasticsearch.xpack.core.transform.transforms.pivot.PivotConfigTests;
 
 import java.io.IOException;
+import java.util.Map;
 
 import static org.elasticsearch.xpack.core.transform.transforms.SourceConfigTests.randomSourceConfig;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 
 public class PreviewTransformActionRequestTests extends AbstractSerializingTransformTestCase<Request> {
@@ -132,4 +137,11 @@ public class PreviewTransformActionRequestTests extends AbstractSerializingTrans
             assertThat(request.getConfig().getDestination().getPipeline(), is(equalTo(expectedDestPipeline)));
         }
     }
+
+    public void testCreateTask() {
+        Request request = createTestInstance();
+        Task task = request.createTask(123, "type", "action", TaskId.EMPTY_TASK_ID, Map.of());
+        assertThat(task, is(instanceOf(CancellableTask.class)));
+        assertThat(task.getDescription(), is(equalTo("preview_transform[transform-preview]")));
+    }
 }

+ 20 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/transform/preview_transforms.yml

@@ -156,6 +156,26 @@ setup:
   - match: { generated_dest_index.mappings.properties.by-hour.type: "date" }
   - match: { generated_dest_index.mappings.properties.avg_response.type: "double" }
 
+---
+"Test preview transform with timeout":
+  - do:
+      transform.preview_transform:
+        timeout: "10s"
+        body: >
+          {
+            "source": { "index": "airline-data" },
+            "pivot": {
+              "group_by": {
+                "airline": {"terms": {"field": "airline"}},
+                "by-hour": {"date_histogram": {"fixed_interval": "1h", "field": "time"}}},
+              "aggs": {
+                "avg_response": {"avg": {"field": "responsetime"}},
+                "time.max": {"max": {"field": "time"}},
+                "time.min": {"min": {"field": "time"}}
+              }
+            }
+          }
+
 ---
 "Test preview transform with disabled mapping deduction":
   - do:

+ 15 - 4
x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/action/TransportPreviewTransformAction.java

@@ -15,6 +15,7 @@ import org.elasticsearch.action.ingest.SimulatePipelineResponse;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.client.internal.ParentTaskAssigningClient;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -25,10 +26,12 @@ import org.elasticsearch.common.logging.HeaderWarning;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.ingest.IngestService;
 import org.elasticsearch.license.License;
 import org.elasticsearch.license.RemoteClusterLicenseChecker;
 import org.elasticsearch.tasks.Task;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xcontent.ToXContent;
@@ -112,6 +115,7 @@ public class TransportPreviewTransformAction extends HandledTransportAction<Requ
 
     @Override
     protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
+        TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());
         final ClusterState clusterState = clusterService.state();
         TransformNodes.throwIfNoTransformNodes(clusterState);
 
@@ -137,6 +141,8 @@ public class TransportPreviewTransformAction extends HandledTransportAction<Requ
             validateConfigResponse -> useSecondaryAuthIfAvailable(
                 securityContext,
                 () -> getPreview(
+                    parentTaskId,
+                    request.timeout(),
                     config.getId(), // note: @link{PreviewTransformAction} sets an id, so this is never null
                     function,
                     config.getSource(),
@@ -175,7 +181,7 @@ public class TransportPreviewTransformAction extends HandledTransportAction<Requ
                 securityContext,
                 indexNameExpressionResolver,
                 clusterState,
-                client,
+                new ParentTaskAssigningClient(client, parentTaskId),
                 config,
                 // We don't want to check privileges for a dummy (placeholder) index and the placeholder is inserted as config.dest.index
                 // early in the REST action so the only possibility we have here is string comparison.
@@ -189,6 +195,8 @@ public class TransportPreviewTransformAction extends HandledTransportAction<Requ
 
     @SuppressWarnings("unchecked")
     private void getPreview(
+        TaskId parentTaskId,
+        TimeValue timeout,
         String transformId,
         Function function,
         SourceConfig source,
@@ -197,6 +205,8 @@ public class TransportPreviewTransformAction extends HandledTransportAction<Requ
         SyncConfig syncConfig,
         ActionListener<Response> listener
     ) {
+        Client parentTaskAssigningClient = new ParentTaskAssigningClient(client, parentTaskId);
+
         final SetOnce<Map<String, String>> mappings = new SetOnce<>();
 
         ActionListener<SimulatePipelineResponse> pipelineResponseActionListener = ActionListener.wrap(simulatePipelineResponse -> {
@@ -256,7 +266,7 @@ public class TransportPreviewTransformAction extends HandledTransportAction<Requ
                     builder.endObject();
                     var pipelineRequest = new SimulatePipelineRequest(BytesReference.bytes(builder), XContentType.JSON);
                     pipelineRequest.setId(pipeline);
-                    client.execute(SimulatePipelineAction.INSTANCE, pipelineRequest, pipelineResponseActionListener);
+                    parentTaskAssigningClient.execute(SimulatePipelineAction.INSTANCE, pipelineRequest, pipelineResponseActionListener);
                 }
             }
         }, listener::onFailure);
@@ -264,7 +274,8 @@ public class TransportPreviewTransformAction extends HandledTransportAction<Requ
         ActionListener<Map<String, String>> deduceMappingsListener = ActionListener.wrap(deducedMappings -> {
             mappings.set(deducedMappings);
             function.preview(
-                client,
+                parentTaskAssigningClient,
+                timeout,
                 ClientHelper.getPersistableSafeSecurityHeaders(threadPool.getThreadContext(), clusterService.state()),
                 source,
                 deducedMappings,
@@ -273,6 +284,6 @@ public class TransportPreviewTransformAction extends HandledTransportAction<Requ
             );
         }, listener::onFailure);
 
-        function.deduceMappings(client, source, deduceMappingsListener);
+        function.deduceMappings(parentTaskAssigningClient, source, deduceMappingsListener);
     }
 }

+ 1 - 1
x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/action/TransportValidateTransformAction.java

@@ -136,7 +136,7 @@ public class TransportValidateTransformAction extends HandledTransportAction<Req
             if (request.isDeferValidation()) {
                 validateQueryListener.onResponse(true);
             } else {
-                function.validateQuery(client, config.getSource(), validateQueryListener);
+                function.validateQuery(client, config.getSource(), request.timeout(), validateQueryListener);
             }
         }, listener::onFailure);
 

+ 4 - 1
x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/rest/action/RestPreviewTransformAction.java

@@ -10,11 +10,13 @@ package org.elasticsearch.xpack.transform.rest.action;
 import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.master.AcknowledgedRequest;
+import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.client.internal.node.NodeClient;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.rest.BaseRestHandler;
 import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.action.RestCancellableNodeClient;
 import org.elasticsearch.rest.action.RestToXContentListener;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.transform.TransformField;
@@ -47,7 +49,7 @@ public class RestPreviewTransformAction extends BaseRestHandler {
     }
 
     @Override
-    protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
+    protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient) throws IOException {
         String transformId = restRequest.param(TransformField.ID.getPreferredName());
 
         if (Strings.isNullOrEmpty(transformId) && restRequest.hasContentOrSourceParam() == false) {
@@ -72,6 +74,7 @@ public class RestPreviewTransformAction extends BaseRestHandler {
             previewRequestHolder.set(PreviewTransformAction.Request.fromXContent(restRequest.contentOrSourceParamParser(), timeout));
         }
 
+        Client client = new RestCancellableNodeClient(nodeClient, restRequest.getHttpChannel());
         return channel -> {
             RestToXContentListener<PreviewTransformAction.Response> listener = new RestToXContentListener<>(channel);
 

+ 6 - 1
x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/Function.java

@@ -11,6 +11,8 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
@@ -124,6 +126,7 @@ public interface Function {
      * Create a preview of the function.
      *
      * @param client a client instance for querying
+     * @param timeout search query timeout
      * @param headers headers to be used to query only for what the caller is allowed to
      * @param sourceConfig the source configuration
      * @param fieldTypeMap mapping of field types
@@ -132,6 +135,7 @@ public interface Function {
      */
     void preview(
         Client client,
+        @Nullable TimeValue timeout,
         Map<String, String> headers,
         SourceConfig sourceConfig,
         Map<String, String> fieldTypeMap,
@@ -175,9 +179,10 @@ public interface Function {
      *
      * @param client a client instance for querying the source
      * @param sourceConfig the source configuration
+     * @param timeout search query timeout
      * @param listener the result listener
      */
-    void validateQuery(Client client, SourceConfig sourceConfig, ActionListener<Boolean> listener);
+    void validateQuery(Client client, SourceConfig sourceConfig, @Nullable TimeValue timeout, ActionListener<Boolean> listener);
 
     /**
      * Create a change collector instance and return it

+ 8 - 5
x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/common/AbstractCompositeAggFunction.java

@@ -17,6 +17,7 @@ import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.aggregations.Aggregations;
@@ -63,6 +64,7 @@ public abstract class AbstractCompositeAggFunction implements Function {
     @Override
     public void preview(
         Client client,
+        TimeValue timeout,
         Map<String, String> headers,
         SourceConfig sourceConfig,
         Map<String, String> fieldTypeMap,
@@ -75,7 +77,7 @@ public abstract class AbstractCompositeAggFunction implements Function {
             ClientHelper.TRANSFORM_ORIGIN,
             client,
             SearchAction.INSTANCE,
-            buildSearchRequest(sourceConfig, null, numberOfBuckets),
+            buildSearchRequest(sourceConfig, timeout, numberOfBuckets),
             ActionListener.wrap(r -> {
                 try {
                     final Aggregations aggregations = r.getAggregations();
@@ -102,8 +104,8 @@ public abstract class AbstractCompositeAggFunction implements Function {
     }
 
     @Override
-    public void validateQuery(Client client, SourceConfig sourceConfig, ActionListener<Boolean> listener) {
-        SearchRequest searchRequest = buildSearchRequest(sourceConfig, null, TEST_QUERY_PAGE_SIZE);
+    public void validateQuery(Client client, SourceConfig sourceConfig, TimeValue timeout, ActionListener<Boolean> listener) {
+        SearchRequest searchRequest = buildSearchRequest(sourceConfig, timeout, TEST_QUERY_PAGE_SIZE);
         client.execute(SearchAction.INSTANCE, searchRequest, ActionListener.wrap(response -> {
             if (response == null) {
                 listener.onFailure(new ValidationException().addValidationError("Unexpected null response from test query"));
@@ -173,9 +175,10 @@ public abstract class AbstractCompositeAggFunction implements Function {
         TransformProgress progress
     );
 
-    private SearchRequest buildSearchRequest(SourceConfig sourceConfig, Map<String, Object> position, int pageSize) {
+    private SearchRequest buildSearchRequest(SourceConfig sourceConfig, TimeValue timeout, int pageSize) {
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(sourceConfig.getQueryConfig().getQuery())
-            .runtimeMappings(sourceConfig.getRuntimeMappings());
+            .runtimeMappings(sourceConfig.getRuntimeMappings())
+            .timeout(timeout);
         buildSearchQuery(sourceBuilder, null, pageSize);
         return new SearchRequest(sourceConfig.getIndex()).source(sourceBuilder).indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN);
     }

+ 1 - 1
x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/PivotTests.java

@@ -447,7 +447,7 @@ public class PivotTests extends ESTestCase {
     private static void validate(Client client, SourceConfig source, Function pivot, boolean expectValid) throws Exception {
         CountDownLatch latch = new CountDownLatch(1);
         final AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
-        pivot.validateQuery(client, source, ActionListener.wrap(validity -> {
+        pivot.validateQuery(client, source, null, ActionListener.wrap(validity -> {
             assertEquals(expectValid, validity);
             latch.countDown();
         }, e -> {