|
@@ -0,0 +1,216 @@
|
|
|
+/*
|
|
|
+ * Licensed to Elasticsearch under one or more contributor
|
|
|
+ * license agreements. See the NOTICE file distributed with
|
|
|
+ * this work for additional information regarding copyright
|
|
|
+ * ownership. Elasticsearch licenses this file to you under
|
|
|
+ * the Apache License, Version 2.0 (the "License"); you may
|
|
|
+ * not use this file except in compliance with the License.
|
|
|
+ * You may obtain a copy of the License at
|
|
|
+ *
|
|
|
+ * http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+ *
|
|
|
+ * Unless required by applicable law or agreed to in writing,
|
|
|
+ * software distributed under the License is distributed on an
|
|
|
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
|
+ * KIND, either express or implied. See the License for the
|
|
|
+ * specific language governing permissions and limitations
|
|
|
+ * under the License.
|
|
|
+ */
|
|
|
+
|
|
|
+package org.elasticsearch.search.aggregations.metrics.tophits;
|
|
|
+
|
|
|
+import org.apache.lucene.index.IndexWriter;
|
|
|
+import org.apache.lucene.search.FieldComparator;
|
|
|
+import org.apache.lucene.search.FieldDoc;
|
|
|
+import org.apache.lucene.search.ScoreDoc;
|
|
|
+import org.apache.lucene.search.SortField;
|
|
|
+import org.apache.lucene.search.TopDocs;
|
|
|
+import org.apache.lucene.search.TopFieldDocs;
|
|
|
+import org.apache.lucene.util.BytesRef;
|
|
|
+import org.elasticsearch.common.collect.Tuple;
|
|
|
+import org.elasticsearch.common.io.stream.Writeable.Reader;
|
|
|
+import org.elasticsearch.common.text.Text;
|
|
|
+import org.elasticsearch.search.SearchHitField;
|
|
|
+import org.elasticsearch.search.aggregations.InternalAggregationTestCase;
|
|
|
+import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
|
|
|
+import org.elasticsearch.search.internal.InternalSearchHit;
|
|
|
+import org.elasticsearch.search.internal.InternalSearchHits;
|
|
|
+
|
|
|
+import java.io.IOException;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.Arrays;
|
|
|
+import java.util.Comparator;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.HashSet;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.Set;
|
|
|
+
|
|
|
+import static java.lang.Math.max;
|
|
|
+import static java.lang.Math.min;
|
|
|
+import static java.util.Comparator.comparing;
|
|
|
+
|
|
|
+public class InternalTopHitsTests extends InternalAggregationTestCase<InternalTopHits> {
|
|
|
+ /**
|
|
|
+ * Should the test instances look like they are sorted by some fields (true) or sorted by score (false). Set here because these need
|
|
|
+ * to be the same across the entirety of {@link #testReduceRandom()}.
|
|
|
+ */
|
|
|
+ private final boolean testInstancesLookSortedByField = randomBoolean();
|
|
|
+ /**
|
|
|
+ * Fields shared by all instances created by {@link #createTestInstance(String, List, Map)}.
|
|
|
+ */
|
|
|
+ private final SortField[] testInstancesSortFields = testInstancesLookSortedByField ? randomSortFields() : new SortField[0];
|
|
|
+
|
|
|
+ @Override
|
|
|
+ protected InternalTopHits createTestInstance(String name, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
|
|
|
+ int from = 0;
|
|
|
+ int requestedSize = between(1, 40);
|
|
|
+ int actualSize = between(0, requestedSize);
|
|
|
+
|
|
|
+ float maxScore = Float.MIN_VALUE;
|
|
|
+ ScoreDoc[] scoreDocs = new ScoreDoc[actualSize];
|
|
|
+ InternalSearchHit[] hits = new InternalSearchHit[actualSize];
|
|
|
+ Set<Integer> usedDocIds = new HashSet<>();
|
|
|
+ for (int i = 0; i < actualSize; i++) {
|
|
|
+ float score = randomFloat();
|
|
|
+ maxScore = max(maxScore, score);
|
|
|
+ int docId = randomValueOtherThanMany(usedDocIds::contains, () -> between(0, IndexWriter.MAX_DOCS));
|
|
|
+ usedDocIds.add(docId);
|
|
|
+
|
|
|
+ Map<String, SearchHitField> searchHitFields = new HashMap<>();
|
|
|
+ if (testInstancesLookSortedByField) {
|
|
|
+ Object[] fields = new Object[testInstancesSortFields.length];
|
|
|
+ for (int f = 0; f < testInstancesSortFields.length; f++) {
|
|
|
+ fields[f] = randomOfType(testInstancesSortFields[f].getType());
|
|
|
+ }
|
|
|
+ scoreDocs[i] = new FieldDoc(docId, score, fields);
|
|
|
+ } else {
|
|
|
+ scoreDocs[i] = new ScoreDoc(docId, score);
|
|
|
+ }
|
|
|
+ hits[i] = new InternalSearchHit(docId, Integer.toString(i), new Text("test"), searchHitFields);
|
|
|
+ hits[i].score(score);
|
|
|
+ }
|
|
|
+ int totalHits = between(actualSize, 500000);
|
|
|
+ InternalSearchHits internalSearchHits = new InternalSearchHits(hits, totalHits, maxScore);
|
|
|
+
|
|
|
+ TopDocs topDocs;
|
|
|
+ Arrays.sort(scoreDocs, scoreDocComparator());
|
|
|
+ if (testInstancesLookSortedByField) {
|
|
|
+ topDocs = new TopFieldDocs(totalHits, scoreDocs, testInstancesSortFields, maxScore);
|
|
|
+ } else {
|
|
|
+ topDocs = new TopDocs(totalHits, scoreDocs, maxScore);
|
|
|
+ }
|
|
|
+
|
|
|
+ return new InternalTopHits(name, from, requestedSize, topDocs, internalSearchHits, pipelineAggregators, metaData);
|
|
|
+ }
|
|
|
+
|
|
|
+ private Object randomOfType(SortField.Type type) {
|
|
|
+ switch (type) {
|
|
|
+ case CUSTOM:
|
|
|
+ throw new UnsupportedOperationException();
|
|
|
+ case DOC:
|
|
|
+ return between(0, IndexWriter.MAX_DOCS);
|
|
|
+ case DOUBLE:
|
|
|
+ return randomDouble();
|
|
|
+ case FLOAT:
|
|
|
+ return randomFloat();
|
|
|
+ case INT:
|
|
|
+ return randomInt();
|
|
|
+ case LONG:
|
|
|
+ return randomLong();
|
|
|
+ case REWRITEABLE:
|
|
|
+ throw new UnsupportedOperationException();
|
|
|
+ case SCORE:
|
|
|
+ return randomFloat();
|
|
|
+ case STRING:
|
|
|
+ return new BytesRef(randomAsciiOfLength(5));
|
|
|
+ case STRING_VAL:
|
|
|
+ return new BytesRef(randomAsciiOfLength(5));
|
|
|
+ default:
|
|
|
+ throw new UnsupportedOperationException("Unkown SortField.Type: " + type);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ protected void assertReduced(InternalTopHits reduced, List<InternalTopHits> inputs) {
|
|
|
+ InternalSearchHits actualHits = (InternalSearchHits) reduced.getHits();
|
|
|
+ List<Tuple<ScoreDoc, InternalSearchHit>> allHits = new ArrayList<>();
|
|
|
+ float maxScore = Float.MIN_VALUE;
|
|
|
+ long totalHits = 0;
|
|
|
+ for (int input = 0; input < inputs.size(); input++) {
|
|
|
+ InternalSearchHits internalHits = (InternalSearchHits) inputs.get(input).getHits();
|
|
|
+ totalHits += internalHits.totalHits();
|
|
|
+ maxScore = max(maxScore, internalHits.maxScore());
|
|
|
+ for (int i = 0; i < internalHits.internalHits().length; i++) {
|
|
|
+ ScoreDoc doc = inputs.get(input).getTopDocs().scoreDocs[i];
|
|
|
+ if (testInstancesLookSortedByField) {
|
|
|
+ doc = new FieldDoc(doc.doc, doc.score, ((FieldDoc) doc).fields, input);
|
|
|
+ } else {
|
|
|
+ doc = new ScoreDoc(doc.doc, doc.score, input);
|
|
|
+ }
|
|
|
+ allHits.add(new Tuple<>(doc, internalHits.internalHits()[i]));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ allHits.sort(comparing(Tuple::v1, scoreDocComparator()));
|
|
|
+ InternalSearchHit[] expectedHitsHits = new InternalSearchHit[min(inputs.get(0).getSize(), allHits.size())];
|
|
|
+ for (int i = 0; i < expectedHitsHits.length; i++) {
|
|
|
+ expectedHitsHits[i] = allHits.get(i).v2();
|
|
|
+ }
|
|
|
+ InternalSearchHits expectedHits = new InternalSearchHits(expectedHitsHits, totalHits, maxScore);
|
|
|
+ assertEqualsWithErrorMessageFromXContent(expectedHits, actualHits);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ protected Reader<InternalTopHits> instanceReader() {
|
|
|
+ return InternalTopHits::new;
|
|
|
+ }
|
|
|
+
|
|
|
+ private SortField[] randomSortFields() {
|
|
|
+ SortField[] sortFields = new SortField[between(1, 5)];
|
|
|
+ Set<String> usedSortFields = new HashSet<>();
|
|
|
+ for (int i = 0; i < sortFields.length; i++) {
|
|
|
+ String sortField = randomValueOtherThanMany(usedSortFields::contains, () -> randomAsciiOfLength(5));
|
|
|
+ usedSortFields.add(sortField);
|
|
|
+ SortField.Type type = randomValueOtherThanMany(t -> t == SortField.Type.CUSTOM || t == SortField.Type.REWRITEABLE,
|
|
|
+ () -> randomFrom(SortField.Type.values()));
|
|
|
+ sortFields[i] = new SortField(sortField, type);
|
|
|
+ }
|
|
|
+ return sortFields;
|
|
|
+ }
|
|
|
+
|
|
|
+ private Comparator<ScoreDoc> scoreDocComparator() {
|
|
|
+ return innerScoreDocComparator().thenComparing(s -> s.shardIndex);
|
|
|
+ }
|
|
|
+
|
|
|
+ private Comparator<ScoreDoc> innerScoreDocComparator() {
|
|
|
+ if (testInstancesLookSortedByField) {
|
|
|
+ // Values passed to getComparator shouldn't matter
|
|
|
+ @SuppressWarnings("rawtypes")
|
|
|
+ FieldComparator[] comparators = new FieldComparator[testInstancesSortFields.length];
|
|
|
+ for (int i = 0; i < testInstancesSortFields.length; i++) {
|
|
|
+ try {
|
|
|
+ comparators[i] = testInstancesSortFields[i].getComparator(0, 0);
|
|
|
+ } catch (IOException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return (lhs, rhs) -> {
|
|
|
+ FieldDoc l = (FieldDoc) lhs;
|
|
|
+ FieldDoc r = (FieldDoc) rhs;
|
|
|
+ int i = 0;
|
|
|
+ while (i < l.fields.length) {
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ int c = comparators[i].compareValues(l.fields[i], r.fields[i]);
|
|
|
+ if (c != 0) {
|
|
|
+ return c;
|
|
|
+ }
|
|
|
+ i++;
|
|
|
+ }
|
|
|
+ return 0;
|
|
|
+ };
|
|
|
+ } else {
|
|
|
+ Comparator<ScoreDoc> comparator = comparing(d -> d.score);
|
|
|
+ return comparator.reversed();
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|