瀏覽代碼

Allow preserving headers on context stash (#94680)

This PR adds functionality to stash a thread context but preserve
specific request headers. The
[need](https://github.com/elastic/elasticsearch/pull/94665#discussion_r1146126754)
for this method most recently came up in the context of work on the
remote cluster security model.
Nikolaj Volgushev 2 年之前
父節點
當前提交
3fe122ba2d

+ 5 - 0
docs/changelog/94680.yaml

@@ -0,0 +1,5 @@
+pr: 94680
+summary: Allow preserving specific headers on thread context stash
+area: Infra/Core
+type: enhancement
+issues: []

+ 32 - 7
server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java

@@ -29,6 +29,7 @@ import java.nio.charset.StandardCharsets;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
@@ -102,17 +103,28 @@ public final class ThreadContext implements Writeable {
      * @return a stored context that will restore the current context to its state at the point this method was called
      */
     public StoredContext stashContext() {
+        return stashContextPreservingRequestHeaders(Collections.emptySet());
+    }
+
+    /**
+     * Just like {@link #stashContext()} but preserves request headers specified via {@code requestHeaders},
+     * if these exist in the context before stashing.
+     */
+    public StoredContext stashContextPreservingRequestHeaders(Set<String> requestHeaders) {
         final ThreadContextStruct context = threadLocal.get();
 
         /*
+         * When the context is stashed, it should be empty, except for headers that were specified to be preserved via `requestHeaders`
+         * and a set of default headers such as X-Opaque-ID, which are always copied (specified via `Task.HEADERS_TO_COPY`).
+         *
          * X-Opaque-ID should be preserved in a threadContext in order to propagate this across threads.
          * This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user.
-         * The same is applied to Task.TRACE_ID.
-         * Otherwise when context is stashed, it should be empty.
+         * The same is applied to Task.TRACE_ID and other values specified in `Task.HEADERS_TO_COPY`.
          */
+        final Set<String> requestHeadersToCopy = getRequestHeadersToCopy(requestHeaders);
         boolean hasHeadersToCopy = false;
         if (context.requestHeaders.isEmpty() == false) {
-            for (String header : HEADERS_TO_COPY) {
+            for (String header : requestHeadersToCopy) {
                 if (context.requestHeaders.containsKey(header)) {
                     hasHeadersToCopy = true;
                     break;
@@ -124,7 +136,7 @@ public final class ThreadContext implements Writeable {
 
         ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT;
         if (hasHeadersToCopy) {
-            Map<String, String> copiedHeaders = getHeadersToCopy(context);
+            Map<String, String> copiedHeaders = getHeadersPresentInContext(context, requestHeadersToCopy);
             threadContextStruct = DEFAULT_CONTEXT.putHeaders(copiedHeaders);
         }
         if (hasTransientHeadersToCopy) {
@@ -140,6 +152,10 @@ public final class ThreadContext implements Writeable {
         return storedOriginalContext(context);
     }
 
+    public StoredContext stashContextPreservingRequestHeaders(final String... requestHeaders) {
+        return stashContextPreservingRequestHeaders(Set.of(requestHeaders));
+    }
+
     /**
      * When using a {@link org.elasticsearch.tracing.Tracer} to capture activity in Elasticsearch, when a parent span is already
      * in progress, it is necessary to start a new context before beginning a child span. This method creates a context,
@@ -232,9 +248,18 @@ public final class ThreadContext implements Writeable {
         return () -> threadLocal.set(originalContext);
     }
 
-    private static Map<String, String> getHeadersToCopy(ThreadContextStruct context) {
-        Map<String, String> map = Maps.newMapWithExpectedSize(HEADERS_TO_COPY.size());
-        for (String header : HEADERS_TO_COPY) {
+    private static Set<String> getRequestHeadersToCopy(Set<String> requestHeaders) {
+        if (requestHeaders.isEmpty()) {
+            return HEADERS_TO_COPY;
+        }
+        final Set<String> allRequestHeadersToCopy = new HashSet<>(requestHeaders);
+        allRequestHeadersToCopy.addAll(HEADERS_TO_COPY);
+        return Set.copyOf(allRequestHeadersToCopy);
+    }
+
+    private static Map<String, String> getHeadersPresentInContext(ThreadContextStruct context, Set<String> headers) {
+        Map<String, String> map = Maps.newMapWithExpectedSize(headers.size());
+        for (String header : headers) {
             final String value = context.requestHeaders.get(header);
             if (value != null) {
                 map.put(header, value);

+ 49 - 0
server/src/test/java/org/elasticsearch/common/util/concurrent/ThreadContextTests.java

@@ -27,6 +27,7 @@ import java.util.stream.IntStream;
 import java.util.stream.Stream;
 
 import static com.carrotsearch.randomizedtesting.RandomizedTest.randomAsciiLettersOfLengthBetween;
+import static org.elasticsearch.tasks.Task.HEADERS_TO_COPY;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasItem;
@@ -57,6 +58,54 @@ public class ThreadContextTests extends ESTestCase {
         assertEquals("1", threadContext.getHeader("default"));
     }
 
+    public void testStashContextPreservesDefaultHeadersToCopy() {
+        for (String header : HEADERS_TO_COPY) {
+            ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
+            threadContext.putHeader(header, "bar");
+            try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
+                assertEquals("bar", threadContext.getHeader(header));
+            }
+        }
+    }
+
+    public void testStashContextPreservingRequestHeaders() {
+        Settings build = Settings.builder().put("request.headers.default", "1").build();
+        ThreadContext threadContext = new ThreadContext(build);
+        threadContext.putHeader("foo", "bar");
+        threadContext.putHeader("bar", "foo");
+        threadContext.putTransient("ctx.foo", 1);
+        assertEquals("bar", threadContext.getHeader("foo"));
+        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
+        assertEquals("1", threadContext.getHeader("default"));
+        try (ThreadContext.StoredContext ignored = threadContext.stashContextPreservingRequestHeaders("foo", "ctx.foo", "missing")) {
+            assertEquals("bar", threadContext.getHeader("foo"));
+            // only request headers preserved, not transient
+            assertNull(threadContext.getTransient("ctx.foo"));
+            assertNull(threadContext.getHeader("bar"));
+            assertEquals("1", threadContext.getHeader("default"));
+            assertNull(threadContext.getHeader("missing"));
+        }
+
+        assertEquals("bar", threadContext.getHeader("foo"));
+        assertEquals("foo", threadContext.getHeader("bar"));
+        assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo"));
+        assertEquals("1", threadContext.getHeader("default"));
+    }
+
+    public void testStashContextPreservingHeadersWithDefaultHeadersToCopy() {
+        for (String header : HEADERS_TO_COPY) {
+            ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
+            threadContext.putHeader(header, "bar");
+            try (ThreadContext.StoredContext ignored = threadContext.stashContextPreservingRequestHeaders()) {
+                assertEquals("bar", threadContext.getHeader(header));
+            }
+            // Also works if we pass it explicitly
+            try (ThreadContext.StoredContext ignored = threadContext.stashContextPreservingRequestHeaders(header)) {
+                assertEquals("bar", threadContext.getHeader(header));
+            }
+        }
+    }
+
     public void testNewContextWithClearedTransients() {
         ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
         threadContext.putTransient("foo", "bar");