Przeglądaj źródła

Fix handling of time exceeded exception in fetch phase (#116676)

The fetch phase is subject to timeouts like any other search phase. Timeouts
may happen when low level cancellation is enabled (true by default), hence the
directory reader is wrapped into ExitableDirectoryReader and a timeout is
provided to the search request.

The exception that is used is TimeExceededException, but it is an internal
exception that should never be returned to the user. When that is thrown, we
need to catch it and throw error or mark the response as timed out depending
on whether partial results are allowed or not.
Luca Cavanna 11 miesięcy temu
rodzic
commit
e4f4c95442

+ 5 - 0
docs/changelog/116676.yaml

@@ -0,0 +1,5 @@
+pr: 116676
+summary: Fix handling of time exceeded exception in fetch phase
+area: Search
+type: bug
+issues: []

+ 10 - 1
server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java

@@ -191,7 +191,16 @@ public final class FetchPhase {
             }
         };
 
-        SearchHit[] hits = docsIterator.iterate(context.shardTarget(), context.searcher().getIndexReader(), docIdsToLoad);
+        SearchHit[] hits = docsIterator.iterate(
+            context.shardTarget(),
+            context.searcher().getIndexReader(),
+            docIdsToLoad,
+            context.request().allowPartialSearchResults()
+        );
+
+        if (docsIterator.isTimedOut()) {
+            context.queryResult().searchTimedOut(true);
+        }
 
         if (context.isCancelled()) {
             for (SearchHit hit : hits) {

+ 50 - 16
server/src/main/java/org/elasticsearch/search/fetch/FetchPhaseDocsIterator.java

@@ -13,7 +13,10 @@ import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.ReaderUtil;
 import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchShardTarget;
+import org.elasticsearch.search.internal.ContextIndexSearcher;
+import org.elasticsearch.search.query.SearchTimeoutException;
 
 import java.io.IOException;
 import java.util.Arrays;
@@ -27,6 +30,12 @@ import java.util.Arrays;
  */
 abstract class FetchPhaseDocsIterator {
 
+    private boolean timedOut = false;
+
+    public boolean isTimedOut() {
+        return timedOut;
+    }
+
     /**
      * Called when a new leaf reader is reached
      * @param ctx           the leaf reader for this set of doc ids
@@ -44,7 +53,7 @@ abstract class FetchPhaseDocsIterator {
     /**
      * Iterate over a set of docsIds within a particular shard and index reader
      */
-    public final SearchHit[] iterate(SearchShardTarget shardTarget, IndexReader indexReader, int[] docIds) {
+    public final SearchHit[] iterate(SearchShardTarget shardTarget, IndexReader indexReader, int[] docIds, boolean allowPartialResults) {
         SearchHit[] searchHits = new SearchHit[docIds.length];
         DocIdToIndex[] docs = new DocIdToIndex[docIds.length];
         for (int index = 0; index < docIds.length; index++) {
@@ -58,30 +67,55 @@ abstract class FetchPhaseDocsIterator {
             LeafReaderContext ctx = indexReader.leaves().get(leafOrd);
             int endReaderIdx = endReaderIdx(ctx, 0, docs);
             int[] docsInLeaf = docIdsInLeaf(0, endReaderIdx, docs, ctx.docBase);
-            setNextReader(ctx, docsInLeaf);
-            for (int i = 0; i < docs.length; i++) {
-                if (i >= endReaderIdx) {
-                    leafOrd = ReaderUtil.subIndex(docs[i].docId, indexReader.leaves());
-                    ctx = indexReader.leaves().get(leafOrd);
-                    endReaderIdx = endReaderIdx(ctx, i, docs);
-                    docsInLeaf = docIdsInLeaf(i, endReaderIdx, docs, ctx.docBase);
-                    setNextReader(ctx, docsInLeaf);
+            try {
+                setNextReader(ctx, docsInLeaf);
+            } catch (ContextIndexSearcher.TimeExceededException timeExceededException) {
+                if (allowPartialResults) {
+                    timedOut = true;
+                    return SearchHits.EMPTY;
                 }
-                currentDoc = docs[i].docId;
-                assert searchHits[docs[i].index] == null;
-                searchHits[docs[i].index] = nextDoc(docs[i].docId);
+                throw new SearchTimeoutException(shardTarget, "Time exceeded");
             }
-        } catch (Exception e) {
-            for (SearchHit searchHit : searchHits) {
-                if (searchHit != null) {
-                    searchHit.decRef();
+            for (int i = 0; i < docs.length; i++) {
+                try {
+                    if (i >= endReaderIdx) {
+                        leafOrd = ReaderUtil.subIndex(docs[i].docId, indexReader.leaves());
+                        ctx = indexReader.leaves().get(leafOrd);
+                        endReaderIdx = endReaderIdx(ctx, i, docs);
+                        docsInLeaf = docIdsInLeaf(i, endReaderIdx, docs, ctx.docBase);
+                        setNextReader(ctx, docsInLeaf);
+                    }
+                    currentDoc = docs[i].docId;
+                    assert searchHits[docs[i].index] == null;
+                    searchHits[docs[i].index] = nextDoc(docs[i].docId);
+                } catch (ContextIndexSearcher.TimeExceededException timeExceededException) {
+                    if (allowPartialResults) {
+                        timedOut = true;
+                        SearchHit[] partialSearchHits = new SearchHit[i];
+                        System.arraycopy(searchHits, 0, partialSearchHits, 0, i);
+                        return partialSearchHits;
+                    }
+                    purgeSearchHits(searchHits);
+                    throw new SearchTimeoutException(shardTarget, "Time exceeded");
                 }
             }
+        } catch (SearchTimeoutException e) {
+            throw e;
+        } catch (Exception e) {
+            purgeSearchHits(searchHits);
             throw new FetchPhaseExecutionException(shardTarget, "Error running fetch phase for doc [" + currentDoc + "]", e);
         }
         return searchHits;
     }
 
+    private static void purgeSearchHits(SearchHit[] searchHits) {
+        for (SearchHit searchHit : searchHits) {
+            if (searchHit != null) {
+                searchHit.decRef();
+            }
+        }
+    }
+
     private static int endReaderIdx(LeafReaderContext currentReaderContext, int index, DocIdToIndex[] docs) {
         int firstInNextReader = currentReaderContext.docBase + currentReaderContext.reader().maxDoc();
         int i = index + 1;

+ 185 - 0
server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java

@@ -8,35 +8,65 @@
  */
 package org.elasticsearch.action.search;
 
+import org.apache.lucene.document.Document;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.QueryCachingPolicy;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.RandomIndexWriter;
 import org.apache.lucene.tests.store.MockDirectoryWrapper;
+import org.apache.lucene.util.Accountable;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.OriginalIndices;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.breaker.CircuitBreaker;
 import org.elasticsearch.common.breaker.NoopCircuitBreaker;
 import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.index.IndexSettings;
+import org.elasticsearch.index.IndexVersion;
+import org.elasticsearch.index.cache.bitset.BitsetFilterCache;
+import org.elasticsearch.index.mapper.IdLoader;
+import org.elasticsearch.index.mapper.MapperMetrics;
+import org.elasticsearch.index.mapper.MappingLookup;
+import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchShardTarget;
+import org.elasticsearch.search.fetch.FetchPhase;
 import org.elasticsearch.search.fetch.FetchSearchResult;
+import org.elasticsearch.search.fetch.FetchSubPhase;
+import org.elasticsearch.search.fetch.FetchSubPhaseProcessor;
 import org.elasticsearch.search.fetch.QueryFetchSearchResult;
 import org.elasticsearch.search.fetch.ShardFetchSearchRequest;
+import org.elasticsearch.search.fetch.StoredFieldsSpec;
+import org.elasticsearch.search.internal.AliasFilter;
+import org.elasticsearch.search.internal.ContextIndexSearcher;
+import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.internal.ShardSearchContextId;
+import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.search.profile.ProfileResult;
 import org.elasticsearch.search.profile.SearchProfileQueryPhaseResult;
 import org.elasticsearch.search.profile.SearchProfileShardResult;
 import org.elasticsearch.search.query.QuerySearchResult;
+import org.elasticsearch.search.query.SearchTimeoutException;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.InternalAggregationTestCase;
+import org.elasticsearch.test.TestSearchContext;
 import org.elasticsearch.transport.Transport;
 
+import java.io.IOException;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
@@ -747,4 +777,159 @@ public class FetchSearchPhaseTests extends ESTestCase {
     private static ProfileResult fetchProfile(boolean profiled) {
         return profiled ? new ProfileResult("fetch", "fetch", Map.of(), Map.of(), FETCH_PROFILE_TIME, List.of()) : null;
     }
+
+    public void testFetchTimeoutWithPartialResults() throws IOException {
+        Directory dir = newDirectory();
+        RandomIndexWriter w = new RandomIndexWriter(random(), dir);
+        w.addDocument(new Document());
+        w.addDocument(new Document());
+        w.addDocument(new Document());
+        IndexReader r = w.getReader();
+        w.close();
+        ContextIndexSearcher contextIndexSearcher = createSearcher(r);
+        try (SearchContext searchContext = createSearchContext(contextIndexSearcher, true)) {
+            FetchPhase fetchPhase = createFetchPhase(contextIndexSearcher);
+            fetchPhase.execute(searchContext, new int[] { 0, 1, 2 }, null);
+            assertTrue(searchContext.queryResult().searchTimedOut());
+            assertEquals(1, searchContext.fetchResult().hits().getHits().length);
+        } finally {
+            r.close();
+            dir.close();
+        }
+    }
+
+    public void testFetchTimeoutNoPartialResults() throws IOException {
+        Directory dir = newDirectory();
+        RandomIndexWriter w = new RandomIndexWriter(random(), dir);
+        w.addDocument(new Document());
+        w.addDocument(new Document());
+        w.addDocument(new Document());
+        IndexReader r = w.getReader();
+        w.close();
+        ContextIndexSearcher contextIndexSearcher = createSearcher(r);
+
+        try (SearchContext searchContext = createSearchContext(contextIndexSearcher, false)) {
+            FetchPhase fetchPhase = createFetchPhase(contextIndexSearcher);
+            expectThrows(SearchTimeoutException.class, () -> fetchPhase.execute(searchContext, new int[] { 0, 1, 2 }, null));
+            assertNull(searchContext.fetchResult().hits());
+        } finally {
+            r.close();
+            dir.close();
+        }
+    }
+
+    private static ContextIndexSearcher createSearcher(IndexReader reader) throws IOException {
+        return new ContextIndexSearcher(reader, null, null, new QueryCachingPolicy() {
+            @Override
+            public void onUse(Query query) {}
+
+            @Override
+            public boolean shouldCache(Query query) {
+                return false;
+            }
+        }, randomBoolean());
+    }
+
+    private static FetchPhase createFetchPhase(ContextIndexSearcher contextIndexSearcher) {
+        return new FetchPhase(Collections.singletonList(fetchContext -> new FetchSubPhaseProcessor() {
+            boolean processCalledOnce = false;
+
+            @Override
+            public void setNextReader(LeafReaderContext readerContext) {}
+
+            @Override
+            public void process(FetchSubPhase.HitContext hitContext) {
+                // we throw only once one doc has been fetched, so we can test partial results are returned
+                if (processCalledOnce) {
+                    contextIndexSearcher.throwTimeExceededException();
+                } else {
+                    processCalledOnce = true;
+                }
+            }
+
+            @Override
+            public StoredFieldsSpec storedFieldsSpec() {
+                return StoredFieldsSpec.NO_REQUIREMENTS;
+            }
+        }));
+    }
+
+    private static SearchContext createSearchContext(ContextIndexSearcher contextIndexSearcher, boolean allowPartialResults) {
+        IndexSettings indexSettings = new IndexSettings(
+            IndexMetadata.builder("index")
+                .settings(Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()))
+                .numberOfShards(1)
+                .numberOfReplicas(0)
+                .creationDate(System.currentTimeMillis())
+                .build(),
+            Settings.EMPTY
+        );
+        BitsetFilterCache bitsetFilterCache = new BitsetFilterCache(indexSettings, new BitsetFilterCache.Listener() {
+            @Override
+            public void onCache(ShardId shardId, Accountable accountable) {
+
+            }
+
+            @Override
+            public void onRemoval(ShardId shardId, Accountable accountable) {
+
+            }
+        });
+
+        SearchExecutionContext searchExecutionContext = new SearchExecutionContext(
+            0,
+            0,
+            indexSettings,
+            bitsetFilterCache,
+            null,
+            null,
+            MappingLookup.EMPTY,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null,
+            Collections.emptyMap(),
+            null,
+            MapperMetrics.NOOP
+        );
+        TestSearchContext searchContext = new TestSearchContext(searchExecutionContext, null, contextIndexSearcher) {
+            private final FetchSearchResult fetchSearchResult = new FetchSearchResult();
+            private final ShardSearchRequest request = new ShardSearchRequest(
+                OriginalIndices.NONE,
+                new SearchRequest().allowPartialSearchResults(allowPartialResults),
+                new ShardId("index", "indexUUID", 0),
+                0,
+                1,
+                AliasFilter.EMPTY,
+                1f,
+                0L,
+                null
+            );
+
+            @Override
+            public IdLoader newIdLoader() {
+                return new IdLoader.StoredIdLoader();
+            }
+
+            @Override
+            public FetchSearchResult fetchResult() {
+                return fetchSearchResult;
+            }
+
+            @Override
+            public ShardSearchRequest request() {
+                return request;
+            }
+        };
+        searchContext.addReleasable(searchContext.fetchResult()::decRef);
+        searchContext.setTask(new SearchShardTask(-1, "type", "action", "description", null, Collections.emptyMap()));
+        return searchContext;
+    }
 }

+ 2 - 2
server/src/test/java/org/elasticsearch/search/fetch/FetchPhaseDocsIteratorTests.java

@@ -77,7 +77,7 @@ public class FetchPhaseDocsIteratorTests extends ESTestCase {
             }
         };
 
-        SearchHit[] hits = it.iterate(null, reader, docs);
+        SearchHit[] hits = it.iterate(null, reader, docs, randomBoolean());
 
         assertThat(hits.length, equalTo(docs.length));
         for (int i = 0; i < hits.length; i++) {
@@ -125,7 +125,7 @@ public class FetchPhaseDocsIteratorTests extends ESTestCase {
             }
         };
 
-        Exception e = expectThrows(FetchPhaseExecutionException.class, () -> it.iterate(null, reader, docs));
+        Exception e = expectThrows(FetchPhaseExecutionException.class, () -> it.iterate(null, reader, docs, randomBoolean()));
         assertThat(e.getMessage(), containsString("Error running fetch phase for doc [" + badDoc + "]"));
         assertThat(e.getCause(), instanceOf(IllegalArgumentException.class));
 

+ 16 - 1
test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

@@ -52,6 +52,8 @@ import org.apache.lucene.util.NumericUtils;
 import org.apache.lucene.util.packed.PackedInts;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
+import org.elasticsearch.action.OriginalIndices;
+import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.TriConsumer;
@@ -141,8 +143,10 @@ import org.elasticsearch.search.aggregations.support.ValuesSourceType;
 import org.elasticsearch.search.fetch.FetchPhase;
 import org.elasticsearch.search.fetch.subphase.FetchDocValuesPhase;
 import org.elasticsearch.search.fetch.subphase.FetchSourcePhase;
+import org.elasticsearch.search.internal.AliasFilter;
 import org.elasticsearch.search.internal.ContextIndexSearcher;
 import org.elasticsearch.search.internal.SearchContext;
+import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.search.internal.SubSearchContext;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.InternalAggregationTestCase;
@@ -463,7 +467,18 @@ public abstract class AggregatorTestCase extends ESTestCase {
             .when(subContext)
             .getNestedDocuments();
         when(ctx.getSearchExecutionContext()).thenReturn(subContext);
-
+        ShardSearchRequest request = new ShardSearchRequest(
+            OriginalIndices.NONE,
+            new SearchRequest().allowPartialSearchResults(randomBoolean()),
+            new ShardId("index", "indexUUID", 0),
+            0,
+            1,
+            AliasFilter.EMPTY,
+            1f,
+            0L,
+            null
+        );
+        when(ctx.request()).thenReturn(request);
         IndexShard indexShard = mock(IndexShard.class);
         when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
         when(indexShard.indexSettings()).thenReturn(indexSettings);