Forráskód Böngészése

Merge pull request #19440 from rjernst/rest_headers

Plugins: Make rest headers registration pull based
Ryan Ernst 9 éve
szülő
commit
9b6e2a8e2f

+ 4 - 1
core/src/main/java/org/elasticsearch/action/ActionModule.java

@@ -20,10 +20,12 @@
 package org.elasticsearch.action;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 import org.elasticsearch.action.admin.cluster.allocation.ClusterAllocationExplainAction;
 import org.elasticsearch.action.admin.cluster.allocation.TransportClusterAllocationExplainAction;
@@ -335,7 +337,8 @@ public class ActionModule extends AbstractModule {
         actionFilters = setupActionFilters(actionPlugins, ingestEnabled);
         autoCreateIndex = transportClient ? null : new AutoCreateIndex(settings, resolver);
         destructiveOperations = new DestructiveOperations(settings, clusterSettings);
-        restController = new RestController(settings);
+        Set<String> headers = actionPlugins.stream().flatMap(p -> p.getRestHeaders().stream()).collect(Collectors.toSet());
+        restController = new RestController(settings, headers);
     }
 
     public Map<String, ActionHandler<?, ?>> getActions() {

+ 13 - 6
core/src/main/java/org/elasticsearch/plugins/ActionPlugin.java

@@ -28,11 +28,11 @@ import org.elasticsearch.action.support.TransportActions;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.rest.RestHandler;
 
+import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 
-import static java.util.Collections.emptyList;
-
 /**
  * An additional extension point for {@link Plugin}s that extends Elasticsearch's scripting functionality. Implement it like this:
  * <pre>{@code
@@ -50,22 +50,29 @@ public interface ActionPlugin {
      * Actions added by this plugin.
      */
     default List<ActionHandler<? extends ActionRequest<?>, ? extends ActionResponse>> getActions() {
-        return emptyList();
+        return Collections.emptyList();
     }
     /**
      * Action filters added by this plugin.
      */
     default List<Class<? extends ActionFilter>> getActionFilters() {
-        return emptyList();
+        return Collections.emptyList();
     }
     /**
      * Rest handlers added by this plugin.
      */
     default List<Class<? extends RestHandler>> getRestHandlers() {
-        return emptyList();
+        return Collections.emptyList();
+    }
+
+    /**
+     * Returns headers which should be copied through rest requests on to internal requests.
+     */
+    default Collection<String> getRestHeaders() {
+        return Collections.emptyList();
     }
 
-    public static final class ActionHandler<Request extends ActionRequest<Request>, Response extends ActionResponse> {
+    final class ActionHandler<Request extends ActionRequest<Request>, Response extends ActionResponse> {
         private final GenericAction<Request, Response> action;
         private final Class<? extends TransportAction<Request, Response>> transportAction;
         private final Class<?>[] supportTransportActions;

+ 3 - 2
core/src/main/java/org/elasticsearch/rest/BaseRestHandler.java

@@ -24,14 +24,15 @@ import org.elasticsearch.common.component.AbstractComponent;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Setting.Property;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.plugins.ActionPlugin;
 
 /**
  * Base handler for REST requests.
  * <p>
  * This handler makes sure that the headers &amp; context of the handled {@link RestRequest requests} are copied over to
  * the transport requests executed by the associated client. While the context is fully copied over, not all the headers
- * are copied, but a selected few. It is possible to control what headers are copied over by registering them using
- * {@link org.elasticsearch.rest.RestController#registerRelevantHeaders(String...)}
+ * are copied, but a selected few. It is possible to control what headers are copied over by returning them in
+ * {@link ActionPlugin#getRestHeaders()}.
  */
 public abstract class BaseRestHandler extends AbstractComponent implements RestHandler {
     public static final Setting<Boolean> MULTI_ALLOW_EXPLICIT_INDEX =

+ 7 - 26
core/src/main/java/org/elasticsearch/rest/RestController.java

@@ -28,16 +28,17 @@ import org.elasticsearch.common.path.PathTrie;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.plugins.ActionPlugin;
 import org.elasticsearch.rest.support.RestUtils;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 
-import static java.util.Collections.emptySet;
 import static java.util.Collections.unmodifiableSet;
 import static org.elasticsearch.rest.RestStatus.BAD_REQUEST;
 import static org.elasticsearch.rest.RestStatus.OK;
@@ -55,13 +56,15 @@ public class RestController extends AbstractLifecycleComponent {
 
     private final RestHandlerFilter handlerFilter = new RestHandlerFilter();
 
-    private Set<String> relevantHeaders = emptySet();
+    /** Rest headers that are copied to internal requests made during a rest request. */
+    private final Set<String> headersToCopy;
 
     // non volatile since the assumption is that pre processors are registered on startup
     private RestFilter[] filters = new RestFilter[0];
 
-    public RestController(Settings settings) {
+    public RestController(Settings settings, Set<String> headersToCopy) {
         super(settings);
+        this.headersToCopy = headersToCopy;
     }
 
     @Override
@@ -79,28 +82,6 @@ public class RestController extends AbstractLifecycleComponent {
         }
     }
 
-    /**
-     * Controls which REST headers get copied over from a {@link org.elasticsearch.rest.RestRequest} to
-     * its corresponding {@link org.elasticsearch.transport.TransportRequest}(s).
-     *
-     * By default no headers get copied but it is possible to extend this behaviour via plugins by calling this method.
-     */
-    public synchronized void registerRelevantHeaders(String... headers) {
-        Set<String> newRelevantHeaders = new HashSet<>(relevantHeaders.size() + headers.length);
-        newRelevantHeaders.addAll(relevantHeaders);
-        Collections.addAll(newRelevantHeaders, headers);
-        relevantHeaders = unmodifiableSet(newRelevantHeaders);
-    }
-
-    /**
-     * Returns the REST headers that get copied over from a {@link org.elasticsearch.rest.RestRequest} to
-     * its corresponding {@link org.elasticsearch.transport.TransportRequest}(s).
-     * By default no headers get copied but it is possible to extend this behaviour via plugins by calling {@link #registerRelevantHeaders(String...)}.
-     */
-    public Set<String> relevantHeaders() {
-        return relevantHeaders;
-    }
-
     /**
      * Registers a pre processor to be executed before the rest request is actually handled.
      */
@@ -213,7 +194,7 @@ public class RestController extends AbstractLifecycleComponent {
             return;
         }
         try (ThreadContext.StoredContext t = threadContext.stashContext()) {
-            for (String key : relevantHeaders) {
+            for (String key : headersToCopy) {
                 String httpHeader = request.header(key);
                 if (httpHeader != null) {
                     threadContext.putHeader(key, httpHeader);

+ 2 - 1
core/src/test/java/org/elasticsearch/http/HttpServerTests.java

@@ -18,6 +18,7 @@
  */
 package org.elasticsearch.http;
 
+import java.util.Collections;
 import java.util.Map;
 
 import org.elasticsearch.common.breaker.CircuitBreaker;
@@ -59,7 +60,7 @@ public class HttpServerTests extends ESTestCase {
         inFlightRequestsBreaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
 
         HttpServerTransport httpServerTransport = new TestHttpServerTransport();
-        RestController restController = new RestController(settings);
+        RestController restController = new RestController(settings, Collections.emptySet());
         restController.registerHandler(RestRequest.Method.GET, "/",
             (request, channel, client) -> channel.sendResponse(
                 new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY)));

+ 10 - 46
core/src/test/java/org/elasticsearch/rest/RestControllerTests.java

@@ -19,6 +19,13 @@
 
 package org.elasticsearch.rest;
 
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
 import org.elasticsearch.client.node.NodeClient;
 import org.elasticsearch.common.logging.DeprecationLogger;
 import org.elasticsearch.common.settings.Settings;
@@ -26,16 +33,6 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.rest.FakeRestRequest;
 
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.TimeUnit;
-
-import static org.hamcrest.CoreMatchers.equalTo;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.doCallRealMethod;
@@ -44,41 +41,10 @@ import static org.mockito.Mockito.verify;
 
 public class RestControllerTests extends ESTestCase {
 
-    public void testRegisterRelevantHeaders() throws InterruptedException {
-
-        final RestController restController = new RestController(Settings.EMPTY);
-
-        int iterations = randomIntBetween(1, 5);
-
-        Set<String> headers = new HashSet<>();
-        ExecutorService executorService = Executors.newFixedThreadPool(iterations);
-        for (int i = 0; i < iterations; i++) {
-            int headersCount = randomInt(10);
-            final Set<String> newHeaders = new HashSet<>();
-            for (int j = 0; j < headersCount; j++) {
-                String usefulHeader = randomRealisticUnicodeOfLengthBetween(1, 30);
-                newHeaders.add(usefulHeader);
-            }
-            headers.addAll(newHeaders);
-
-            executorService.submit((Runnable) () -> restController.registerRelevantHeaders(newHeaders.toArray(new String[newHeaders.size()])));
-        }
-
-        executorService.shutdown();
-        assertThat(executorService.awaitTermination(1, TimeUnit.SECONDS), equalTo(true));
-        String[] relevantHeaders = restController.relevantHeaders().toArray(new String[restController.relevantHeaders().size()]);
-        assertThat(relevantHeaders.length, equalTo(headers.size()));
-
-        Arrays.sort(relevantHeaders);
-        String[] headersArray = new String[headers.size()];
-        headersArray = headers.toArray(headersArray);
-        Arrays.sort(headersArray);
-        assertThat(relevantHeaders, equalTo(headersArray));
-    }
-
     public void testApplyRelevantHeaders() throws Exception {
         final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
-        final RestController restController = new RestController(Settings.EMPTY) {
+        Set<String> headers = new HashSet<>(Arrays.asList("header.1", "header.2"));
+        final RestController restController = new RestController(Settings.EMPTY, headers) {
             @Override
             boolean checkRequestParameters(RestRequest request, RestChannel channel) {
                 return true;
@@ -89,11 +55,9 @@ public class RestControllerTests extends ESTestCase {
                 assertEquals("true", threadContext.getHeader("header.1"));
                 assertEquals("true", threadContext.getHeader("header.2"));
                 assertNull(threadContext.getHeader("header.3"));
-
             }
         };
         threadContext.putHeader("header.3", "true");
-        restController.registerRelevantHeaders("header.1", "header.2");
         Map<String, String> restHeaders = new HashMap<>();
         restHeaders.put("header.1", "true");
         restHeaders.put("header.2", "true");
@@ -105,7 +69,7 @@ public class RestControllerTests extends ESTestCase {
     }
 
     public void testCanTripCircuitBreaker() throws Exception {
-        RestController controller = new RestController(Settings.EMPTY);
+        RestController controller = new RestController(Settings.EMPTY, Collections.emptySet());
         // trip circuit breaker by default
         controller.registerHandler(RestRequest.Method.GET, "/trip", new FakeRestHandler(true));
         controller.registerHandler(RestRequest.Method.GET, "/do-not-trip", new FakeRestHandler(false));

+ 2 - 2
core/src/test/java/org/elasticsearch/rest/RestFilterChainTests.java

@@ -40,7 +40,7 @@ import static org.hamcrest.CoreMatchers.equalTo;
 public class RestFilterChainTests extends ESTestCase {
     public void testRestFilters() throws Exception {
 
-        RestController restController = new RestController(Settings.EMPTY);
+        RestController restController = new RestController(Settings.EMPTY, Collections.emptySet());
 
         int numFilters = randomInt(10);
         Set<Integer> orders = new HashSet<>(numFilters);
@@ -121,7 +121,7 @@ public class RestFilterChainTests extends ESTestCase {
             }
         });
 
-        RestController restController = new RestController(Settings.EMPTY);
+        RestController restController = new RestController(Settings.EMPTY, Collections.emptySet());
         restController.registerFilter(testFilter);
 
         restController.registerHandler(RestRequest.Method.GET, "/", new RestHandler() {

+ 2 - 1
core/src/test/java/org/elasticsearch/rest/action/cat/RestIndicesActionTests.java

@@ -58,6 +58,7 @@ import org.elasticsearch.test.ESTestCase;
 
 import java.nio.file.Path;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 
 import static java.util.Collections.emptyList;
@@ -70,7 +71,7 @@ public class RestIndicesActionTests extends ESTestCase {
 
     public void testBuildTable() {
         final Settings settings = Settings.EMPTY;
-        final RestController restController = new RestController(settings);
+        final RestController restController = new RestController(settings, Collections.emptySet());
         final RestIndicesAction action = new RestIndicesAction(settings, restController, new IndexNameExpressionResolver(settings));
 
         // build a (semi-)random table

+ 2 - 1
core/src/test/java/org/elasticsearch/rest/action/cat/RestRecoveryActionTests.java

@@ -37,6 +37,7 @@ import org.elasticsearch.snapshots.Snapshot;
 import org.elasticsearch.test.ESTestCase;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Locale;
@@ -50,7 +51,7 @@ public class RestRecoveryActionTests extends ESTestCase {
 
     public void testRestRecoveryAction() {
         final Settings settings = Settings.EMPTY;
-        final RestController restController = new RestController(settings);
+        final RestController restController = new RestController(settings, Collections.emptySet());
         final RestRecoveryAction action = new RestRecoveryAction(settings, restController, restController);
         final int totalShards = randomIntBetween(1, 32);
         final int successfulShards = Math.max(0, totalShards - randomIntBetween(1, 2));

+ 17 - 14
qa/smoke-test-http/src/test/java/org/elasticsearch/http/ContextAndHeaderTransportIT.java

@@ -46,7 +46,6 @@ import org.elasticsearch.index.query.TermsQueryBuilder;
 import org.elasticsearch.indices.TermsLookup;
 import org.elasticsearch.plugins.ActionPlugin;
 import org.elasticsearch.plugins.Plugin;
-import org.elasticsearch.rest.RestController;
 import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.junit.After;
@@ -75,7 +74,7 @@ import static org.hamcrest.Matchers.is;
 @ClusterScope(scope = SUITE)
 public class ContextAndHeaderTransportIT extends HttpSmokeTestCase {
     private static final List<RequestAndHeaders> requests =  new CopyOnWriteArrayList<>();
-    private String randomHeaderKey = randomAsciiOfLength(10);
+    private static final String CUSTOM_HEADER = "SomeCustomHeader";
     private String randomHeaderValue = randomAsciiOfLength(20);
     private String queryIndex = "query-" + randomAsciiOfLength(10).toLowerCase(Locale.ROOT);
     private String lookupIndex = "lookup-" + randomAsciiOfLength(10).toLowerCase(Locale.ROOT);
@@ -97,6 +96,7 @@ public class ContextAndHeaderTransportIT extends HttpSmokeTestCase {
     protected Collection<Class<? extends Plugin>> nodePlugins() {
         ArrayList<Class<? extends Plugin>> plugins = new ArrayList<>(super.nodePlugins());
         plugins.add(ActionLoggingPlugin.class);
+        plugins.add(CustomHeadersPlugin.class);
         return plugins;
     }
 
@@ -219,21 +219,18 @@ public class ContextAndHeaderTransportIT extends HttpSmokeTestCase {
     }
 
     public void testThatRelevantHttpHeadersBecomeRequestHeaders() throws Exception {
-        String relevantHeaderName = "relevant_" + randomHeaderKey;
-        for (RestController restController : internalCluster().getInstances(RestController.class)) {
-            restController.registerRelevantHeaders(relevantHeaderName);
-        }
+        final String IRRELEVANT_HEADER = "SomeIrrelevantHeader";
 
         try (Response response = getRestClient().performRequest(
                 "GET", "/" + queryIndex + "/_search",
-                new BasicHeader(randomHeaderKey, randomHeaderValue), new BasicHeader(relevantHeaderName, randomHeaderValue))) {
+                new BasicHeader(CUSTOM_HEADER, randomHeaderValue), new BasicHeader(IRRELEVANT_HEADER, randomHeaderValue))) {
             assertThat(response.getStatusLine().getStatusCode(), equalTo(200));
             List<RequestAndHeaders> searchRequests = getRequests(SearchRequest.class);
             assertThat(searchRequests, hasSize(greaterThan(0)));
             for (RequestAndHeaders requestAndHeaders : searchRequests) {
-                assertThat(requestAndHeaders.headers.containsKey(relevantHeaderName), is(true));
+                assertThat(requestAndHeaders.headers.containsKey(CUSTOM_HEADER), is(true));
                 // was not specified, thus is not included
-                assertThat(requestAndHeaders.headers.containsKey(randomHeaderKey), is(false));
+                assertThat(requestAndHeaders.headers.containsKey(IRRELEVANT_HEADER), is(false));
             }
         }
     }
@@ -273,21 +270,21 @@ public class ContextAndHeaderTransportIT extends HttpSmokeTestCase {
     }
 
     private void assertRequestContainsHeader(ActionRequest request, Map<String, String> context) {
-        String msg = String.format(Locale.ROOT, "Expected header %s to be in request %s", randomHeaderKey, request.getClass().getName());
+        String msg = String.format(Locale.ROOT, "Expected header %s to be in request %s", CUSTOM_HEADER, request.getClass().getName());
         if (request instanceof IndexRequest) {
             IndexRequest indexRequest = (IndexRequest) request;
-            msg = String.format(Locale.ROOT, "Expected header %s to be in index request %s/%s/%s", randomHeaderKey,
+            msg = String.format(Locale.ROOT, "Expected header %s to be in index request %s/%s/%s", CUSTOM_HEADER,
                 indexRequest.index(), indexRequest.type(), indexRequest.id());
         }
-        assertThat(msg, context.containsKey(randomHeaderKey), is(true));
-        assertThat(context.get(randomHeaderKey).toString(), is(randomHeaderValue));
+        assertThat(msg, context.containsKey(CUSTOM_HEADER), is(true));
+        assertThat(context.get(CUSTOM_HEADER).toString(), is(randomHeaderValue));
     }
 
     /**
      * a transport client that adds our random header
      */
     private Client transportClient() {
-        return internalCluster().transportClient().filterWithHeader(Collections.singletonMap(randomHeaderKey, randomHeaderValue));
+        return internalCluster().transportClient().filterWithHeader(Collections.singletonMap(CUSTOM_HEADER, randomHeaderValue));
     }
 
     public static class ActionLoggingPlugin extends Plugin implements ActionPlugin {
@@ -347,4 +344,10 @@ public class ContextAndHeaderTransportIT extends HttpSmokeTestCase {
             this.request = request;
         }
     }
+
+    public static class CustomHeadersPlugin extends Plugin implements ActionPlugin {
+        public Collection<String> getRestHeaders() {
+            return Collections.singleton(CUSTOM_HEADER);
+        }
+    }
 }