Selaa lähdekoodia

ContextIndexSearcher#search should return only when all threads are finished (#95909)

Ignacio Vera 2 vuotta sitten
vanhempi
commit
c4158d78db

+ 154 - 2
server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java

@@ -16,7 +16,9 @@ import org.apache.lucene.search.BulkScorer;
 import org.apache.lucene.search.CollectionStatistics;
 import org.apache.lucene.search.CollectionTerminatedException;
 import org.apache.lucene.search.Collector;
+import org.apache.lucene.search.CollectorManager;
 import org.apache.lucene.search.ConjunctionUtils;
+import org.apache.lucene.search.ConstantScoreQuery;
 import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.IndexSearcher;
@@ -33,6 +35,7 @@ import org.apache.lucene.util.BitSet;
 import org.apache.lucene.util.BitSetIterator;
 import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.SparseFixedBitSet;
+import org.apache.lucene.util.ThreadInterruptedException;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.lucene.util.CombinedBitSet;
 import org.elasticsearch.search.dfs.AggregatedDfs;
@@ -43,11 +46,19 @@ import org.elasticsearch.search.profile.query.QueryProfiler;
 import org.elasticsearch.search.profile.query.QueryTimingType;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Objects;
 import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.FutureTask;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.ThreadPoolExecutor;
 
 /**
  * Context-aware extension of {@link IndexSearcher}.
@@ -63,6 +74,10 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
     private QueryProfiler profiler;
     private final MutableQueryTimeout cancellable;
 
+    private final QueueSizeBasedExecutor queueSizeBasedExecutor;
+    private final LeafSlice[] leafSlices;
+
+    /** constructor for non-concurrent search */
     public ContextIndexSearcher(
         IndexReader reader,
         Similarity similarity,
@@ -70,7 +85,19 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
         QueryCachingPolicy queryCachingPolicy,
         boolean wrapWithExitableDirectoryReader
     ) throws IOException {
-        this(reader, similarity, queryCache, queryCachingPolicy, new MutableQueryTimeout(), wrapWithExitableDirectoryReader);
+        this(reader, similarity, queryCache, queryCachingPolicy, new MutableQueryTimeout(), wrapWithExitableDirectoryReader, null);
+    }
+
+    /** constructor for concurrent search */
+    public ContextIndexSearcher(
+        IndexReader reader,
+        Similarity similarity,
+        QueryCache queryCache,
+        QueryCachingPolicy queryCachingPolicy,
+        boolean wrapWithExitableDirectoryReader,
+        ThreadPoolExecutor executor
+    ) throws IOException {
+        this(reader, similarity, queryCache, queryCachingPolicy, new MutableQueryTimeout(), wrapWithExitableDirectoryReader, executor);
     }
 
     private ContextIndexSearcher(
@@ -79,13 +106,17 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
         QueryCache queryCache,
         QueryCachingPolicy queryCachingPolicy,
         MutableQueryTimeout cancellable,
-        boolean wrapWithExitableDirectoryReader
+        boolean wrapWithExitableDirectoryReader,
+        ThreadPoolExecutor executor
     ) throws IOException {
+        // concurrency is handle in this class so don't pass the executor to the parent class
         super(wrapWithExitableDirectoryReader ? new ExitableDirectoryReader((DirectoryReader) reader, cancellable) : reader);
         setSimilarity(similarity);
         setQueryCache(queryCache);
         setQueryCachingPolicy(queryCachingPolicy);
         this.cancellable = cancellable;
+        this.queueSizeBasedExecutor = executor != null ? new QueueSizeBasedExecutor(executor) : null;
+        this.leafSlices = executor == null ? null : slices(leafContexts);
     }
 
     public void setProfiler(QueryProfiler profiler) {
@@ -162,6 +193,79 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
         }
     }
 
