Ver código fonte

[ML] Preserve response headers in Datafeed preview (#103923)

Fixes a bug where the datafeed preview API would lose the 
`X-elastic-product` response header if security was disable.
David Kyle 1 ano atrás
pai
commit
8efc72c5d7

+ 5 - 0
docs/changelog/103923.yaml

@@ -0,0 +1,5 @@
+pr: 103923
+summary: Preserve response headers in Datafeed preview
+area: Machine Learning
+type: bug
+issues: []

+ 11 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DatafeedJobsRestIT.java

@@ -681,7 +681,6 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
         options.addHeader("Authorization", BASIC_AUTH_VALUE_ML_ADMIN);
         getFeed.setOptions(options);
         ResponseException e = expectThrows(ResponseException.class, () -> client().performRequest(getFeed));
-
         assertThat(e.getMessage(), containsString("[indices:data/read/field_caps] is unauthorized for user [ml_admin]"));
     }
 
@@ -722,7 +721,12 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
         options.addHeader("es-secondary-authorization", BASIC_AUTH_VALUE_ML_ADMIN_WITH_SOME_DATA_ACCESS);
         getFeed.setOptions(options);
         // Should not fail as secondary auth has permissions.
-        client().performRequest(getFeed);
+        var response = client().performRequest(getFeed);
+        assertXProductResponseHeader(response);
+    }
+
+    private void assertXProductResponseHeader(Response response) {
+        assertEquals("Elasticsearch", response.getHeader("X-elastic-product"));
     }
 
     public void testLookbackOnlyGivenAggregationsWithHistogram() throws Exception {
@@ -1518,6 +1522,7 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
 
         public void execute() throws Exception {
             Response jobResponse = createJob(jobId, airlineVariant);
+            assertXProductResponseHeader(jobResponse);
             assertThat(jobResponse.getStatusLine().getStatusCode(), equalTo(200));
             String datafeedId = "datafeed-" + jobId;
             new DatafeedBuilder(datafeedId, jobId, dataIndex).setScriptedFields(scriptedFields).build();
@@ -1529,6 +1534,7 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
             Response jobStatsResponse = client().performRequest(
                 new Request("GET", MachineLearning.BASE_PATH + "anomaly_detectors/" + jobId + "/_stats")
             );
+            assertXProductResponseHeader(jobStatsResponse);
             String jobStatsResponseAsString = EntityUtils.toString(jobStatsResponse.getEntity());
             if (shouldSucceedInput) {
                 assertThat(jobStatsResponseAsString, containsString("\"input_record_count\":2"));
@@ -1556,12 +1562,14 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
         options.addHeader("Authorization", authHeader);
         request.setOptions(options);
         Response startDatafeedResponse = client().performRequest(request);
+        assertXProductResponseHeader(startDatafeedResponse);
         assertThat(EntityUtils.toString(startDatafeedResponse.getEntity()), containsString("\"started\":true"));
         assertBusy(() -> {
             try {
                 Response datafeedStatsResponse = client().performRequest(
                     new Request("GET", MachineLearning.BASE_PATH + "datafeeds/" + datafeedId + "/_stats")
                 );
+                assertXProductResponseHeader(datafeedStatsResponse);
                 assertThat(EntityUtils.toString(datafeedStatsResponse.getEntity()), containsString("\"state\":\"stopped\""));
             } catch (Exception e) {
                 throw new RuntimeException(e);
@@ -1575,6 +1583,7 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
                 Response jobStatsResponse = client().performRequest(
                     new Request("GET", MachineLearning.BASE_PATH + "anomaly_detectors/" + jobId + "/_stats")
                 );
+                assertXProductResponseHeader(jobStatsResponse);
                 assertThat(EntityUtils.toString(jobStatsResponse.getEntity()), containsString("\"state\":\"closed\""));
             } catch (Exception e) {
                 throw new RuntimeException(e);

+ 163 - 0
x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DatafeedWithoutSecurityRestIT.java

@@ -0,0 +1,163 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.integration;
+
+import org.apache.http.util.EntityUtils;
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.RequestOptions;
+import org.elasticsearch.client.Response;
+import org.elasticsearch.test.rest.ESRestTestCase;
+import org.junit.Before;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.not;
+
+public class DatafeedWithoutSecurityRestIT extends ESRestTestCase {
+
+    @Before
+    public void setUpData() throws Exception {
+        addAirlineData();
+    }
+
+    /**
+     * The main purpose this test is to ensure the X-elastic-product
+     * header is returned when security is disabled. The vast majority
+     * of Datafeed test coverage is in DatafeedJobsRestIT but that
+     * suite runs with security enabled.
+     */
+    public void testPreviewMissingHeader() throws Exception {
+        String jobId = "missing-header";
+        Request createJobRequest = new Request("PUT", "/_ml/anomaly_detectors/" + jobId);
+        createJobRequest.setJsonEntity("""
+            {
+              "description": "Aggs job",
+              "analysis_config": {
+                "bucket_span": "1h",
+                "detectors": [
+                  {
+                    "function": "count",
+                    "partition_field_name": "airline"
+                  }
+                ],
+                "influencers": [
+                  "airline"
+                ],
+                "model_prune_window": "30d"
+              },
+                  "model_plot_config": {
+                    "enabled": false,
+                    "annotations_enabled": false
+                  },
+                  "analysis_limits": {
+                    "model_memory_limit": "11mb",
+                    "categorization_examples_limit": 4
+                  },
+              "data_description" : {"time_field": "time stamp"}
+            }""");
+        client().performRequest(createJobRequest);
+
+        String datafeedId = "datafeed-" + jobId;
+        Request createDatafeedRequest = new Request("PUT", "/_ml/datafeeds/" + datafeedId);
+        createDatafeedRequest.setJsonEntity("""
+            {
+                "job_id": "missing-header",
+                "query": {
+                    "bool": {
+                        "must": [
+                            {
+                                "match_all": {}
+                            }
+                        ]
+                    }
+                },
+                "indices": [
+                    "airline-data"
+                ],
+                "scroll_size": 1000
+            }
+            """);
+        client().performRequest(createDatafeedRequest);
+
+        Request getFeed = new Request("GET", "/_ml/datafeeds/" + datafeedId + "/_preview");
+        RequestOptions.Builder options = getFeed.getOptions().toBuilder();
+        getFeed.setOptions(options);
+        var previewResponse = client().performRequest(getFeed);
+        assertXProductResponseHeader(previewResponse);
+
+        client().performRequest(new Request("POST", "/_ml/anomaly_detectors/" + jobId + "/_open"));
+        Request startRequest = new Request("POST", "/_ml/datafeeds/" + datafeedId + "/_start");
+        Response startDatafeedResponse = client().performRequest(startRequest);
+        assertXProductResponseHeader(startDatafeedResponse);
+    }
+
+    private void assertXProductResponseHeader(Response response) {
+        assertEquals("Elasticsearch", response.getHeader("X-elastic-product"));
+    }
+
+    private void addAirlineData() throws IOException {
+        StringBuilder bulk = new StringBuilder();
+
+        // Create index with source = enabled, doc_values = enabled, stored = false + multi-field
+        Request createAirlineDataRequest = new Request("PUT", "/airline-data");
+        // space in 'time stamp' is intentional
+        createAirlineDataRequest.setJsonEntity("""
+            {
+              "mappings": {
+                "runtime": {
+                  "airline_lowercase_rt": {
+                    "type": "keyword",
+                    "script": {
+                      "source": "emit(params._source.airline.toLowerCase())"
+                    }
+                  }
+                },
+                "properties": {
+                  "time stamp": {
+                    "type": "date"
+                  },
+                  "airline": {
+                    "type": "text",
+                    "fields": {
+                      "text": {
+                        "type": "text"
+                      },
+                      "keyword": {
+                        "type": "keyword"
+                      }
+                    }
+                  },
+                  "responsetime": {
+                    "type": "float"
+                  }
+                }
+              }
+            }""");
+        client().performRequest(createAirlineDataRequest);
+
+        bulk.append("""
+            {"index": {"_index": "airline-data", "_id": 1}}
+            {"time stamp":"2016-06-01T00:00:00Z","airline":"AAA","responsetime":135.22}
+            {"index": {"_index": "airline-data", "_id": 2}}
+            {"time stamp":"2016-06-01T01:59:00Z","airline":"AAA","responsetime":541.76}
+            """);
+
+        bulkIndex(bulk.toString());
+    }
+
+    private void bulkIndex(String bulk) throws IOException {
+        Request bulkRequest = new Request("POST", "/_bulk");
+        bulkRequest.setJsonEntity(bulk);
+        bulkRequest.addParameter("refresh", "true");
+        bulkRequest.addParameter("pretty", null);
+        String bulkResponse = EntityUtils.toString(client().performRequest(bulkRequest).getEntity());
+        assertThat(bulkResponse, not(containsString("\"errors\": false")));
+    }
+
+}

+ 12 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.java

@@ -11,6 +11,7 @@ import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListenerResponseHandler;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.ContextPreservingActionListener;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.client.internal.ParentTaskAssigningClient;
 import org.elasticsearch.client.internal.node.NodeClient;
@@ -153,11 +154,20 @@ public class TransportExplainDataFrameAnalyticsAction extends HandledTransportAc
                 );
             });
         } else {
+            var responseHeaderPreservingListener = ContextPreservingActionListener.wrapPreservingContext(
+                listener,
+                threadPool.getThreadContext()
+            );
             extractedFieldsDetectorFactory.createFromSource(
                 request.getConfig(),
                 ActionListener.wrap(
-                    extractedFieldsDetector -> explain(parentTaskId, request.getConfig(), extractedFieldsDetector, listener),
-                    listener::onFailure
+                    extractedFieldsDetector -> explain(
+                        parentTaskId,
+                        request.getConfig(),
+                        extractedFieldsDetector,
+                        responseHeaderPreservingListener
+                    ),
+                    responseHeaderPreservingListener::onFailure
                 )
             );
         }

+ 6 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDataFrameAnalyticsAction.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.action;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.ContextPreservingActionListener;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.client.internal.ParentTaskAssigningClient;
 import org.elasticsearch.client.internal.node.NodeClient;
@@ -98,7 +99,11 @@ public class TransportPreviewDataFrameAnalyticsAction extends HandledTransportAc
                 preview(task, config, listener);
             });
         } else {
-            preview(task, request.getConfig(), listener);
+            preview(
+                task,
+                request.getConfig(),
+                ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext())
+            );
         }
     }
 

