|  | @@ -13,6 +13,7 @@ import org.elasticsearch.action.search.SearchResponseSections;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.search.ShardSearchFailure;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.unit.TimeValue;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.util.concurrent.AtomicArray;
 | 
	
		
			
				|  |  | +import org.elasticsearch.common.util.concurrent.ThreadContext;
 | 
	
		
			
				|  |  |  import org.elasticsearch.search.SearchHits;
 | 
	
		
			
				|  |  |  import org.elasticsearch.search.aggregations.InternalAggregation;
 | 
	
		
			
				|  |  |  import org.elasticsearch.search.aggregations.InternalAggregations;
 | 
	
	
		
			
				|  | @@ -21,11 +22,13 @@ import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import java.util.ArrayList;
 | 
	
		
			
				|  |  |  import java.util.List;
 | 
	
		
			
				|  |  | +import java.util.Map;
 | 
	
		
			
				|  |  |  import java.util.function.Supplier;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import static java.util.Collections.singletonList;
 | 
	
		
			
				|  |  |  import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
 | 
	
		
			
				|  |  |  import static org.elasticsearch.search.aggregations.InternalAggregations.topLevelReduce;
 | 
	
		
			
				|  |  | +import static org.elasticsearch.xpack.search.AsyncSearchIndexService.restoreResponseHeadersContext;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  /**
 | 
	
		
			
				|  |  |   * A mutable search response that allows to update and create partial response synchronously.
 | 
	
	
		
			
				|  | @@ -39,12 +42,14 @@ class MutableSearchResponse {
 | 
	
		
			
				|  |  |      private final Clusters clusters;
 | 
	
		
			
				|  |  |      private final AtomicArray<ShardSearchFailure> shardFailures;
 | 
	
		
			
				|  |  |      private final Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier;
 | 
	
		
			
				|  |  | +    private final ThreadContext threadContext;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      private boolean isPartial;
 | 
	
		
			
				|  |  |      private boolean isFinalReduce;
 | 
	
		
			
				|  |  |      private int successfulShards;
 | 
	
		
			
				|  |  |      private SearchResponseSections sections;
 | 
	
		
			
				|  |  |      private ElasticsearchException failure;
 | 
	
		
			
				|  |  | +    private Map<String, List<String>> responseHeaders;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      private boolean frozen;
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -55,15 +60,20 @@ class MutableSearchResponse {
 | 
	
		
			
				|  |  |       * @param skippedShards The number of skipped shards, or -1 to indicate a failure.
 | 
	
		
			
				|  |  |       * @param clusters The remote clusters statistics.
 | 
	
		
			
				|  |  |       * @param aggReduceContextSupplier A supplier to run final reduce on partial aggregations.
 | 
	
		
			
				|  |  | +     * @param threadContext The thread context to retrieve the final response headers.
 | 
	
		
			
				|  |  |       */
 | 
	
		
			
				|  |  | -    MutableSearchResponse(int totalShards, int skippedShards, Clusters clusters,
 | 
	
		
			
				|  |  | -            Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) {
 | 
	
		
			
				|  |  | +    MutableSearchResponse(int totalShards,
 | 
	
		
			
				|  |  | +                          int skippedShards,
 | 
	
		
			
				|  |  | +                          Clusters clusters,
 | 
	
		
			
				|  |  | +                          Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier,
 | 
	
		
			
				|  |  | +                          ThreadContext threadContext) {
 | 
	
		
			
				|  |  |          this.totalShards = totalShards;
 | 
	
		
			
				|  |  |          this.skippedShards = skippedShards;
 | 
	
		
			
				|  |  |          this.clusters = clusters;
 | 
	
		
			
				|  |  |          this.aggReduceContextSupplier = aggReduceContextSupplier;
 | 
	
		
			
				|  |  |          this.shardFailures = totalShards == -1 ? null : new AtomicArray<>(totalShards-skippedShards);
 | 
	
		
			
				|  |  |          this.isPartial = true;
 | 
	
		
			
				|  |  | +        this.threadContext = threadContext;
 | 
	
		
			
				|  |  |          this.sections = totalShards == -1 ? null : new InternalSearchResponse(
 | 
	
		
			
				|  |  |              new SearchHits(SearchHits.EMPTY, new TotalHits(0, GREATER_THAN_OR_EQUAL_TO), Float.NaN),
 | 
	
		
			
				|  |  |              null, null, null, false, null, 0);
 | 
	
	
		
			
				|  | @@ -93,6 +103,8 @@ class MutableSearchResponse {
 | 
	
		
			
				|  |  |       */
 | 
	
		
			
				|  |  |      synchronized void updateFinalResponse(int successfulShards, SearchResponseSections newSections) {
 | 
	
		
			
				|  |  |          failIfFrozen();
 | 
	
		
			
				|  |  | +        // copy the response headers from the current context
 | 
	
		
			
				|  |  | +        this.responseHeaders = threadContext.getResponseHeaders();
 | 
	
		
			
				|  |  |          this.successfulShards = successfulShards;
 | 
	
		
			
				|  |  |          this.sections = newSections;
 | 
	
		
			
				|  |  |          this.isPartial = false;
 | 
	
	
		
			
				|  | @@ -106,6 +118,8 @@ class MutableSearchResponse {
 | 
	
		
			
				|  |  |       */
 | 
	
		
			
				|  |  |      synchronized void updateWithFailure(Exception exc) {
 | 
	
		
			
				|  |  |          failIfFrozen();
 | 
	
		
			
				|  |  | +        // copy the response headers from the current context
 | 
	
		
			
				|  |  | +        this.responseHeaders = threadContext.getResponseHeaders();
 | 
	
		
			
				|  |  |          this.isPartial = true;
 | 
	
		
			
				|  |  |          this.failure = ElasticsearchException.guessRootCauses(exc)[0];
 | 
	
		
			
				|  |  |          this.frozen = true;
 | 
	
	
		
			
				|  | @@ -146,6 +160,20 @@ class MutableSearchResponse {
 | 
	
		
			
				|  |  |              frozen == false, task.getStartTime(), expirationTime);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    /**
 | 
	
		
			
				|  |  | +     * Creates an {@link AsyncSearchResponse} based on the current state of the mutable response.
 | 
	
		
			
				|  |  | +     * This method also restores the response headers in the current thread context if the final response is available.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    synchronized AsyncSearchResponse toAsyncSearchResponseWithHeaders(AsyncSearchTask task, long expirationTime) {
 | 
	
		
			
				|  |  | +        AsyncSearchResponse resp = toAsyncSearchResponse(task, expirationTime);
 | 
	
		
			
				|  |  | +        if (responseHeaders != null) {
 | 
	
		
			
				|  |  | +            restoreResponseHeadersContext(threadContext, responseHeaders);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        return resp;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      private void failIfFrozen() {
 | 
	
		
			
				|  |  |          if (frozen) {
 | 
	
		
			
				|  |  |              throw new IllegalStateException("invalid update received after the completion of the request");
 |