|
@@ -14,12 +14,12 @@ import org.apache.lucene.search.FieldDoc;
|
|
|
import org.apache.lucene.search.LeafCollector;
|
|
|
import org.apache.lucene.search.Query;
|
|
|
import org.apache.lucene.search.ScoreDoc;
|
|
|
-import org.apache.lucene.search.ScoreMode;
|
|
|
import org.apache.lucene.search.Sort;
|
|
|
import org.apache.lucene.search.SortField;
|
|
|
import org.apache.lucene.search.TopDocsCollector;
|
|
|
import org.apache.lucene.search.TopFieldCollectorManager;
|
|
|
import org.apache.lucene.search.TopScoreDocCollectorManager;
|
|
|
+import org.apache.lucene.search.Weight;
|
|
|
import org.elasticsearch.common.Strings;
|
|
|
import org.elasticsearch.compute.data.BlockFactory;
|
|
|
import org.elasticsearch.compute.data.DocBlock;
|
|
@@ -36,6 +36,7 @@ import org.elasticsearch.search.sort.SortAndFormats;
|
|
|
import org.elasticsearch.search.sort.SortBuilder;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
+import java.io.UncheckedIOException;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.List;
|
|
@@ -43,9 +44,6 @@ import java.util.Optional;
|
|
|
import java.util.function.Function;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
-import static org.apache.lucene.search.ScoreMode.TOP_DOCS;
|
|
|
-import static org.apache.lucene.search.ScoreMode.TOP_DOCS_WITH_SCORES;
|
|
|
-
|
|
|
/**
|
|
|
* Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
|
|
|
*/
|
|
@@ -63,16 +61,16 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
|
|
|
int maxPageSize,
|
|
|
int limit,
|
|
|
List<SortBuilder<?>> sorts,
|
|
|
- boolean scoring
|
|
|
+ boolean needsScore
|
|
|
) {
|
|
|
- super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? TOP_DOCS_WITH_SCORES : TOP_DOCS);
|
|
|
+ super(contexts, weightFunction(queryFunction, sorts, needsScore), dataPartitioning, taskConcurrency, limit, needsScore);
|
|
|
this.maxPageSize = maxPageSize;
|
|
|
this.sorts = sorts;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public SourceOperator get(DriverContext driverContext) {
|
|
|
- return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, scoreMode);
|
|
|
+ return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, needsScore);
|
|
|
}
|
|
|
|
|
|
public int maxPageSize() {
|
|
@@ -88,8 +86,8 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
|
|
|
+ maxPageSize
|
|
|
+ ", limit = "
|
|
|
+ limit
|
|
|
- + ", scoreMode = "
|
|
|
- + scoreMode
|
|
|
+ + ", needsScore = "
|
|
|
+ + needsScore
|
|
|
+ ", sorts = ["
|
|
|
+ notPrettySorts
|
|
|
+ "]]";
|
|
@@ -108,7 +106,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
|
|
|
private PerShardCollector perShardCollector;
|
|
|
private final List<SortBuilder<?>> sorts;
|
|
|
private final int limit;
|
|
|
- private final ScoreMode scoreMode;
|
|
|
+ private final boolean needsScore;
|
|
|
|
|
|
public LuceneTopNSourceOperator(
|
|
|
BlockFactory blockFactory,
|
|
@@ -116,12 +114,12 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
|
|
|
List<SortBuilder<?>> sorts,
|
|
|
int limit,
|
|
|
LuceneSliceQueue sliceQueue,
|
|
|
- ScoreMode scoreMode
|
|
|
+ boolean needsScore
|
|
|
) {
|
|
|
super(blockFactory, maxPageSize, sliceQueue);
|
|
|
this.sorts = sorts;
|
|
|
this.limit = limit;
|
|
|
- this.scoreMode = scoreMode;
|
|
|
+ this.needsScore = needsScore;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -163,7 +161,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
|
|
|
try {
|
|
|
if (perShardCollector == null || perShardCollector.shardContext.index() != scorer.shardContext().index()) {
|
|
|
// TODO: share the bottom between shardCollectors
|
|
|
- perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, limit);
|
|
|
+ perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, needsScore, limit);
|
|
|
}
|
|
|
var leafCollector = perShardCollector.getLeafCollector(scorer.leafReaderContext());
|
|
|
scorer.scoreNextRange(leafCollector, scorer.leafReaderContext().reader().getLiveDocs(), maxPageSize);
|
|
@@ -261,7 +259,7 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
|
|
|
}
|
|
|
|
|
|
private DoubleVector.Builder scoreVectorOrNull(int size) {
|
|
|
- if (scoreMode.needsScores()) {
|
|
|
+ if (needsScore) {
|
|
|
return blockFactory.newDoubleVectorFixedBuilder(size);
|
|
|
} else {
|
|
|
return null;
|
|
@@ -271,37 +269,11 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
|
|
|
@Override
|
|
|
protected void describe(StringBuilder sb) {
|
|
|
sb.append(", limit = ").append(limit);
|
|
|
- sb.append(", scoreMode = ").append(scoreMode);
|
|
|
+ sb.append(", needsScore = ").append(needsScore);
|
|
|
String notPrettySorts = sorts.stream().map(Strings::toString).collect(Collectors.joining(","));
|
|
|
sb.append(", sorts = [").append(notPrettySorts).append("]");
|
|
|
}
|
|
|
|
|
|
- PerShardCollector newPerShardCollector(ShardContext shardContext, List<SortBuilder<?>> sorts, int limit) throws IOException {
|
|
|
- Optional<SortAndFormats> sortAndFormats = shardContext.buildSort(sorts);
|
|
|
- if (sortAndFormats.isEmpty()) {
|
|
|
- throw new IllegalStateException("sorts must not be disabled in TopN");
|
|
|
- }
|
|
|
- if (scoreMode.needsScores() == false) {
|
|
|
- return new NonScoringPerShardCollector(shardContext, sortAndFormats.get().sort, limit);
|
|
|
- } else {
|
|
|
- SortField[] sortFields = sortAndFormats.get().sort.getSort();
|
|
|
- if (sortFields != null && sortFields.length == 1 && sortFields[0].needsScores() && sortFields[0].getReverse() == false) {
|
|
|
- // SORT _score DESC
|
|
|
- return new ScoringPerShardCollector(shardContext, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
|
|
|
- } else {
|
|
|
- // SORT ..., _score, ...
|
|
|
- var sort = new Sort();
|
|
|
- if (sortFields != null) {
|
|
|
- var l = new ArrayList<>(Arrays.asList(sortFields));
|
|
|
- l.add(SortField.FIELD_DOC);
|
|
|
- l.add(SortField.FIELD_SCORE);
|
|
|
- sort = new Sort(l.toArray(SortField[]::new));
|
|
|
- }
|
|
|
- return new ScoringPerShardCollector(shardContext, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
abstract static class PerShardCollector {
|
|
|
private final ShardContext shardContext;
|
|
|
private final TopDocsCollector<?> collector;
|
|
@@ -336,4 +308,45 @@ public final class LuceneTopNSourceOperator extends LuceneOperator {
|
|
|
super(shardContext, topDocsCollector);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ private static Function<ShardContext, Weight> weightFunction(
|
|
|
+ Function<ShardContext, Query> queryFunction,
|
|
|
+ List<SortBuilder<?>> sorts,
|
|
|
+ boolean needsScore
|
|
|
+ ) {
|
|
|
+ return ctx -> {
|
|
|
+ final var query = queryFunction.apply(ctx);
|
|
|
+ final var searcher = ctx.searcher();
|
|
|
+ try {
|
|
|
+ // we create a collector with a limit of 1 to determine the appropriate score mode to use.
|
|
|
+ var scoreMode = newPerShardCollector(ctx, sorts, needsScore, 1).collector.scoreMode();
|
|
|
+ return searcher.createWeight(searcher.rewrite(query), scoreMode, 1);
|
|
|
+ } catch (IOException e) {
|
|
|
+ throw new UncheckedIOException(e);
|
|
|
+ }
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ private static PerShardCollector newPerShardCollector(ShardContext context, List<SortBuilder<?>> sorts, boolean needsScore, int limit)
|
|
|
+ throws IOException {
|
|
|
+ Optional<SortAndFormats> sortAndFormats = context.buildSort(sorts);
|
|
|
+ if (sortAndFormats.isEmpty()) {
|
|
|
+ throw new IllegalStateException("sorts must not be disabled in TopN");
|
|
|
+ }
|
|
|
+ if (needsScore == false) {
|
|
|
+ return new NonScoringPerShardCollector(context, sortAndFormats.get().sort, limit);
|
|
|
+ }
|
|
|
+ Sort sort = sortAndFormats.get().sort;
|
|
|
+ if (Sort.RELEVANCE.equals(sort)) {
|
|
|
+ // SORT _score DESC
|
|
|
+ return new ScoringPerShardCollector(context, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
|
|
|
+ }
|
|
|
+
|
|
|
+ // SORT ..., _score, ...
|
|
|
+ var l = new ArrayList<>(Arrays.asList(sort.getSort()));
|
|
|
+ l.add(SortField.FIELD_DOC);
|
|
|
+ l.add(SortField.FIELD_SCORE);
|
|
|
+ sort = new Sort(l.toArray(SortField[]::new));
|
|
|
+ return new ScoringPerShardCollector(context, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
|
|
|
+ }
|
|
|
}
|