+ 15 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDatafeedAction.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilities;
 import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest;
 import org.elasticsearch.action.fieldcaps.TransportFieldCapabilitiesAction;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.ContextPreservingActionListener;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.client.internal.ParentTaskAssigningClient;
@@ -132,6 +133,17 @@ public class TransportPreviewDatafeedAction extends HandledTransportAction<Previ
         PreviewDatafeedAction.Request request,
         ActionListener<PreviewDatafeedAction.Response> listener
     ) {
+        // The datafeed preview runs in its own context with the provided
+        // headers for auth. When security is not enabled the context
+        // preserving listener is required to restore the request/response
+        // headers. If security is enabled the context wrapping done in
+        // SecondaryAuthorizationUtils::useSecondaryAuthIfAvailable is
+        // sufficient to preserve the context.
+        var responseHeaderPreservingListener = ContextPreservingActionListener.wrapPreservingContext(
+            listener,
+            threadPool.getThreadContext()
+        );
+
         final QueryBuilder extraFilters = request.getStartTime().isPresent() || request.getEndTime().isPresent()
             ? null
             : QueryBuilders.boolQuery().mustNot(QueryBuilders.termsQuery(DataTierFieldMapper.NAME, "data_frozen", "data_cold"));
@@ -153,18 +165,16 @@ public class TransportPreviewDatafeedAction extends HandledTransportAction<Previ
                 xContentRegistry,
                 // Fake DatafeedTimingStatsReporter that does not have access to results index
                 new DatafeedTimingStatsReporter(new DatafeedTimingStats(datafeedConfig.getJobId()), (ts, refreshPolicy, listener1) -> {}),
-                listener.delegateFailure(
+                responseHeaderPreservingListener.delegateFailure(
                     (l, dataExtractorFactory) -> isDateNanos(
                         previewDatafeedConfig,
                         job.getDataDescription().getTimeField(),
-                        listener.delegateFailure((l2, isDateNanos) -> {
-                            final QueryBuilder hotOnly = QueryBuilders.boolQuery()
-                                .mustNot(QueryBuilders.termsQuery(DataTierFieldMapper.NAME, "data_frozen", "data_cold"));
+                        l.delegateFailure((l2, isDateNanos) -> {
                             final long start = request.getStartTime().orElse(0);
                             final long end = request.getEndTime()
                                 .orElse(isDateNanos ? DateUtils.MAX_NANOSECOND_INSTANT.toEpochMilli() : Long.MAX_VALUE);
                             DataExtractor dataExtractor = dataExtractorFactory.newExtractor(start, end);
-                            threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> previewDatafeed(dataExtractor, l));
+                            threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> previewDatafeed(dataExtractor, l2));
                         })
                     )
                 )

+ 15 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java

@@ -13,6 +13,7 @@ import org.elasticsearch.ResourceAlreadyExistsException;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.ContextPreservingActionListener;
 import org.elasticsearch.action.support.master.TransportMasterNodeAction;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.client.internal.ParentTaskAssigningClient;
@@ -190,6 +191,13 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
             return;
         }
 
+        // The datafeed will run its with the configured headers,
+        // preserve the response headers.
+        var responseHeaderPreservingListener = ContextPreservingActionListener.wrapPreservingContext(
+            listener,
+            threadPool.getThreadContext()
+        );
+
         AtomicReference<DatafeedConfig> datafeedConfigHolder = new AtomicReference<>();
         PersistentTasksCustomMetadata tasks = state.getMetadata().custom(PersistentTasksCustomMetadata.TYPE);
 
@@ -197,7 +205,7 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
             new ActionListener<>() {
                 @Override
                 public void onResponse(PersistentTasksCustomMetadata.PersistentTask<StartDatafeedAction.DatafeedParams> persistentTask) {
-                    waitForDatafeedStarted(persistentTask.getId(), params, listener);
+                    waitForDatafeedStarted(persistentTask.getId(), params, responseHeaderPreservingListener);
                 }
 
                 @Override
@@ -209,7 +217,7 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
                             RestStatus.CONFLICT
                         );
                     }
