Explorar o código

Implement Exitable DirectoryReader (#52822)

Implement an Exitable DirectoryReader that wraps the original
DirectoryReader so that when a search task is cancelled the
DirectoryReaders also stop their work fast. This is usuful for
expensive operations like wilcard/prefix queries where the
DirectoryReaders can spend lots of time and consume resources,
as previously their work wouldn't stop even though the original
search task was cancelled (e.g. because of timeout or dropped client
connection).
Marios Trivyzas %!s(int64=5) %!d(string=hai) anos
pai
achega
67acaf61f3

+ 1 - 1
server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java

@@ -155,7 +155,7 @@ final class DefaultSearchContext extends SearchContext {
     DefaultSearchContext(long id, ShardSearchRequest request, SearchShardTarget shardTarget,
                          Engine.Searcher engineSearcher, ClusterService clusterService, IndexService indexService,
                          IndexShard indexShard, BigArrays bigArrays, LongSupplier relativeTimeSupplier, TimeValue timeout,
-                         FetchPhase fetchPhase) {
+                         FetchPhase fetchPhase) throws IOException {
         this.id = id;
         this.request = request;
         this.fetchPhase = fetchPhase;

+ 4 - 4
server/src/main/java/org/elasticsearch/search/SearchService.java

@@ -398,7 +398,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
             }, listener::onFailure));
     }
 
-    private void onMatchNoDocs(SearchRewriteContext rewriteContext, ActionListener<SearchPhaseResult> listener) {
+    private void onMatchNoDocs(SearchRewriteContext rewriteContext, ActionListener<SearchPhaseResult> listener) throws IOException {
         // creates a lightweight search context that we use to inform context listeners
         // before closing
         SearchContext searchContext = createSearchContext(rewriteContext, defaultSearchTimeout);
@@ -609,7 +609,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         }
     }
 
-    final SearchContext createAndPutContext(SearchRewriteContext rewriteContext) {
+    final SearchContext createAndPutContext(SearchRewriteContext rewriteContext) throws IOException {
         SearchContext context = createContext(rewriteContext);
         onNewContext(context);
         boolean success = false;
@@ -644,7 +644,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         }
     }
 
