|
|
@@ -55,37 +55,76 @@ import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Objects;
|
|
|
import java.util.Set;
|
|
|
+import java.util.function.BiFunction;
|
|
|
+import java.util.function.Function;
|
|
|
+import java.util.function.Supplier;
|
|
|
+import java.util.stream.Stream;
|
|
|
|
|
|
import static java.lang.Math.max;
|
|
|
import static java.lang.Math.min;
|
|
|
import static java.util.Comparator.comparing;
|
|
|
+import static java.util.stream.Collectors.toList;
|
|
|
|
|
|
public class InternalTopHitsTests extends InternalAggregationTestCase<InternalTopHits> {
|
|
|
- /**
|
|
|
- * Should the test instances look like they are sorted by some fields (true) or sorted by score (false). Set here because these need
|
|
|
- * to be the same across the entirety of {@link #testReduceRandom()}.
|
|
|
- */
|
|
|
- private boolean testInstancesLookSortedByField;
|
|
|
- /**
|
|
|
- * Fields shared by all instances created by {@link #createTestInstance(String, Map)}.
|
|
|
- */
|
|
|
- private SortField[] testInstancesSortFields;
|
|
|
-
|
|
|
- /**
|
|
|
- * Collects all generated scores and fields to ensure that all scores are unique. That is necessary for deterministic results
|
|
|
- */
|
|
|
- private Set<Float> usedScores = new HashSet<>();
|
|
|
- private Set<Object> usedFields = new HashSet<>();
|
|
|
-
|
|
|
@Override
|
|
|
- public void setUp() throws Exception {
|
|
|
- super.setUp();
|
|
|
- testInstancesLookSortedByField = randomBoolean();
|
|
|
- testInstancesSortFields = testInstancesLookSortedByField ? randomSortFields() : new SortField[0];
|
|
|
+ protected InternalTopHits createTestInstance(String name, Map<String, Object> metadata) {
|
|
|
+ if (randomBoolean()) {
|
|
|
+ return createTestInstanceSortedByFields(name, metadata, ESTestCase::randomFloat,
|
|
|
+ randomSortFields(), InternalTopHitsTests::randomOfType);
|
|
|
+ }
|
|
|
+ return createTestInstanceSortedScore(name, metadata, ESTestCase::randomFloat);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- protected InternalTopHits createTestInstance(String name, Map<String, Object> metadata) {
|
|
|
+ protected List<InternalTopHits> randomResultsToReduce(String name, int size) {
|
|
|
+ /*
|
|
|
+ * Make sure all scores are unique so we can get
|
|
|
+ * deterministic test results.
|
|
|
+ */
|
|
|
+ Set<Float> usedScores = new HashSet<>();
|
|
|
+ Supplier<Float> scoreSupplier = () -> {
|
|
|
+ float score = randomValueOtherThanMany(usedScores::contains, ESTestCase::randomFloat);
|
|
|
+ usedScores.add(score);
|
|
|
+ return score;
|
|
|
+ };
|
|
|
+ Supplier<InternalTopHits> supplier;
|
|
|
+ if (randomBoolean()) {
|
|
|
+ SortField[] sortFields = randomSortFields();
|
|
|
+ Set<Object> usedSortFieldValues = new HashSet<>();
|
|
|
+ Function<SortField.Type, Object> sortFieldValueSuppier = t -> {
|
|
|
+ Object value = randomValueOtherThanMany(usedSortFieldValues::contains, () -> randomOfType(t));
|
|
|
+ usedSortFieldValues.add(value);
|
|
|
+ return value;
|
|
|
+ };
|
|
|
+ supplier = () -> createTestInstanceSortedByFields(name, null, scoreSupplier, sortFields, sortFieldValueSuppier);
|
|
|
+ } else {
|
|
|
+ supplier = () -> createTestInstanceSortedScore(name, null, scoreSupplier);
|
|
|
+ }
|
|
|
+ return Stream.generate(supplier).limit(size).collect(toList());
|
|
|
+ }
|
|
|
+
|
|
|
+ private InternalTopHits createTestInstanceSortedByFields(String name, Map<String, Object> metadata,
|
|
|
+ Supplier<Float> scoreSupplier, SortField[] sortFields, Function<SortField.Type, Object> sortFieldValueSupplier) {
|
|
|
+ return createTestInstance(name, metadata, scoreSupplier,
|
|
|
+ (docId, score) -> {
|
|
|
+ Object[] fields = new Object[sortFields.length];
|
|
|
+ for (int f = 0; f < sortFields.length; f++) {
|
|
|
+ final int ff = f;
|
|
|
+ fields[f] = sortFieldValueSupplier.apply(sortFields[ff].getType());
|
|
|
+ }
|
|
|
+ return new FieldDoc(docId, score, fields);
|
|
|
+ },
|
|
|
+ (totalHits, scoreDocs) -> new TopFieldDocs(totalHits, scoreDocs, sortFields),
|
|
|
+ sortFieldsComparator(sortFields));
|
|
|
+ }
|
|
|
+
|
|
|
+ private InternalTopHits createTestInstanceSortedScore(String name, Map<String, Object> metadata, Supplier<Float> scoreSupplier) {
|
|
|
+ return createTestInstance(name, metadata, scoreSupplier, ScoreDoc::new, TopDocs::new, scoreComparator());
|
|
|
+ }
|
|
|
+
|
|
|
+ private InternalTopHits createTestInstance(String name, Map<String, Object> metadata, Supplier<Float> scoreSupplier,
|
|
|
+ BiFunction<Integer, Float, ScoreDoc> docBuilder,
|
|
|
+ BiFunction<TotalHits, ScoreDoc[], TopDocs> topDocsBuilder, Comparator<ScoreDoc> comparator) {
|
|
|
int from = 0;
|
|
|
int requestedSize = between(1, 40);
|
|
|
int actualSize = between(0, requestedSize);
|
|
|
@@ -95,37 +134,21 @@ public class InternalTopHitsTests extends InternalAggregationTestCase<InternalTo
|
|
|
SearchHit[] hits = new SearchHit[actualSize];
|
|
|
Set<Integer> usedDocIds = new HashSet<>();
|
|
|
for (int i = 0; i < actualSize; i++) {
|
|
|
- float score = randomValueOtherThanMany(usedScores::contains, ESTestCase::randomFloat);
|
|
|
- usedScores.add(score);
|
|
|
+ float score = scoreSupplier.get();
|
|
|
maxScore = max(maxScore, score);
|
|
|
int docId = randomValueOtherThanMany(usedDocIds::contains, () -> between(0, IndexWriter.MAX_DOCS));
|
|
|
usedDocIds.add(docId);
|
|
|
|
|
|
Map<String, DocumentField> searchHitFields = new HashMap<>();
|
|
|
- if (testInstancesLookSortedByField) {
|
|
|
- Object[] fields = new Object[testInstancesSortFields.length];
|
|
|
- for (int f = 0; f < testInstancesSortFields.length; f++) {
|
|
|
- final int ff = f;
|
|
|
- fields[f] = randomValueOtherThanMany(usedFields::contains, () -> randomOfType(testInstancesSortFields[ff].getType()));
|
|
|
- usedFields.add(fields[f]);
|
|
|
- }
|
|
|
- scoreDocs[i] = new FieldDoc(docId, score, fields);
|
|
|
- } else {
|
|
|
- scoreDocs[i] = new ScoreDoc(docId, score);
|
|
|
- }
|
|
|
+ scoreDocs[i] = docBuilder.apply(docId, score);
|
|
|
hits[i] = new SearchHit(docId, Integer.toString(i), searchHitFields, Collections.emptyMap());
|
|
|
hits[i].score(score);
|
|
|
}
|
|
|
int totalHits = between(actualSize, 500000);
|
|
|
- sort(hits, scoreDocs, scoreDocComparator());
|
|
|
+ sort(hits, scoreDocs, comparator);
|
|
|
SearchHits searchHits = new SearchHits(hits, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), maxScore);
|
|
|
|
|
|
- TopDocs topDocs;
|
|
|
- if (testInstancesLookSortedByField) {
|
|
|
- topDocs = new TopFieldDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs, testInstancesSortFields);
|
|
|
- } else {
|
|
|
- topDocs = new TopDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs);
|
|
|
- }
|
|
|
+ TopDocs topDocs = topDocsBuilder.apply(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs);
|
|
|
// Lucene's TopDocs initializes the maxScore to Float.NaN, if there is no maxScore
|
|
|
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore == Float.NEGATIVE_INFINITY ? Float.NaN : maxScore);
|
|
|
|
|
|
@@ -176,7 +199,7 @@ public class InternalTopHitsTests extends InternalAggregationTestCase<InternalTo
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private Object randomOfType(SortField.Type type) {
|
|
|
+ private static Object randomOfType(SortField.Type type) {
|
|
|
switch (type) {
|
|
|
case CUSTOM:
|
|
|
throw new UnsupportedOperationException();
|
|
|
@@ -205,6 +228,14 @@ public class InternalTopHitsTests extends InternalAggregationTestCase<InternalTo
|
|
|
|
|
|
@Override
|
|
|
protected void assertReduced(InternalTopHits reduced, List<InternalTopHits> inputs) {
|
|
|
+ boolean sortedByFields = inputs.get(0).getTopDocs().topDocs instanceof TopFieldDocs;
|
|
|
+ Comparator<ScoreDoc> dataNodeComparator;
|
|
|
+ if (sortedByFields) {
|
|
|
+ dataNodeComparator = sortFieldsComparator(((TopFieldDocs) inputs.get(0).getTopDocs().topDocs).fields);
|
|
|
+ } else {
|
|
|
+ dataNodeComparator = scoreComparator();
|
|
|
+ }
|
|
|
+ Comparator<ScoreDoc> reducedComparator = dataNodeComparator.thenComparing(s -> s.shardIndex);
|
|
|
SearchHits actualHits = reduced.getHits();
|
|
|
List<Tuple<ScoreDoc, SearchHit>> allHits = new ArrayList<>();
|
|
|
float maxScore = Float.NEGATIVE_INFINITY;
|
|
|
@@ -219,7 +250,7 @@ public class InternalTopHitsTests extends InternalAggregationTestCase<InternalTo
|
|
|
maxScore = max(maxScore, internalHits.getMaxScore());
|
|
|
for (int i = 0; i < internalHits.getHits().length; i++) {
|
|
|
ScoreDoc doc = inputs.get(input).getTopDocs().topDocs.scoreDocs[i];
|
|
|
- if (testInstancesLookSortedByField) {
|
|
|
+ if (sortedByFields) {
|
|
|
doc = new FieldDoc(doc.doc, doc.score, ((FieldDoc) doc).fields, input);
|
|
|
} else {
|
|
|
doc = new ScoreDoc(doc.doc, doc.score, input);
|
|
|
@@ -227,7 +258,7 @@ public class InternalTopHitsTests extends InternalAggregationTestCase<InternalTo
|
|
|
allHits.add(new Tuple<>(doc, internalHits.getHits()[i]));
|
|
|
}
|
|
|
}
|
|
|
- allHits.sort(comparing(Tuple::v1, scoreDocComparator()));
|
|
|
+ allHits.sort(comparing(Tuple::v1, reducedComparator));
|
|
|
SearchHit[] expectedHitsHits = new SearchHit[min(inputs.get(0).getSize(), allHits.size())];
|
|
|
for (int i = 0; i < expectedHitsHits.length; i++) {
|
|
|
expectedHitsHits[i] = allHits.get(i).v2();
|
|
|
@@ -289,36 +320,32 @@ public class InternalTopHitsTests extends InternalAggregationTestCase<InternalTo
|
|
|
return sortFields;
|
|
|
}
|
|
|
|
|
|
- private Comparator<ScoreDoc> scoreDocComparator() {
|
|
|
- return innerScoreDocComparator().thenComparing(s -> s.shardIndex);
|
|
|
- }
|
|
|
-
|
|
|
- private Comparator<ScoreDoc> innerScoreDocComparator() {
|
|
|
- if (testInstancesLookSortedByField) {
|
|
|
+ private Comparator<ScoreDoc> sortFieldsComparator(SortField[] sortFields) {
|
|
|
+ @SuppressWarnings("rawtypes")
|
|
|
+ FieldComparator[] comparators = new FieldComparator[sortFields.length];
|
|
|
+ for (int i = 0; i < sortFields.length; i++) {
|
|
|
// Values passed to getComparator shouldn't matter
|
|
|
- @SuppressWarnings("rawtypes")
|
|
|
- FieldComparator[] comparators = new FieldComparator[testInstancesSortFields.length];
|
|
|
- for (int i = 0; i < testInstancesSortFields.length; i++) {
|
|
|
- comparators[i] = testInstancesSortFields[i].getComparator(0, 0);
|
|
|
- }
|
|
|
- return (lhs, rhs) -> {
|
|
|
- FieldDoc l = (FieldDoc) lhs;
|
|
|
- FieldDoc r = (FieldDoc) rhs;
|
|
|
- int i = 0;
|
|
|
- while (i < l.fields.length) {
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
- int c = comparators[i].compareValues(l.fields[i], r.fields[i]);
|
|
|
- if (c != 0) {
|
|
|
- return c;
|
|
|
- }
|
|
|
- i++;
|
|
|
- }
|
|
|
- return 0;
|
|
|
- };
|
|
|
- } else {
|
|
|
- Comparator<ScoreDoc> comparator = comparing(d -> d.score);
|
|
|
- return comparator.reversed();
|
|
|
+ comparators[i] = sortFields[i].getComparator(0, 0);
|
|
|
}
|
|
|
+ return (lhs, rhs) -> {
|
|
|
+ FieldDoc l = (FieldDoc) lhs;
|
|
|
+ FieldDoc r = (FieldDoc) rhs;
|
|
|
+ int i = 0;
|
|
|
+ while (i < l.fields.length) {
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ int c = comparators[i].compareValues(l.fields[i], r.fields[i]);
|
|
|
+ if (c != 0) {
|
|
|
+ return c;
|
|
|
+ }
|
|
|
+ i++;
|
|
|
+ }
|
|
|
+ return 0;
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ private Comparator<ScoreDoc> scoreComparator() {
|
|
|
+ Comparator<ScoreDoc> comparator = comparing(d -> d.score);
|
|
|
+ return comparator.reversed();
|
|
|
}
|
|
|
|
|
|
@Override
|