-                    listener.onFailure(e);
+                    responseHeaderPreservingListener.onFailure(e);
                 }
             };
 
@@ -228,9 +236,9 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
                     ),
                     ActionListener.wrap(response -> {
                         if (response.isSuccess() == false) {
-                            listener.onFailure(createUnlicensedError(params.getDatafeedId(), response));
+                            responseHeaderPreservingListener.onFailure(createUnlicensedError(params.getDatafeedId(), response));
                         } else if (remoteClusterClient == false) {
-                            listener.onFailure(
+                            responseHeaderPreservingListener.onFailure(
                                 ExceptionsHelper.badRequestException(
                                     Messages.getMessage(
                                         Messages.DATAFEED_NEEDS_REMOTE_CLUSTER_SEARCH,
@@ -254,7 +262,7 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
                             createDataExtractor(task, job, datafeedConfigHolder.get(), params, waitForTaskListener);
                         }
                     },
-                        e -> listener.onFailure(
+                        e -> responseHeaderPreservingListener.onFailure(
                             createUnknownLicenseError(
                                 params.getDatafeedId(),
                                 RemoteClusterLicenseChecker.remoteIndices(params.getDatafeedIndices()),
@@ -273,7 +281,7 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
             validate(job, datafeedConfigHolder.get(), tasks, xContentRegistry);
             auditDeprecations(datafeedConfigHolder.get(), job, auditor, xContentRegistry);
             createDataExtractor.accept(job);
-        }, listener::onFailure);
+        }, responseHeaderPreservingListener::onFailure);
 
         ActionListener<DatafeedConfig.Builder> datafeedListener = ActionListener.wrap(datafeedBuilder -> {
             DatafeedConfig datafeedConfig = datafeedBuilder.build();
@@ -283,7 +291,7 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
             datafeedConfigHolder.set(datafeedConfig);
 
             jobConfigProvider.getJob(datafeedConfig.getJobId(), null, jobListener);
-        }, listener::onFailure);
+        }, responseHeaderPreservingListener::onFailure);
 
         datafeedConfigProvider.getDatafeedConfig(params.getDatafeedId(), null, datafeedListener);
     }