+    @Override
+    public <C extends Collector, T> T search(Query query, CollectorManager<C, T> collectorManager) throws IOException {
+        final C firstCollector = collectorManager.newCollector();
+        // Take advantage of the few extra rewrite rules of ConstantScoreQuery when score are not needed.
+        query = firstCollector.scoreMode().needsScores() ? rewrite(query) : rewrite(new ConstantScoreQuery(query));
+        final Weight weight = createWeight(query, firstCollector.scoreMode(), 1);
+        return search(weight, collectorManager, firstCollector);
+    }
+
+    /**
+     * Similar to the lucene implementation but it will wait for all threads to fisinsh before returning even if an error is thrown.
+     * In that case, other exceptions will be ignored and the first exception is thrown after all threads are finished.
+     * */
+    private <C extends Collector, T> T search(Weight weight, CollectorManager<C, T> collectorManager, C firstCollector) throws IOException {
+        if (queueSizeBasedExecutor == null || leafSlices.length <= 1) {
+            search(leafContexts, weight, firstCollector);
+            return collectorManager.reduce(Collections.singletonList(firstCollector));
+        } else {
+            final List<C> collectors = new ArrayList<>(leafSlices.length);
+            collectors.add(firstCollector);
+            final ScoreMode scoreMode = firstCollector.scoreMode();
+            for (int i = 1; i < leafSlices.length; ++i) {
+                final C collector = collectorManager.newCollector();
+                collectors.add(collector);
+                if (scoreMode != collector.scoreMode()) {
+                    throw new IllegalStateException("CollectorManager does not always produce collectors with the same score mode");
+                }
+            }
+            final List<FutureTask<C>> listTasks = new ArrayList<>();
+            for (int i = 0; i < leafSlices.length; ++i) {
+                final LeafReaderContext[] leaves = leafSlices[i].leaves;
+                final C collector = collectors.get(i);
+                FutureTask<C> task = new FutureTask<>(() -> {
+                    search(Arrays.asList(leaves), weight, collector);
+                    return collector;
+                });
+
+                listTasks.add(task);
+            }
+
+            queueSizeBasedExecutor.invokeAll(listTasks);
+            RuntimeException exception = null;
+            final List<C> collectedCollectors = new ArrayList<>();
+            for (Future<C> future : listTasks) {
+                try {
+                    collectedCollectors.add(future.get());
+                    // TODO: when there is an exception and we don't want partial results, it would be great
+                    // to cancel the queries / threads
+                } catch (InterruptedException e) {
+                    if (exception == null) {
+                        exception = new ThreadInterruptedException(e);
+                    } else {
+                        // we ignore further exceptions
+                    }
+                } catch (ExecutionException e) {
+                    if (exception == null) {
+                        if (e.getCause() instanceof RuntimeException runtimeException) {
+                            exception = runtimeException;
+                        } else {
+                            exception = new RuntimeException(e.getCause());
+                        }
+                    } else {
+                        // we ignore further exceptions
+                    }
+                }
+            }
+            if (exception != null) {
+                throw exception;
+            }
+            return collectorManager.reduce(collectedCollectors);
+        }
+    }
+
     @Override
     public void search(List<LeafReaderContext> leaves, Weight weight, Collector collector) throws IOException {
         weight = wrapWeight(weight);
@@ -354,4 +458,52 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
             runnables.clear();
         }
     }
+
+    private static class QueueSizeBasedExecutor {
+        private static final double LIMITING_FACTOR = 1.5;
+
+        private final ThreadPoolExecutor threadPoolExecutor;
+
+        QueueSizeBasedExecutor(ThreadPoolExecutor threadPoolExecutor) {
+            this.threadPoolExecutor = threadPoolExecutor;
+        }
+
+        public void invokeAll(Collection<? extends Runnable> tasks) {
+            int i = 0;
+
+            for (Runnable task : tasks) {
+                boolean shouldExecuteOnCallerThread = false;
+
+                // Execute last task on caller thread
+                if (i == tasks.size() - 1) {
+                    shouldExecuteOnCallerThread = true;
+                }
+
+                if (threadPoolExecutor.getQueue().size() >= (threadPoolExecutor.getMaximumPoolSize() * LIMITING_FACTOR)) {
+                    shouldExecuteOnCallerThread = true;
+                }
+
+                processTask(task, shouldExecuteOnCallerThread);
+
+                ++i;
+            }
+        }
+
+        protected void processTask(final Runnable task, final boolean shouldExecuteOnCallerThread) {
+            if (task == null) {
+                throw new IllegalArgumentException("Input is null");
+            }
+
+            if (shouldExecuteOnCallerThread == false) {
+                try {
+                    threadPoolExecutor.execute(task);
+
+                    return;
+                } catch (@SuppressWarnings("unused") RejectedExecutionException e) {
+                    // Execute on caller thread
+                }
+            }
+            task.run();
+        }
+    }
 }

