|
@@ -29,6 +29,7 @@ import java.nio.charset.StandardCharsets;
|
|
|
import java.util.Collection;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
|
+import java.util.HashSet;
|
|
|
import java.util.LinkedHashSet;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
@@ -102,17 +103,28 @@ public final class ThreadContext implements Writeable {
|
|
|
* @return a stored context that will restore the current context to its state at the point this method was called
|
|
|
*/
|
|
|
public StoredContext stashContext() {
|
|
|
+ return stashContextPreservingRequestHeaders(Collections.emptySet());
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Just like {@link #stashContext()} but preserves request headers specified via {@code requestHeaders},
|
|
|
+ * if these exist in the context before stashing.
|
|
|
+ */
|
|
|
+ public StoredContext stashContextPreservingRequestHeaders(Set<String> requestHeaders) {
|
|
|
final ThreadContextStruct context = threadLocal.get();
|
|
|
|
|
|
/*
|
|
|
+ * When the context is stashed, it should be empty, except for headers that were specified to be preserved via `requestHeaders`
|
|
|
+ * and a set of default headers such as X-Opaque-ID, which are always copied (specified via `Task.HEADERS_TO_COPY`).
|
|
|
+ *
|
|
|
* X-Opaque-ID should be preserved in a threadContext in order to propagate this across threads.
|
|
|
* This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user.
|
|
|
- * The same is applied to Task.TRACE_ID.
|
|
|
- * Otherwise when context is stashed, it should be empty.
|
|
|
+ * The same is applied to Task.TRACE_ID and other values specified in `Task.HEADERS_TO_COPY`.
|
|
|
*/
|
|
|
+ final Set<String> requestHeadersToCopy = getRequestHeadersToCopy(requestHeaders);
|
|
|
boolean hasHeadersToCopy = false;
|
|
|
if (context.requestHeaders.isEmpty() == false) {
|
|
|
- for (String header : HEADERS_TO_COPY) {
|
|
|
+ for (String header : requestHeadersToCopy) {
|
|
|
if (context.requestHeaders.containsKey(header)) {
|
|
|
hasHeadersToCopy = true;
|
|
|
break;
|
|
@@ -124,7 +136,7 @@ public final class ThreadContext implements Writeable {
|
|
|
|
|
|
ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT;
|
|
|
if (hasHeadersToCopy) {
|
|
|
- Map<String, String> copiedHeaders = getHeadersToCopy(context);
|
|
|
+ Map<String, String> copiedHeaders = getHeadersPresentInContext(context, requestHeadersToCopy);
|
|
|
threadContextStruct = DEFAULT_CONTEXT.putHeaders(copiedHeaders);
|
|
|
}
|
|
|
if (hasTransientHeadersToCopy) {
|
|
@@ -140,6 +152,10 @@ public final class ThreadContext implements Writeable {
|
|
|
return storedOriginalContext(context);
|
|
|
}
|
|
|
|
|
|
+ public StoredContext stashContextPreservingRequestHeaders(final String... requestHeaders) {
|
|
|
+ return stashContextPreservingRequestHeaders(Set.of(requestHeaders));
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* When using a {@link org.elasticsearch.tracing.Tracer} to capture activity in Elasticsearch, when a parent span is already
|
|
|
* in progress, it is necessary to start a new context before beginning a child span. This method creates a context,
|
|
@@ -232,9 +248,18 @@ public final class ThreadContext implements Writeable {
|
|
|
return () -> threadLocal.set(originalContext);
|
|
|
}
|
|
|
|
|
|
- private static Map<String, String> getHeadersToCopy(ThreadContextStruct context) {
|
|
|
- Map<String, String> map = Maps.newMapWithExpectedSize(HEADERS_TO_COPY.size());
|
|
|
- for (String header : HEADERS_TO_COPY) {
|
|
|
+ private static Set<String> getRequestHeadersToCopy(Set<String> requestHeaders) {
|
|
|
+ if (requestHeaders.isEmpty()) {
|
|
|
+ return HEADERS_TO_COPY;
|
|
|
+ }
|
|
|
+ final Set<String> allRequestHeadersToCopy = new HashSet<>(requestHeaders);
|
|
|
+ allRequestHeadersToCopy.addAll(HEADERS_TO_COPY);
|
|
|
+ return Set.copyOf(allRequestHeadersToCopy);
|
|
|
+ }
|
|
|
+
|
|
|
+ private static Map<String, String> getHeadersPresentInContext(ThreadContextStruct context, Set<String> headers) {
|
|
|
+ Map<String, String> map = Maps.newMapWithExpectedSize(headers.size());
|
|
|
+ for (String header : headers) {
|
|
|
final String value = context.requestHeaders.get(header);
|
|
|
if (value != null) {
|
|
|
map.put(header, value);
|