-    final SearchContext createContext(SearchRewriteContext rewriteContext) {
+    final SearchContext createContext(SearchRewriteContext rewriteContext) throws IOException {
         final DefaultSearchContext context = createSearchContext(rewriteContext, defaultSearchTimeout);
         try {
             if (rewriteContext.request != null && openScrollContexts.get() >= maxOpenScrollContext) {
@@ -695,7 +695,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         return createSearchContext(rewriteContext.wrapSearcher(), timeout);
     }
 
-    private DefaultSearchContext createSearchContext(SearchRewriteContext rewriteContext, TimeValue timeout) {
+    private DefaultSearchContext createSearchContext(SearchRewriteContext rewriteContext, TimeValue timeout) throws IOException {
         boolean success = false;
         try {
             final ShardSearchRequest request = rewriteContext.request;

+ 63 - 17
server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java

@@ -62,7 +62,9 @@ import org.elasticsearch.search.query.QuerySearchResult;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Objects;
 import java.util.Set;
 
 /**
@@ -77,13 +79,26 @@ public class ContextIndexSearcher extends IndexSearcher {
 
     private AggregatedDfs aggregatedDfs;
     private QueryProfiler profiler;
-    private Runnable checkCancelled;
+    private MutableQueryTimeout cancellable;
 
-    public ContextIndexSearcher(IndexReader reader, Similarity similarity, QueryCache queryCache, QueryCachingPolicy queryCachingPolicy) {
-        super(reader);
+    public ContextIndexSearcher(IndexReader reader, Similarity similarity,
+                                QueryCache queryCache, QueryCachingPolicy queryCachingPolicy) throws IOException {
+        this(reader, similarity, queryCache, queryCachingPolicy, new MutableQueryTimeout());
+    }
+
+    // TODO: Make the 2nd constructor private so that the IndexReader is always wrapped.
+    // Some issues must be fixed:
+    //   - regarding tests deriving from AggregatorTestCase and more specifically the use of searchAndReduce and
+    //     the ShardSearcher sub-searchers.
+    //   - tests that use a MultiReader
+    public ContextIndexSearcher(IndexReader reader, Similarity similarity,
+                                QueryCache queryCache, QueryCachingPolicy queryCachingPolicy,
+                                MutableQueryTimeout cancellable) throws IOException {
+        super(cancellable != null ? new ExitableDirectoryReader((DirectoryReader) reader, cancellable) : reader);
         setSimilarity(similarity);
         setQueryCache(queryCache);
         setQueryCachingPolicy(queryCachingPolicy);
+        this.cancellable = cancellable != null ? cancellable : new MutableQueryTimeout();
     }
 
     public void setProfiler(QueryProfiler profiler) {
@@ -91,11 +106,19 @@ public class ContextIndexSearcher extends IndexSearcher {
     }
 
     /**
-     * Set a {@link Runnable} that will be run on a regular basis while
-     * collecting documents.
+     * Add a {@link Runnable} that will be run on a regular basis while accessing documents in the
+     * DirectoryReader but also while collecting them and check for query cancellation or timeout.
+     */
+    public Runnable addQueryCancellation(Runnable action) {
+        return this.cancellable.add(action);
+    }
+
+    /**
+     * Remove a {@link Runnable} that checks for query cancellation or timeout
+     * which is called while accessing documents in the DirectoryReader but also while collecting them.
      */
-    public void setCheckCancelled(Runnable checkCancelled) {
-        this.checkCancelled = checkCancelled;
+    public void removeQueryCancellation(Runnable action) {
+        this.cancellable.remove(action);
     }
 
     public void setAggregatedDfs(AggregatedDfs aggregatedDfs) {
@@ -139,12 +162,6 @@ public class ContextIndexSearcher extends IndexSearcher {
         }
     }
 
-    private void checkCancelled() {
-        if (checkCancelled != null) {
-            checkCancelled.run();
-        }
-    }
-
     @SuppressWarnings({"unchecked", "rawtypes"})
     public void search(List<LeafReaderContext> leaves, Weight weight, CollectorManager manager,
             QuerySearchResult result, DocValueFormat[] formats, TotalHits totalHits) throws IOException {
@@ -180,7 +197,7 @@ public class ContextIndexSearcher extends IndexSearcher {
      * the provided <code>ctx</code>.
      */
     private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collector) throws IOException {
-        checkCancelled();
+        cancellable.checkCancelled();
         weight = wrapWeight(weight);
         final LeafCollector leafCollector;
         try {
@@ -208,7 +225,7 @@ public class ContextIndexSearcher extends IndexSearcher {
             if (scorer != null) {
                 try {
                     intersectScorerAndBitSet(scorer, liveDocsBitSet, leafCollector,
-                        checkCancelled == null ? () -> { } : checkCancelled);
+                            this.cancellable.isEnabled() ? cancellable::checkCancelled: () -> {});
                 } catch (CollectionTerminatedException e) {
                     // collection was terminated prematurely
                     // continue with the following leaf
@@ -218,7 +235,7 @@ public class ContextIndexSearcher extends IndexSearcher {
     }
 
     private Weight wrapWeight(Weight weight) {
-        if (checkCancelled != null) {
+        if (cancellable.isEnabled()) {
             return new Weight(weight.getQuery()) {
                 @Override
                 public void extractTerms(Set<Term> terms) {
@@ -244,7 +261,7 @@ public class ContextIndexSearcher extends IndexSearcher {
                 public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
                     BulkScorer in = weight.bulkScorer(context);
                     if (in != null) {
-                        return new CancellableBulkScorer(in, checkCancelled);
+                        return new CancellableBulkScorer(in, cancellable::checkCancelled);
                     } else {
                         return null;
                     }
@@ -320,4 +337,33 @@ public class ContextIndexSearcher extends IndexSearcher {
         assert reader instanceof DirectoryReader : "expected an instance of DirectoryReader, got " + reader.getClass();
         return (DirectoryReader) reader;
     }
+
+    private static class MutableQueryTimeout implements ExitableDirectoryReader.QueryCancellation {
+
+        private final Set<Runnable> runnables = new HashSet<>();
+
+        private Runnable add(Runnable action) {
+            Objects.requireNonNull(action, "cancellation runnable should not be null");
+            if (runnables.add(action) == false) {
+                throw new IllegalArgumentException("Cancellation runnable already added");
+            }
+            return action;
+        }
+
+        private void remove(Runnable action) {
+            runnables.remove(action);
+        }
+
+        @Override
+        public void checkCancelled() {
+            for (Runnable timeout : runnables) {
+                timeout.run();
+            }
+        }
+
+        @Override
+        public boolean isEnabled() {
+            return runnables.isEmpty() == false;
+        }
+    }
 }

+ 289 - 0
server/src/main/java/org/elasticsearch/search/internal/ExitableDirectoryReader.java

@@ -0,0 +1,289 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.search.internal;
+
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FilterDirectoryReader;
+import org.apache.lucene.index.FilterLeafReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.PointValues;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.suggest.document.CompletionTerms;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.automaton.CompiledAutomaton;
+
+import java.io.IOException;
+
+/**
+ * Wraps an {@link IndexReader} with a {@link QueryCancellation}
+ * which checks for cancelled or timed-out query.
+ */
+class ExitableDirectoryReader extends FilterDirectoryReader {
+
+    /**
+     * Used to check if query cancellation is actually enabled
+     * and if so use it to check if the query is cancelled or timed-out.
+     */
+    interface QueryCancellation {
+
+        /**
+         * Used to prevent unnecessary checks for cancellation
+         * @return true if query cancellation is enabled
+         */
+        boolean isEnabled();
+
+        /**
+         * Call to check if the query is cancelled or timed-out.
+         * If so a {@link RuntimeException} is thrown
+         */
+        void checkCancelled();
+    }
+
+    ExitableDirectoryReader(DirectoryReader in, QueryCancellation queryCancellation) throws IOException {
+        super(in, new SubReaderWrapper() {
+            @Override
+            public LeafReader wrap(LeafReader reader) {
+                return new ExitableLeafReader(reader, queryCancellation);
+            }
+        });
+    }
+
+    @Override
+    protected DirectoryReader doWrapDirectoryReader(DirectoryReader in) {
+        throw new UnsupportedOperationException("doWrapDirectoryReader() should never be invoked");
+    }
+
+    @Override
+    public CacheHelper getReaderCacheHelper() {
+        return in.getReaderCacheHelper();
+    }
+    /**
+     * Wraps a {@link FilterLeafReader} with a {@link QueryCancellation}.
+     */
+    static class ExitableLeafReader extends FilterLeafReader {
+
+        private final QueryCancellation queryCancellation;
+
+        private ExitableLeafReader(LeafReader leafReader, QueryCancellation queryCancellation) {
+            super(leafReader);
+            this.queryCancellation = queryCancellation;
+        }
+
+        @Override
+        public PointValues getPointValues(String field) throws IOException {
+            final PointValues pointValues = in.getPointValues(field);
+            if (pointValues == null) {
+                return null;
+            }
+            return queryCancellation.isEnabled() ? new ExitablePointValues(pointValues, queryCancellation) : pointValues;
+        }
+
+        @Override
+        public Terms terms(String field) throws IOException {
+            Terms terms = in.terms(field);
+            if (terms == null) {
+                return null;
+            }
+            // If we have a suggest CompletionQuery then the CompletionWeight#bulkScorer() will check that
+            // the terms are instanceof CompletionTerms (not generic FilterTerms) and will throw an exception
+            // if that's not the case.
+            return (queryCancellation.isEnabled() && terms instanceof CompletionTerms == false) ?
+                    new ExitableTerms(terms, queryCancellation) : terms;
+        }
+
+        @Override
+        public CacheHelper getCoreCacheHelper() {
+            return in.getCoreCacheHelper();
+        }
+
+        @Override
+        public CacheHelper getReaderCacheHelper() {
+            return in.getReaderCacheHelper();
+        }
+    }
+
+    /**
+     * Wrapper class for {@link FilterLeafReader.FilterTerms} that check for query cancellation or timeout.
+     */
+    static class ExitableTerms extends FilterLeafReader.FilterTerms {
+
+        private final QueryCancellation queryCancellation;
+
+        private ExitableTerms(Terms terms, QueryCancellation queryCancellation) {
+            super(terms);
+            this.queryCancellation = queryCancellation;
+        }
+
+        @Override
+        public TermsEnum intersect(CompiledAutomaton compiled, BytesRef startTerm) throws IOException {
+            return new ExitableTermsEnum(in.intersect(compiled, startTerm), queryCancellation);
+        }
+
+        @Override
+        public TermsEnum iterator() throws IOException {
+            return new ExitableTermsEnum(in.iterator(), queryCancellation);
+        }
+    }
+
+    /**
+     * Wrapper class for {@link FilterLeafReader.FilterTermsEnum} that is used by {@link ExitableTerms} for
+     * implementing an exitable enumeration of terms.
+     */
+    private static class ExitableTermsEnum extends FilterLeafReader.FilterTermsEnum {
+
+        private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = (1 << 4) - 1; // 15
+
+        private int calls;
+        private final QueryCancellation queryCancellation;
+
+        private ExitableTermsEnum(TermsEnum termsEnum, QueryCancellation queryCancellation) {
+            super(termsEnum);
+            this.queryCancellation = queryCancellation;
+            this.queryCancellation.checkCancelled();
+        }
+
+        private void checkAndThrowWithSampling() {
+            if ((calls++ & MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) {
+                queryCancellation.checkCancelled();
+            }
+        }
+
+        @Override
+        public BytesRef next() throws IOException {
+            checkAndThrowWithSampling();
+            return in.next();
+        }
+    }
+
+    /**
+     * Wrapper class for {@link PointValues} that checks for query cancellation or timeout.
+     */
+    static class ExitablePointValues extends PointValues {
+
+        private final PointValues in;
+        private final QueryCancellation queryCancellation;
+
+        private ExitablePointValues(PointValues in, QueryCancellation queryCancellation) {
+            this.in = in;
+            this.queryCancellation = queryCancellation;
+            this.queryCancellation.checkCancelled();
+        }
+
+        @Override
+        public void intersect(IntersectVisitor visitor) throws IOException {
+            queryCancellation.checkCancelled();
+            in.intersect(new ExitableIntersectVisitor(visitor, queryCancellation));
+        }
+
+        @Override
+        public long estimatePointCount(IntersectVisitor visitor) {
+            queryCancellation.checkCancelled();
+            return in.estimatePointCount(visitor);
+        }
+
+        @Override
+        public byte[] getMinPackedValue() throws IOException {
+            queryCancellation.checkCancelled();
+            return in.getMinPackedValue();
+        }
+
+        @Override
+        public byte[] getMaxPackedValue() throws IOException {
+            queryCancellation.checkCancelled();
+            return in.getMaxPackedValue();
+        }
+
+        @Override
+        public int getNumDimensions() throws IOException {
+            queryCancellation.checkCancelled();
+            return in.getNumDimensions();
+        }
+
+        @Override
+        public int getNumIndexDimensions() throws IOException {
+            queryCancellation.checkCancelled();
+            return in.getNumIndexDimensions();
+        }
+
+        @Override
+        public int getBytesPerDimension() throws IOException {
+            queryCancellation.checkCancelled();
+            return in.getBytesPerDimension();
+        }
+
+        @Override
+        public long size() {
+            queryCancellation.checkCancelled();
+            return in.size();
+        }
+
+        @Override
+        public int getDocCount() {
+            queryCancellation.checkCancelled();
+            return in.getDocCount();
+        }
+    }
+
+    private static class ExitableIntersectVisitor implements PointValues.IntersectVisitor {
+
+        private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = (1 << 4) - 1; // 15
+
+        private final PointValues.IntersectVisitor in;
+        private final QueryCancellation queryCancellation;
+        private int calls;
+
+        private ExitableIntersectVisitor(PointValues.IntersectVisitor in, QueryCancellation queryCancellation) {
+            this.in = in;
+            this.queryCancellation = queryCancellation;
+        }
+
+        private void checkAndThrowWithSampling() {
+            if ((calls++ & MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK) == 0) {
+                queryCancellation.checkCancelled();
+            }
+        }
+
+        @Override
+        public void visit(int docID) throws IOException {
+            checkAndThrowWithSampling();
+            in.visit(docID);
+        }
+
+        @Override
+        public void visit(int docID, byte[] packedValue) throws IOException {
+            checkAndThrowWithSampling();
+            in.visit(docID, packedValue);
+        }
+
+        @Override
+        public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+            queryCancellation.checkCancelled();
+            return in.compare(minPackedValue, maxPackedValue);
+        }
+
+        @Override
+        public void grow(int count) {
+            queryCancellation.checkCancelled();
+            in.grow(count);
+        }
+    }
+}

+ 38 - 43
server/src/main/java/org/elasticsearch/search/query/QueryPhase.java

@@ -45,8 +45,8 @@ import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.TopFieldCollector;
 import org.apache.lucene.search.TopFieldDocs;
 import org.apache.lucene.search.TotalHits;
-import org.elasticsearch.action.search.SearchShardTask;
 import org.apache.lucene.search.Weight;
+import org.elasticsearch.action.search.SearchShardTask;
 import org.elasticsearch.common.Booleans;
 import org.elasticsearch.common.CheckedConsumer;
 import org.elasticsearch.common.lucene.Lucene;
@@ -54,8 +54,8 @@ import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
 import org.elasticsearch.common.util.concurrent.EWMATrackingEsThreadPoolExecutor;
 import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
 import org.elasticsearch.index.IndexSortConfig;
-import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.DateFieldMapper.DateFieldType;
+import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.SearchPhase;
 import org.elasticsearch.search.SearchService;
@@ -254,63 +254,58 @@ public class QueryPhase implements SearchPhase {
                 final long startTime = searchContext.getRelativeTimeInMillis();
                 final long timeout = searchContext.timeout().millis();
                 final long maxTime = startTime + timeout;
-                timeoutRunnable = () -> {
+                timeoutRunnable = searcher.addQueryCancellation(() -> {
                     final long time = searchContext.getRelativeTimeInMillis();
                     if (time > maxTime) {
                         throw new TimeExceededException();
                     }
-                };
+                });
             } else {
                 timeoutRunnable = null;
             }
 
-            final Runnable cancellationRunnable;
             if (searchContext.lowLevelCancellation()) {
                 SearchShardTask task = searchContext.getTask();
-                cancellationRunnable = () -> { if (task.isCancelled()) throw new TaskCancelledException("cancelled"); };
-            } else {
-                cancellationRunnable = null;
+                searcher.addQueryCancellation(() -> {
+                    if (task.isCancelled()) {
+                        throw new TaskCancelledException("cancelled");
+                    }
+                });
             }
 
-            final Runnable checkCancelled;
-            if (timeoutRunnable != null && cancellationRunnable != null) {
-                checkCancelled = () -> {
-                    timeoutRunnable.run();
-                    cancellationRunnable.run();
-                };
-            } else if (timeoutRunnable != null) {
-                checkCancelled = timeoutRunnable;
-            } else if (cancellationRunnable != null) {
-                checkCancelled = cancellationRunnable;
-            } else {
-                checkCancelled = null;
-            }
-            searcher.setCheckCancelled(checkCancelled);
+            try {
+                boolean shouldRescore;
+                // if we are optimizing sort and there are no other collectors
+                if (sortAndFormatsForRewrittenNumericSort != null && collectors.size() == 0 && searchContext.getProfilers() == null) {
+                    shouldRescore = searchWithCollectorManager(searchContext, searcher, query, leafSorter, timeoutSet);
+                } else {
+                    shouldRescore = searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, timeoutSet);
+                }
 
-            boolean shouldRescore;
-            // if we are optimizing sort and there are no other collectors
-            if (sortAndFormatsForRewrittenNumericSort != null && collectors.size() == 0 && searchContext.getProfilers() == null) {
-                shouldRescore = searchWithCollectorManager(searchContext, searcher, query, leafSorter, timeoutSet);
-            } else {
-                shouldRescore = searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, timeoutSet);
-            }
+                // if we rewrote numeric long or date sort, restore fieldDocs based on the original sort
+                if (sortAndFormatsForRewrittenNumericSort != null) {
+                    searchContext.sort(sortAndFormatsForRewrittenNumericSort); // restore SortAndFormats
+                    restoreTopFieldDocs(queryResult, sortAndFormatsForRewrittenNumericSort);
+                }
 
-            // if we rewrote numeric long or date sort, restore fieldDocs based on the original sort
-            if (sortAndFormatsForRewrittenNumericSort != null) {
-                searchContext.sort(sortAndFormatsForRewrittenNumericSort); // restore SortAndFormats
-                restoreTopFieldDocs(queryResult, sortAndFormatsForRewrittenNumericSort);
-            }
+                ExecutorService executor = searchContext.indexShard().getThreadPool().executor(ThreadPool.Names.SEARCH);
+                assert executor instanceof EWMATrackingEsThreadPoolExecutor ||
+                    (executor instanceof EsThreadPoolExecutor == false /* in case thread pool is mocked out in tests */) :
+                    "SEARCH threadpool should have an executor that exposes EWMA metrics, but is of type " + executor.getClass();
+                if (executor instanceof EWMATrackingEsThreadPoolExecutor) {
+                    EWMATrackingEsThreadPoolExecutor rExecutor = (EWMATrackingEsThreadPoolExecutor) executor;
+                    queryResult.nodeQueueSize(rExecutor.getCurrentQueueSize());
+                    queryResult.serviceTimeEWMA((long) rExecutor.getTaskExecutionEWMA());
+                }
 
-            ExecutorService executor = searchContext.indexShard().getThreadPool().executor(ThreadPool.Names.SEARCH);
-            assert executor instanceof EWMATrackingEsThreadPoolExecutor ||
-                (executor instanceof EsThreadPoolExecutor == false /* in case thread pool is mocked out in tests */) :
-                "SEARCH threadpool should have an executor that exposes EWMA metrics, but is of type " + executor.getClass();
-            if (executor instanceof EWMATrackingEsThreadPoolExecutor) {
-                EWMATrackingEsThreadPoolExecutor rExecutor = (EWMATrackingEsThreadPoolExecutor) executor;
-                queryResult.nodeQueueSize(rExecutor.getCurrentQueueSize());
-                queryResult.serviceTimeEWMA((long) rExecutor.getTaskExecutionEWMA());
+                return shouldRescore;
+            } finally {
+                // Search phase has finished, no longer need to check for timeout
+                // otherwise aggregation phase might get cancelled.
+                if (timeoutRunnable != null) {
+                   searcher.removeQueryCancellation(timeoutRunnable);
+                }
             }
-            return shouldRescore;
         } catch (Exception e) {
             throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Failed to execute main query", e);
         }

+ 100 - 14
server/src/test/java/org/elasticsearch/search/SearchCancellationTests.java

@@ -20,16 +20,22 @@ package org.elasticsearch.search;
 
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
+import org.apache.lucene.document.IntPoint;
 import org.apache.lucene.document.StringField;
 import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.NoMergePolicy;
+import org.apache.lucene.index.PointValues;
 import org.apache.lucene.index.RandomIndexWriter;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.MatchAllDocsQuery;
 import org.apache.lucene.search.TotalHitCountCollector;
 import org.apache.lucene.store.Directory;
-import org.elasticsearch.core.internal.io.IOUtils;
 import org.apache.lucene.util.TestUtil;
+import org.apache.lucene.util.automaton.CompiledAutomaton;
+import org.apache.lucene.util.automaton.RegExp;
+import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.search.internal.ContextIndexSearcher;
 import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.test.ESTestCase;
@@ -43,8 +49,11 @@ import static org.hamcrest.Matchers.equalTo;
 
 public class SearchCancellationTests extends ESTestCase {
 
-    static Directory dir;
-    static IndexReader reader;
+    private static final String STRING_FIELD_NAME = "foo";
+    private static final String POINT_FIELD_NAME = "point";
+
+    private static Directory dir;
+    private static IndexReader reader;
 
     @BeforeClass
     public static void setup() throws IOException {
@@ -61,9 +70,10 @@ public class SearchCancellationTests extends ESTestCase {
     }
 
     private static void indexRandomDocuments(RandomIndexWriter w, int numDocs) throws IOException {
-        for (int i = 0; i < numDocs; ++i) {
+        for (int i = 1; i <= numDocs; ++i) {
             Document doc = new Document();
-            doc.add(new StringField("foo", "bar", Field.Store.NO));
+            doc.add(new StringField(STRING_FIELD_NAME, "a".repeat(i), Field.Store.NO));
+            doc.add(new IntPoint(POINT_FIELD_NAME, i, i + 1));
             w.addDocument(doc);
         }
     }
@@ -75,21 +85,97 @@ public class SearchCancellationTests extends ESTestCase {
         reader = null;
     }
 
+    public void testAddingCancellationActions() throws IOException {
+        ContextIndexSearcher searcher = new ContextIndexSearcher(reader,
+                IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy());
+        NullPointerException npe = expectThrows(NullPointerException.class, () -> searcher.addQueryCancellation(null));
+        assertEquals("cancellation runnable should not be null", npe.getMessage());
+
+        Runnable r = () -> {};
+        searcher.addQueryCancellation(r);
+        IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> searcher.addQueryCancellation(r));
+        assertEquals("Cancellation runnable already added", iae.getMessage());
+    }
+
     public void testCancellableCollector() throws IOException {
-        TotalHitCountCollector collector = new TotalHitCountCollector();
-        AtomicBoolean cancelled = new AtomicBoolean();
+        TotalHitCountCollector collector1 = new TotalHitCountCollector();
+        Runnable cancellation = () -> { throw new TaskCancelledException("cancelled"); };
         ContextIndexSearcher searcher = new ContextIndexSearcher(reader,
             IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy());
-        searcher.setCheckCancelled(() -> {
+
+        searcher.search(new MatchAllDocsQuery(), collector1);
+        assertThat(collector1.getTotalHits(), equalTo(reader.numDocs()));
+
+        searcher.addQueryCancellation(cancellation);
+        expectThrows(TaskCancelledException.class,
+            () -> searcher.search(new MatchAllDocsQuery(), collector1));
+
+        searcher.removeQueryCancellation(cancellation);
+        TotalHitCountCollector collector2 = new TotalHitCountCollector();
+        searcher.search(new MatchAllDocsQuery(), collector2);
+        assertThat(collector2.getTotalHits(), equalTo(reader.numDocs()));
+    }
+
+    public void testCancellableDirectoryReader() throws IOException {
+        AtomicBoolean cancelled = new AtomicBoolean(true);
+        Runnable cancellation = () -> {
             if (cancelled.get()) {
                 throw new TaskCancelledException("cancelled");
-            }
-        });
-        searcher.search(new MatchAllDocsQuery(), collector);
-        assertThat(collector.getTotalHits(), equalTo(reader.numDocs()));
-        cancelled.set(true);
+        }};
+        ContextIndexSearcher searcher = new ContextIndexSearcher(reader,
+                IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy());
+        searcher.addQueryCancellation(cancellation);
+        CompiledAutomaton automaton = new CompiledAutomaton(new RegExp("a.*").toAutomaton());
+
+        expectThrows(TaskCancelledException.class,
+                () -> searcher.getIndexReader().leaves().get(0).reader().terms(STRING_FIELD_NAME).iterator());
+        expectThrows(TaskCancelledException.class,
+                () -> searcher.getIndexReader().leaves().get(0).reader().terms(STRING_FIELD_NAME).intersect(automaton, null));
+        expectThrows(TaskCancelledException.class,
+                () -> searcher.getIndexReader().leaves().get(0).reader().getPointValues(POINT_FIELD_NAME));
         expectThrows(TaskCancelledException.class,
-            () -> searcher.search(new MatchAllDocsQuery(), collector));
+                () -> searcher.getIndexReader().leaves().get(0).reader().getPointValues(POINT_FIELD_NAME));
+
+        cancelled.set(false); // Avoid exception during construction of the wrapper objects
+        Terms terms = searcher.getIndexReader().leaves().get(0).reader().terms(STRING_FIELD_NAME);
+        TermsEnum termsIterator = terms.iterator();
+        TermsEnum termsIntersect = terms.intersect(automaton, null);
+        PointValues pointValues1 = searcher.getIndexReader().leaves().get(0).reader().getPointValues(POINT_FIELD_NAME);
+        cancelled.set(true);
+        expectThrows(TaskCancelledException.class, termsIterator::next);
+        expectThrows(TaskCancelledException.class, termsIntersect::next);
+        expectThrows(TaskCancelledException.class, pointValues1::getDocCount);
+        expectThrows(TaskCancelledException.class, pointValues1::getNumIndexDimensions);
+        expectThrows(TaskCancelledException.class, () -> pointValues1.intersect(new PointValuesIntersectVisitor()));
+
+        cancelled.set(false); // Avoid exception during construction of the wrapper objects
+        // Re-initialize objects so that we reset the `calls` counter used to avoid cancellation check
+        // on every iteration and assure that cancellation would normally happen if we hadn't removed the
+        // cancellation runnable.
+        termsIterator = terms.iterator();
+        termsIntersect = terms.intersect(automaton, null);
+        PointValues pointValues2 = searcher.getIndexReader().leaves().get(0).reader().getPointValues(POINT_FIELD_NAME);
+        cancelled.set(true);
+        searcher.removeQueryCancellation(cancellation);
+        termsIterator.next();
+        termsIntersect.next();
+        pointValues2.getDocCount();
+        pointValues2.getNumIndexDimensions();
+        pointValues2.intersect(new PointValuesIntersectVisitor());
     }
 
+    private static class PointValuesIntersectVisitor implements PointValues.IntersectVisitor {
+        @Override
+        public void visit(int docID) {
+        }
+
+        @Override
+        public void visit(int docID, byte[] packedValue) {
+        }
+
+        @Override
+        public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+            return PointValues.Relation.CELL_CROSSES_QUERY;
+        }
+    }
 }

+ 19 - 1
server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java

@@ -22,6 +22,7 @@ package org.elasticsearch.search.internal;
 import org.apache.lucene.analysis.standard.StandardAnalyzer;
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
+import org.apache.lucene.document.IntPoint;
 import org.apache.lucene.document.StringField;
 import org.apache.lucene.index.DirectoryReader;
 import org.apache.lucene.index.FilterDirectoryReader;
@@ -76,6 +77,9 @@ import java.util.IdentityHashMap;
 import java.util.Set;
 
 import static org.elasticsearch.search.internal.ContextIndexSearcher.intersectScorerAndBitSet;
+import static org.elasticsearch.search.internal.ExitableDirectoryReader.ExitableLeafReader;
+import static org.elasticsearch.search.internal.ExitableDirectoryReader.ExitablePointValues;
+import static org.elasticsearch.search.internal.ExitableDirectoryReader.ExitableTerms;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 
@@ -191,6 +195,8 @@ public class ContextIndexSearcherTests extends ESTestCase {
         doc.add(fooField);
         StringField deleteField = new StringField("delete", "no", Field.Store.NO);
         doc.add(deleteField);
+        IntPoint pointField = new IntPoint("point", 1, 2);
+        doc.add(pointField);
         w.addDocument(doc);
         if (deletions) {
             // add a document that matches foo:bar but will be deleted
@@ -235,7 +241,19 @@ public class ContextIndexSearcherTests extends ESTestCase {
 
         ContextIndexSearcher searcher = new ContextIndexSearcher(filteredReader, IndexSearcher.getDefaultSimilarity(),
             IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy());
-        searcher.setCheckCancelled(() -> {});
+
+        // Assert wrapping
+        assertEquals(ExitableDirectoryReader.class, searcher.getIndexReader().getClass());
+        for (LeafReaderContext lrc : searcher.getIndexReader().leaves()) {
+            assertEquals(ExitableLeafReader.class, lrc.reader().getClass());
+            assertNotEquals(ExitableTerms.class, lrc.reader().terms("foo").getClass());
+            assertNotEquals(ExitablePointValues.class, lrc.reader().getPointValues("point").getClass());
+        }
+        searcher.addQueryCancellation(() -> {});
+        for (LeafReaderContext lrc : searcher.getIndexReader().leaves()) {
+            assertEquals(ExitableTerms.class, lrc.reader().terms("foo").getClass());
+            assertEquals(ExitablePointValues.class, lrc.reader().getPointValues("point").getClass());
+        }
 
         // Searching a non-existing term will trigger a null scorer
         assertEquals(0, searcher.count(new TermQuery(new Term("non_existing_field", "non_existing_value"))));

+ 3 - 3
server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java

@@ -828,12 +828,12 @@ public class QueryPhaseTests extends IndexShardTestCase {
 
     }
 
-    private static ContextIndexSearcher newContextSearcher(IndexReader reader) {
+    private static ContextIndexSearcher newContextSearcher(IndexReader reader) throws IOException {
         return new ContextIndexSearcher(reader, IndexSearcher.getDefaultSimilarity(),
             IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy());
     }
 
-    private static ContextIndexSearcher newEarlyTerminationContextSearcher(IndexReader reader, int size) {
+    private static ContextIndexSearcher newEarlyTerminationContextSearcher(IndexReader reader, int size) throws IOException {
         return new ContextIndexSearcher(reader, IndexSearcher.getDefaultSimilarity(),
             IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy()) {
 
@@ -846,7 +846,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
     }
 
     // used to check that numeric long or date sort optimization was run
-    private static ContextIndexSearcher newOptimizedContextSearcher(IndexReader reader, int queryType) {
+    private static ContextIndexSearcher newOptimizedContextSearcher(IndexReader reader, int queryType) throws IOException {
         return new ContextIndexSearcher(reader, IndexSearcher.getDefaultSimilarity(),
             IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCachingPolicy()) {
 

+ 3 - 3
test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

@@ -222,7 +222,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
                                                 IndexSettings indexSettings,
                                                 Query query,
                                                 MultiBucketConsumer bucketConsumer,
-                                                MappedFieldType... fieldTypes) {
+                                                MappedFieldType... fieldTypes) throws IOException {
         return createSearchContext(indexSearcher, indexSettings, query, bucketConsumer, new NoneCircuitBreakerService(), fieldTypes);
     }
 
@@ -231,7 +231,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
                                                 Query query,
                                                 MultiBucketConsumer bucketConsumer,
                                                 CircuitBreakerService circuitBreakerService,
-                                                MappedFieldType... fieldTypes) {
+                                                MappedFieldType... fieldTypes) throws IOException {
         QueryCache queryCache = new DisabledQueryCache(indexSettings);
         QueryCachingPolicy queryCachingPolicy = new QueryCachingPolicy() {
             @Override
@@ -245,7 +245,7 @@ public abstract class AggregatorTestCase extends ESTestCase {
             }
         };
         ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(indexSearcher.getIndexReader(),
-            indexSearcher.getSimilarity(), queryCache, queryCachingPolicy);
+            indexSearcher.getSimilarity(), queryCache, queryCachingPolicy, null);
 
         SearchContext searchContext = mock(SearchContext.class);
         when(searchContext.numberOfShards()).thenReturn(1);