+ 103 - 0
server/src/test/java/org/elasticsearch/search/internal/ContextIndexSearcherTests.java

@@ -29,6 +29,8 @@ import org.apache.lucene.index.Terms;
 import org.apache.lucene.index.TermsEnum;
 import org.apache.lucene.search.BoostQuery;
 import org.apache.lucene.search.BulkScorer;
+import org.apache.lucene.search.Collector;
+import org.apache.lucene.search.CollectorManager;
 import org.apache.lucene.search.ConstantScoreQuery;
 import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.search.Explanation;
@@ -65,9 +67,14 @@ import org.elasticsearch.test.IndexSettingsModule;
 
 import java.io.IOException;
 import java.io.UncheckedIOException;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.IdentityHashMap;
+import java.util.List;
 import java.util.Set;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.elasticsearch.search.internal.ContextIndexSearcher.intersectScorerAndBitSet;
 import static org.elasticsearch.search.internal.ExitableDirectoryReader.ExitableLeafReader;
@@ -162,6 +169,102 @@ public class ContextIndexSearcherTests extends ESTestCase {
         directory.close();
     }
 
+    public void testConcurrentSearchAllThreadsFinish() throws Exception {
+        final Directory directory = newDirectory();
+        IndexWriter iw = new IndexWriter(directory, new IndexWriterConfig(new StandardAnalyzer()).setMergePolicy(NoMergePolicy.INSTANCE));
+        final int numDocs = randomIntBetween(100, 200);
+        for (int i = 0; i < numDocs; i++) {
+            Document document = new Document();
+            document.add(new StringField("field", "value", Field.Store.NO));
+            iw.addDocument(document);
+            if (rarely()) {
+                iw.commit();
+            }
+        }
+
+        iw.close();
+        DirectoryReader directoryReader = DirectoryReader.open(directory);
+        ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(4);
+
+        AtomicInteger missingDocs = new AtomicInteger();
+        AtomicInteger visitDocs = new AtomicInteger(0);
+
+        CollectorManager<Collector, Void> collectorManager = new CollectorManager<>() {
+            boolean first = true;
+
+            @Override
+            public Collector newCollector() {
+                if (first) {
+                    first = false;
+                    return new Collector() {
+                        @Override
+                        public LeafCollector getLeafCollector(LeafReaderContext context) {
+                            missingDocs.set(context.reader().numDocs());
+                            throw new IllegalArgumentException("fake exception");
+                        }
+
+                        @Override
+                        public ScoreMode scoreMode() {
+                            return ScoreMode.COMPLETE;
+                        }
+                    };
+                } else {
+                    return new Collector() {
+                        @Override
+                        public LeafCollector getLeafCollector(LeafReaderContext context) {
+                            return new LeafBucketCollector() {
+                                @Override
+                                public void collect(int doc, long owningBucketOrd) {
+                                    while (true) {
+                                        int current = visitDocs.get();
+                                        if (visitDocs.compareAndSet(current, current + 1)) {
+                                            break;
+                                        }
+                                    }
+                                }
+                            };
+                        }
+
+                        @Override
+                        public ScoreMode scoreMode() {
+                            return ScoreMode.COMPLETE;
+                        }
+                    };
+                }
+            }
+
+            @Override
+            public Void reduce(Collection<Collector> collectors) {
+                return null;
+            }
+        };
+
+        ContextIndexSearcher searcher = new ContextIndexSearcher(
+            directoryReader,
+            IndexSearcher.getDefaultSimilarity(),
+            IndexSearcher.getDefaultQueryCache(),
+            IndexSearcher.getDefaultQueryCachingPolicy(),
+            randomBoolean(),
+            executor
+        ) {
+            @Override
+            protected LeafSlice[] slices(List<LeafReaderContext> leaves) {
+                return slices(leaves, 1, 1);
+            }
+        };
+
+        IllegalArgumentException exception = expectThrows(
+            IllegalArgumentException.class,
+            () -> searcher.search(new MatchAllDocsQuery(), collectorManager)
+        );
+        assertThat(exception.getMessage(), equalTo("fake exception"));
+
+        assertThat(visitDocs.get() + missingDocs.get(), equalTo(numDocs));
+        directoryReader.close();
+        directory.close();
+        executor.shutdown();
+    }
+
     public void testContextIndexSearcherSparseNoDeletions() throws IOException {
         doTestContextIndexSearcher(true, false);
     }