Преглед изворни кода

Add sort and collapse info to SearchHits transport serialization (#36555)

In order for CCS alternate execution mode (see #32125) to be able to do the final reduction step on the CCS coordinating node, we need to serialize additional info in the transport layer as part of the `SearchHits`, specifically:

- lucene `SortField[]` which contains info about the fields that sorting was performed on and their type, which depends on mappings (that the CCS node does not know about)
- collapse field (`String`) that field collapsing was executed on, if requested
- collapse values (`Object[]`) that field collapsing was based on, if requested

This info is needed to be able to reconstruct the `TopFieldDocs` or `CollapseFieldTopDocs` in the CCS coordinating node to feed the `mergeTopDocs` method and reduce multiple search responses received (one per cluster) into one.

This commit adds such information to the `SearchHits` class. It's nullable info that is not serialized through the REST layer. `SearchPhaseController` sets such info at the end of the hits reduction phase.
Luca Cavanna пре 6 година
родитељ
комит
7dc3d3b78b

+ 3 - 2
server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java

@@ -109,8 +109,9 @@ final class FetchSearchPhase extends SearchPhase {
             // query AND fetch optimization
             finishPhase.run();
         } else {
-            final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(numShards, reducedQueryPhase.scoreDocs);
-            if (reducedQueryPhase.scoreDocs.length == 0) { // no docs to fetch -- sidestep everything and return
+            ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs.scoreDocs;
+            final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(numShards, scoreDocs);
+            if (scoreDocs.length == 0) { // no docs to fetch -- sidestep everything and return
                 phaseResults.stream()
                     .map(SearchPhaseResult::queryResult)
                     .forEach(this::releaseIrrelevantSearchContext); // we have to release contexts here to free up resources

+ 47 - 42
server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

@@ -211,18 +211,23 @@ public final class SearchPhaseController {
                     }
                 }
             }
-            final boolean isSortedByField;
-            final SortField[] sortFields;
+            boolean isSortedByField = false;
+            SortField[] sortFields = null;
+            String collapseField = null;
+            Object[] collapseValues = null;
             if (mergedTopDocs instanceof TopFieldDocs) {
                 TopFieldDocs fieldDocs = (TopFieldDocs) mergedTopDocs;
-                isSortedByField = (fieldDocs instanceof CollapseTopFieldDocs &&
-                    fieldDocs.fields.length == 1 && fieldDocs.fields[0].getType() == SortField.Type.SCORE) == false;
                 sortFields = fieldDocs.fields;
-            } else {
-                isSortedByField = false;
-                sortFields = null;
+                if (fieldDocs instanceof CollapseTopFieldDocs) {
+                    isSortedByField = (fieldDocs.fields.length == 1 && fieldDocs.fields[0].getType() == SortField.Type.SCORE) == false;
+                    CollapseTopFieldDocs collapseTopFieldDocs = (CollapseTopFieldDocs) fieldDocs;
+                    collapseField = collapseTopFieldDocs.field;
+                    collapseValues = collapseTopFieldDocs.collapseValues;
+                } else {
+                    isSortedByField = true;
+                }
             }
-            return new SortedTopDocs(scoreDocs, isSortedByField, sortFields);
+            return new SortedTopDocs(scoreDocs, isSortedByField, sortFields, collapseField, collapseValues);
         } else {
             // no relevant docs
             return SortedTopDocs.EMPTY;
@@ -266,7 +271,7 @@ public final class SearchPhaseController {
     public ScoreDoc[] getLastEmittedDocPerShard(ReducedQueryPhase reducedQueryPhase, int numShards) {
         final ScoreDoc[] lastEmittedDocPerShard = new ScoreDoc[numShards];
         if (reducedQueryPhase.isEmptyResult == false) {
-            final ScoreDoc[] sortedScoreDocs = reducedQueryPhase.scoreDocs;
+            final ScoreDoc[] sortedScoreDocs = reducedQueryPhase.sortedTopDocs.scoreDocs;
             // from is always zero as when we use scroll, we ignore from
             long size = Math.min(reducedQueryPhase.fetchHits, reducedQueryPhase.size);
             // with collapsing we can have more hits than sorted docs
@@ -307,7 +312,7 @@ public final class SearchPhaseController {
         if (reducedQueryPhase.isEmptyResult) {
             return InternalSearchResponse.empty();
         }
-        ScoreDoc[] sortedDocs = reducedQueryPhase.scoreDocs;
+        ScoreDoc[] sortedDocs = reducedQueryPhase.sortedTopDocs.scoreDocs;
         SearchHits hits = getHits(reducedQueryPhase, ignoreFrom, fetchResults, resultsLookup);
         if (reducedQueryPhase.suggest != null) {
             if (!fetchResults.isEmpty()) {
@@ -345,12 +350,12 @@ public final class SearchPhaseController {
 
     private SearchHits getHits(ReducedQueryPhase reducedQueryPhase, boolean ignoreFrom,
                                Collection<? extends SearchPhaseResult> fetchResults, IntFunction<SearchPhaseResult> resultsLookup) {
-        final boolean sorted = reducedQueryPhase.isSortedByField;
-        ScoreDoc[] sortedDocs = reducedQueryPhase.scoreDocs;
+        SortedTopDocs sortedTopDocs = reducedQueryPhase.sortedTopDocs;
         int sortScoreIndex = -1;
-        if (sorted) {
-            for (int i = 0; i < reducedQueryPhase.sortField.length; i++) {
-                if (reducedQueryPhase.sortField[i].getType() == SortField.Type.SCORE) {
+        if (sortedTopDocs.isSortedByField) {
+            SortField[] sortFields = sortedTopDocs.sortFields;
+            for (int i = 0; i < sortFields.length; i++) {
+                if (sortFields[i].getType() == SortField.Type.SCORE) {
                     sortScoreIndex = i;
                 }
             }
@@ -362,12 +367,12 @@ public final class SearchPhaseController {
         int from = ignoreFrom ? 0 : reducedQueryPhase.from;
         int numSearchHits = (int) Math.min(reducedQueryPhase.fetchHits - from, reducedQueryPhase.size);
         // with collapsing we can have more fetch hits than sorted docs
-        numSearchHits = Math.min(sortedDocs.length, numSearchHits);
+        numSearchHits = Math.min(sortedTopDocs.scoreDocs.length, numSearchHits);
         // merge hits
         List<SearchHit> hits = new ArrayList<>();
         if (!fetchResults.isEmpty()) {
             for (int i = 0; i < numSearchHits; i++) {
-                ScoreDoc shardDoc = sortedDocs[i];
+                ScoreDoc shardDoc = sortedTopDocs.scoreDocs[i];
                 SearchPhaseResult fetchResultProvider = resultsLookup.apply(shardDoc.shardIndex);
                 if (fetchResultProvider == null) {
                     // this can happen if we are hitting a shard failure during the fetch phase
@@ -381,21 +386,21 @@ public final class SearchPhaseController {
                 assert index < fetchResult.hits().getHits().length : "not enough hits fetched. index [" + index + "] length: "
                     + fetchResult.hits().getHits().length;
                 SearchHit searchHit = fetchResult.hits().getHits()[index];
-                if (sorted == false) {
-                    searchHit.score(shardDoc.score);
-                }
                 searchHit.shard(fetchResult.getSearchShardTarget());
-                if (sorted) {
+                if (sortedTopDocs.isSortedByField) {
                     FieldDoc fieldDoc = (FieldDoc) shardDoc;
                     searchHit.sortValues(fieldDoc.fields, reducedQueryPhase.sortValueFormats);
                     if (sortScoreIndex != -1) {
                         searchHit.score(((Number) fieldDoc.fields[sortScoreIndex]).floatValue());
                     }
+                } else {
+                    searchHit.score(shardDoc.score);
                 }
                 hits.add(searchHit);
             }
         }
-        return new SearchHits(hits.toArray(new SearchHit[0]), reducedQueryPhase.totalHits, reducedQueryPhase.maxScore);
+        return new SearchHits(hits.toArray(new SearchHit[0]), reducedQueryPhase.totalHits,
+            reducedQueryPhase.maxScore, sortedTopDocs.sortFields, sortedTopDocs.collapseField, sortedTopDocs.collapseValues);
     }
 
     /**
@@ -436,8 +441,7 @@ public final class SearchPhaseController {
         if (queryResults.isEmpty()) { // early terminate we have nothing to reduce
             final TotalHits totalHits = topDocsStats.getTotalHits();
             return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.maxScore,
-                timedOut, terminatedEarly, null, null, null, EMPTY_DOCS, null,
-                null, numReducePhases, false, 0, 0, true);
+                timedOut, terminatedEarly, null, null, null, SortedTopDocs.EMPTY, null, numReducePhases, 0, 0, true);
         }
         final QuerySearchResult firstResult = queryResults.stream().findFirst().get().queryResult();
         final boolean hasSuggest = firstResult.suggest() != null;
@@ -499,11 +503,11 @@ public final class SearchPhaseController {
         final InternalAggregations aggregations = aggregationsList.isEmpty() ? null : reduceAggs(aggregationsList,
             firstResult.pipelineAggregators(), reduceContext);
         final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
-        final SortedTopDocs scoreDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size);
+        final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size);
         final TotalHits totalHits = topDocsStats.getTotalHits();
         return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.maxScore,
-            timedOut, terminatedEarly, suggest, aggregations, shardResults, scoreDocs.scoreDocs, scoreDocs.sortFields,
-            firstResult.sortValueFormats(), numReducePhases, scoreDocs.isSortedByField, size, from, false);
+            timedOut, terminatedEarly, suggest, aggregations, shardResults, sortedTopDocs,
+            firstResult.sortValueFormats(), numReducePhases, size, from, firstResult == null);
     }
 
     /**
@@ -551,12 +555,8 @@ public final class SearchPhaseController {
         final SearchProfileShardResults shardResults;
         // the number of reduces phases
         final int numReducePhases;
-        // the searches merged top docs
-        final ScoreDoc[] scoreDocs;
-        // the top docs sort fields used to sort the score docs, <code>null</code> if the results are not sorted
-        final SortField[] sortField;
-        // <code>true</code> iff the result score docs is sorted by a field (not score), this implies that <code>sortField</code> is set.
-        final boolean isSortedByField;
+        //encloses info about the merged top docs, the sort fields used to sort the score docs etc.
+        final SortedTopDocs sortedTopDocs;
         // the size of the top hits to return
         final int size;
         // <code>true</code> iff the query phase had no results. Otherwise <code>false</code>
@@ -567,9 +567,8 @@ public final class SearchPhaseController {
         final DocValueFormat[] sortValueFormats;
 
         ReducedQueryPhase(TotalHits totalHits, long fetchHits, float maxScore, boolean timedOut, Boolean terminatedEarly, Suggest suggest,
-                          InternalAggregations aggregations, SearchProfileShardResults shardResults, ScoreDoc[] scoreDocs,
-                          SortField[] sortFields, DocValueFormat[] sortValueFormats, int numReducePhases, boolean isSortedByField, int size,
-                          int from, boolean isEmptyResult) {
+                          InternalAggregations aggregations, SearchProfileShardResults shardResults, SortedTopDocs sortedTopDocs,
+                          DocValueFormat[] sortValueFormats, int numReducePhases, int size, int from, boolean isEmptyResult) {
             if (numReducePhases <= 0) {
                 throw new IllegalArgumentException("at least one reduce phase must have been applied but was: " + numReducePhases);
             }
@@ -586,9 +585,7 @@ public final class SearchPhaseController {
             this.aggregations = aggregations;
             this.shardResults = shardResults;
             this.numReducePhases = numReducePhases;
-            this.scoreDocs = scoreDocs;
-            this.sortField = sortFields;
-            this.isSortedByField = isSortedByField;
+            this.sortedTopDocs = sortedTopDocs;
             this.size = size;
             this.from = from;
             this.isEmptyResult = isEmptyResult;
@@ -728,7 +725,7 @@ public final class SearchPhaseController {
         }
         return new InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult>(numShards) {
             @Override
-            public ReducedQueryPhase reduce() {
+            ReducedQueryPhase reduce() {
                 return reducedQueryPhase(results.asList(), isScrollRequest, trackTotalHits);
             }
         };
@@ -770,15 +767,23 @@ public final class SearchPhaseController {
     }
 
     static final class SortedTopDocs {
-        static final SortedTopDocs EMPTY = new SortedTopDocs(EMPTY_DOCS, false, null);
+        static final SortedTopDocs EMPTY = new SortedTopDocs(EMPTY_DOCS, false, null, null, null);
+        // the searches merged top docs
         final ScoreDoc[] scoreDocs;
+        // <code>true</code> iff the result score docs is sorted by a field (not score), this implies that <code>sortField</code> is set.
         final boolean isSortedByField;
+        // the top docs sort fields used to sort the score docs, <code>null</code> if the results are not sorted
         final SortField[] sortFields;
+        final String collapseField;
+        final Object[] collapseValues;
 
-        SortedTopDocs(ScoreDoc[] scoreDocs, boolean isSortedByField, SortField[] sortFields) {
+        SortedTopDocs(ScoreDoc[] scoreDocs, boolean isSortedByField, SortField[] sortFields,
+                      String collapseField, Object[] collapseValues) {
             this.scoreDocs = scoreDocs;
             this.isSortedByField = isSortedByField;
             this.sortFields = sortFields;
+            this.collapseField = collapseField;
+            this.collapseValues = collapseValues;
         }
     }
 }

+ 4 - 5
server/src/main/java/org/elasticsearch/action/search/SearchScrollQueryThenFetchAsyncAction.java

@@ -35,7 +35,6 @@ import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.query.ScrollQuerySearchResult;
 import org.elasticsearch.transport.Transport;
 
-import java.io.IOException;
 import java.util.function.BiFunction;
 
 final class SearchScrollQueryThenFetchAsyncAction extends SearchScrollAsyncAction<ScrollQuerySearchResult> {
@@ -68,16 +67,16 @@ final class SearchScrollQueryThenFetchAsyncAction extends SearchScrollAsyncActio
     protected SearchPhase moveToNextPhase(BiFunction<String, String, DiscoveryNode> clusterNodeLookup) {
         return new SearchPhase("fetch") {
             @Override
-            public void run() throws IOException {
+            public void run() {
                 final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = searchPhaseController.reducedScrollQueryPhase(
                     queryResults.asList());
-                if (reducedQueryPhase.scoreDocs.length == 0) {
+                ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs.scoreDocs;
+                if (scoreDocs.length == 0) {
                     sendResponse(reducedQueryPhase, fetchResults);
                     return;
                 }
 
-                final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(queryResults.length(),
-                    reducedQueryPhase.scoreDocs);
+                final IntArrayList[] docIdsToLoad = searchPhaseController.fillDocIdsToLoad(queryResults.length(), scoreDocs);
                 final ScoreDoc[] lastEmittedDocPerShard = searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase,
                     queryResults.length());
                 final CountDown counter = new CountDown(docIdsToLoad.length);

+ 23 - 9
server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java

@@ -785,22 +785,36 @@ public abstract class StreamOutput extends OutputStream {
         }
     }
 
-    public <T extends Writeable> void writeArray(T[] array) throws IOException {
-        writeVInt(array.length);
-        for (T value: array) {
-            value.writeTo(this);
-        }
-    }
-
-    public <T extends Writeable> void writeOptionalArray(@Nullable T[] array) throws IOException {
+    /**
+     * Same as {@link #writeArray(Writer, Object[])} but the provided array may be null. An additional boolean value is
+     * serialized to indicate whether the array was null or not.
+     */
+    public <T> void writeOptionalArray(final Writer<T> writer, final @Nullable T[] array) throws IOException {
         if (array == null) {
             writeBoolean(false);
         } else {
             writeBoolean(true);
-            writeArray(array);
+            writeArray(writer, array);
         }
     }
 
+    /**
+     * Writes the specified array of {@link Writeable}s. This method can be seen as
+     * writer version of {@link StreamInput#readArray(Writeable.Reader, IntFunction)}. The length of array encoded as a variable-length
+     * integer is first written to the stream, and then the elements of the array are written to the stream.
+     */
+    public <T extends Writeable> void writeArray(T[] array) throws IOException {
+        writeArray((out, value) -> value.writeTo(out), array);
+    }
+
+    /**
+     * Same as {@link #writeArray(Writeable[])} but the provided array may be null. An additional boolean value is
+     * serialized to indicate whether the array was null or not.
+     */
+    public <T extends Writeable> void writeOptionalArray(@Nullable T[] array) throws IOException {
+        writeOptionalArray((out, value) -> value.writeTo(out), array);
+    }
+
     /**
      * Serializes a potential null value.
      */

+ 14 - 31
server/src/main/java/org/elasticsearch/common/lucene/Lucene.java

@@ -128,6 +128,9 @@ public class Lucene {
 
     public static final TopDocs EMPTY_TOP_DOCS = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), EMPTY_SCORE_DOCS);
 
+    private Lucene() {
+    }
+
     public static Version parseVersion(@Nullable String version, Version defaultVersion, Logger logger) {
         if (version == null) {
             return defaultVersion;
@@ -201,7 +204,7 @@ public class Lucene {
         try (Lock writeLock = directory.obtainLock(IndexWriter.WRITE_LOCK_NAME)) {
             int foundSegmentFiles = 0;
             for (final String file : directory.listAll()) {
-                /**
+                /*
                  * we could also use a deletion policy here but in the case of snapshot and restore
                  * sometimes we restore an index and override files that were referenced by a "future"
                  * commit. If such a commit is opened by the IW it would likely throw a corrupted index exception
@@ -227,7 +230,7 @@ public class Lucene {
                 .setCommitOnClose(false)
                 .setMergePolicy(NoMergePolicy.INSTANCE)
                 .setOpenMode(IndexWriterConfig.OpenMode.APPEND))) {
-            // do nothing and close this will kick of IndexFileDeleter which will remove all pending files
+            // do nothing and close this will kick off IndexFileDeleter which will remove all pending files
         }
         return si;
     }
@@ -321,12 +324,7 @@ public class Lucene {
         } else if (type == 1) {
             TotalHits totalHits = readTotalHits(in);
             float maxScore = in.readFloat();
-
-            SortField[] fields = new SortField[in.readVInt()];
-            for (int i = 0; i < fields.length; i++) {
-                fields[i] = readSortField(in);
-            }
-
+            SortField[] fields = in.readArray(Lucene::readSortField, SortField[]::new);
             FieldDoc[] fieldDocs = new FieldDoc[in.readVInt()];
             for (int i = 0; i < fieldDocs.length; i++) {
                 fieldDocs[i] = readFieldDoc(in);
@@ -337,10 +335,7 @@ public class Lucene {
             float maxScore = in.readFloat();
 
             String field = in.readString();
-            SortField[] fields = new SortField[in.readVInt()];
-            for (int i = 0; i < fields.length; i++) {
-               fields[i] = readSortField(in);
-            }
+            SortField[] fields = in.readArray(Lucene::readSortField, SortField[]::new);
             int size = in.readVInt();
             Object[] collapseValues = new Object[size];
             FieldDoc[] fieldDocs = new FieldDoc[size];
@@ -385,7 +380,7 @@ public class Lucene {
         return new FieldDoc(in.readVInt(), in.readFloat(), cFields);
     }
 
-    private static Comparable readSortValue(StreamInput in) throws IOException {
+    public static Comparable readSortValue(StreamInput in) throws IOException {
         byte type = in.readByte();
         if (type == 0) {
             return null;
@@ -436,11 +431,7 @@ public class Lucene {
             out.writeFloat(topDocs.maxScore);
 
             out.writeString(collapseDocs.field);
-
-            out.writeVInt(collapseDocs.fields.length);
-            for (SortField sortField : collapseDocs.fields) {
-               writeSortField(out, sortField);
-            }
+            out.writeArray(Lucene::writeSortField, collapseDocs.fields);
 
             out.writeVInt(topDocs.topDocs.scoreDocs.length);
             for (int i = 0; i < topDocs.topDocs.scoreDocs.length; i++) {
@@ -455,10 +446,7 @@ public class Lucene {
             writeTotalHits(out, topDocs.topDocs.totalHits);
             out.writeFloat(topDocs.maxScore);
 
-            out.writeVInt(topFieldDocs.fields.length);
-            for (SortField sortField : topFieldDocs.fields) {
-              writeSortField(out, sortField);
-            }
+            out.writeArray(Lucene::writeSortField, topFieldDocs.fields);
 
             out.writeVInt(topDocs.topDocs.scoreDocs.length);
             for (ScoreDoc doc : topFieldDocs.scoreDocs) {
@@ -501,8 +489,7 @@ public class Lucene {
         }
     }
 
-
-    private static void writeSortValue(StreamOutput out, Object field) throws IOException {
+    public static void writeSortValue(StreamOutput out, Object field) throws IOException {
         if (field == null) {
             out.writeByte((byte) 0);
         } else {
@@ -687,11 +674,7 @@ public class Lucene {
         }
     }
 
-    private Lucene() {
-
-    }
-
-    public static final boolean indexExists(final Directory directory) throws IOException {
+    public static boolean indexExists(final Directory directory) throws IOException {
         return DirectoryReader.indexExists(directory);
     }
 
@@ -701,7 +684,7 @@ public class Lucene {
      *
      * Will retry the directory every second for at least {@code timeLimitMillis}
      */
-    public static final boolean waitForIndex(final Directory directory, final long timeLimitMillis)
+    public static boolean waitForIndex(final Directory directory, final long timeLimitMillis)
             throws IOException {
         final long DELAY = 1000;
         long waited = 0;
@@ -1070,7 +1053,7 @@ public class Lucene {
             }
 
             public LeafMetaData getMetaData() {
-                return new LeafMetaData(Version.LATEST.major, Version.LATEST, (Sort)null);
+                return new LeafMetaData(Version.LATEST.major, Version.LATEST, null);
             }
 
             public CacheHelper getCoreCacheHelper() {

+ 61 - 5
server/src/main/java/org/elasticsearch/search/SearchHits.java

@@ -19,8 +19,10 @@
 
 package org.elasticsearch.search;
 
+import org.apache.lucene.search.SortField;
 import org.apache.lucene.search.TotalHits;
 import org.apache.lucene.search.TotalHits.Relation;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -56,14 +58,29 @@ public final class SearchHits implements Streamable, ToXContentFragment, Iterabl
 
     private float maxScore;
 
+    @Nullable
+    private SortField[] sortFields;
+    @Nullable
+    private String collapseField;
+    @Nullable
+    private Object[] collapseValues;
+
     SearchHits() {
 
     }
 
     public SearchHits(SearchHit[] hits, @Nullable TotalHits totalHits, float maxScore) {
+        this(hits, totalHits, maxScore, null, null, null);
+    }
+
+    public SearchHits(SearchHit[] hits, @Nullable TotalHits totalHits, float maxScore, @Nullable SortField[] sortFields,
+                      @Nullable String collapseField, @Nullable Object[] collapseValues) {
         this.hits = hits;
         this.totalHits = totalHits == null ? null : new Total(totalHits);
         this.maxScore = maxScore;
+        this.sortFields = sortFields;
+        this.collapseField = collapseField;
+        this.collapseValues = collapseValues;
     }
 
     /**
@@ -74,7 +91,6 @@ public final class SearchHits implements Streamable, ToXContentFragment, Iterabl
         return totalHits == null ? null : totalHits.in;
     }
 
-
     /**
      * The maximum score of this query.
      */
@@ -96,6 +112,31 @@ public final class SearchHits implements Streamable, ToXContentFragment, Iterabl
         return hits[position];
     }
 
+    /**
+     * In case documents were sorted by field(s), returns information about such field(s), null otherwise
+     * @see SortField
+     */
+    @Nullable
+    public SortField[] getSortFields() {
+        return sortFields;
+    }
+
+    /**
+     * In case field collapsing was performed, returns the field used for field collapsing, null otherwise
+     */
+    @Nullable
+    public String getCollapseField() {
+        return collapseField;
+    }
+
+    /**
+     * In case field collapsing was performed, returns the values of the field that field collapsing was performed on, null otherwise
+     */
+    @Nullable
+    public Object[] getCollapseValues() {
+        return collapseValues;
+    }
+
     @Override
     public Iterator<SearchHit> iterator() {
         return Arrays.stream(getHits()).iterator();
@@ -175,8 +216,7 @@ public final class SearchHits implements Streamable, ToXContentFragment, Iterabl
                 }
             }
         }
-        SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[hits.size()]), totalHits, maxScore);
-        return searchHits;
+        return new SearchHits(hits.toArray(new SearchHit[0]), totalHits, maxScore);
     }
 
     public static SearchHits readSearchHits(StreamInput in) throws IOException {
@@ -203,6 +243,12 @@ public final class SearchHits implements Streamable, ToXContentFragment, Iterabl
                 hits[i] = SearchHit.readSearchHit(in);
             }
         }
+        //TODO update version once backported
+        if (in.getVersion().onOrAfter(Version.V_7_0_0)) {
+            sortFields = in.readOptionalArray(Lucene::readSortField, SortField[]::new);
+            collapseField = in.readOptionalString();
+            collapseValues = in.readOptionalArray(Lucene::readSortValue, Object[]::new);
+        }
     }
 
     @Override
@@ -219,6 +265,12 @@ public final class SearchHits implements Streamable, ToXContentFragment, Iterabl
                 hit.writeTo(out);
             }
         }
+        //TODO update version once backported
+        if (out.getVersion().onOrAfter(Version.V_7_0_0)) {
+            out.writeOptionalArray(Lucene::writeSortField, sortFields);
+            out.writeOptionalString(collapseField);
+            out.writeOptionalArray(Lucene::writeSortValue, collapseValues);
+        }
     }
 
     @Override
@@ -229,12 +281,16 @@ public final class SearchHits implements Streamable, ToXContentFragment, Iterabl
         SearchHits other = (SearchHits) obj;
         return Objects.equals(totalHits, other.totalHits)
                 && Objects.equals(maxScore, other.maxScore)
-                && Arrays.equals(hits, other.hits);
+                && Arrays.equals(hits, other.hits)
+                && Arrays.equals(sortFields, other.sortFields)
+                && Objects.equals(collapseField, other.collapseField)
+                && Arrays.equals(collapseValues, other.collapseValues);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(totalHits, maxScore, Arrays.hashCode(hits));
+        return Objects.hash(totalHits, maxScore, Arrays.hashCode(hits),
+            Arrays.hashCode(sortFields), collapseField, Arrays.hashCode(collapseValues));
     }
 
     public static TotalHits parseTotalHitsFragment(XContentParser parser) throws IOException {

+ 123 - 34
server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java

@@ -20,10 +20,16 @@
 package org.elasticsearch.action.search;
 
 import com.carrotsearch.randomizedtesting.RandomizedContext;
+import org.apache.lucene.search.FieldDoc;
 import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.SortField;
 import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TopFieldDocs;
 import org.apache.lucene.search.TotalHits;
 import org.apache.lucene.search.TotalHits.Relation;
+import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
 import org.elasticsearch.common.text.Text;
 import org.elasticsearch.common.util.BigArrays;
@@ -47,7 +53,6 @@ import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
 import org.elasticsearch.test.ESTestCase;
 import org.junit.Before;
 
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -138,7 +143,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
             () -> generateQueryResults(nShards, suggestions, searchHitsSize, useConstantScore));
     }
 
-    public void testMerge() throws IOException {
+    public void testMerge() {
         List<CompletionSuggestion> suggestions = new ArrayList<>();
         int maxSuggestSize = 0;
         for (int i = 0; i < randomIntBetween(1, 5); i++) {
@@ -152,8 +157,8 @@ public class SearchPhaseControllerTests extends ESTestCase {
         for (boolean trackTotalHits : new boolean[] {true, false}) {
             SearchPhaseController.ReducedQueryPhase reducedQueryPhase =
                 searchPhaseController.reducedQueryPhase(queryResults.asList(), false, trackTotalHits);
-            AtomicArray<SearchPhaseResult> searchPhaseResultAtomicArray = generateFetchResults(nShards, reducedQueryPhase.scoreDocs,
-                reducedQueryPhase.suggest);
+            AtomicArray<SearchPhaseResult> searchPhaseResultAtomicArray = generateFetchResults(nShards,
+                reducedQueryPhase.sortedTopDocs.scoreDocs, reducedQueryPhase.suggest);
             InternalSearchResponse mergedResponse = searchPhaseController.merge(false,
                 reducedQueryPhase,
                 searchPhaseResultAtomicArray.asList(), searchPhaseResultAtomicArray::get);
@@ -166,7 +171,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
                 suggestSize += stream.collect(Collectors.summingInt(e -> e.getOptions().size()));
             }
             assertThat(suggestSize, lessThanOrEqualTo(maxSuggestSize));
-            assertThat(mergedResponse.hits().getHits().length, equalTo(reducedQueryPhase.scoreDocs.length - suggestSize));
+            assertThat(mergedResponse.hits().getHits().length, equalTo(reducedQueryPhase.sortedTopDocs.scoreDocs.length - suggestSize));
             Suggest suggestResult = mergedResponse.suggest();
             for (Suggest.Suggestion<?> suggestion : reducedQueryPhase.suggest) {
                 assertThat(suggestion, instanceOf(CompletionSuggestion.class));
@@ -183,24 +188,24 @@ public class SearchPhaseControllerTests extends ESTestCase {
         }
     }
 
-    private AtomicArray<SearchPhaseResult> generateQueryResults(int nShards,
+    private static AtomicArray<SearchPhaseResult> generateQueryResults(int nShards,
                                                                 List<CompletionSuggestion> suggestions,
                                                                 int searchHitsSize, boolean useConstantScore) {
         AtomicArray<SearchPhaseResult> queryResults = new AtomicArray<>(nShards);
         for (int shardIndex = 0; shardIndex < nShards; shardIndex++) {
             QuerySearchResult querySearchResult = new QuerySearchResult(shardIndex,
                 new SearchShardTarget("", new Index("", ""), shardIndex, null));
-            TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+            final TopDocs topDocs;
             float maxScore = 0;
-            if (searchHitsSize > 0) {
+            if (searchHitsSize == 0) {
+                topDocs = Lucene.EMPTY_TOP_DOCS;
+            } else {
                 int nDocs = randomIntBetween(0, searchHitsSize);
                 ScoreDoc[] scoreDocs = new ScoreDoc[nDocs];
                 for (int i = 0; i < nDocs; i++) {
                     float score = useConstantScore ? 1.0F : Math.abs(randomFloat());
                     scoreDocs[i] = new ScoreDoc(i, score);
-                    if (score > maxScore) {
-                        maxScore = score;
-                    }
+                    maxScore = Math.max(score, maxScore);
                 }
                 topDocs = new TopDocs(new TotalHits(scoreDocs.length, TotalHits.Relation.EQUAL_TO), scoreDocs);
             }
@@ -283,7 +288,7 @@ public class SearchPhaseControllerTests extends ESTestCase {
                     }
                 }
             }
-            SearchHit[] hits = searchHits.toArray(new SearchHit[searchHits.size()]);
+            SearchHit[] hits = searchHits.toArray(new SearchHit[0]);
             fetchSearchResult.hits(new SearchHits(hits, new TotalHits(hits.length, Relation.EQUAL_TO), maxScore));
             fetchResults.set(shardIndex, fetchSearchResult);
         }
@@ -336,6 +341,10 @@ public class SearchPhaseControllerTests extends ESTestCase {
         assertEquals(numTotalReducePhases, reduce.numReducePhases);
         InternalMax max = (InternalMax) reduce.aggregations.asList().get(0);
         assertEquals(3.0D, max.getValue(), 0.0D);
+        assertFalse(reduce.sortedTopDocs.isSortedByField);
+        assertNull(reduce.sortedTopDocs.sortFields);
+        assertNull(reduce.sortedTopDocs.collapseField);
+        assertNull(reduce.sortedTopDocs.collapseValues);
     }
 
     public void testConsumerConcurrently() throws InterruptedException {
@@ -374,13 +383,17 @@ public class SearchPhaseControllerTests extends ESTestCase {
         SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
         InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
         assertEquals(max.get(), internalMax.getValue(), 0.0D);
-        assertEquals(1, reduce.scoreDocs.length);
+        assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);
         assertEquals(max.get(), reduce.maxScore, 0.0f);
         assertEquals(expectedNumResults, reduce.totalHits.value);
-        assertEquals(max.get(), reduce.scoreDocs[0].score, 0.0f);
+        assertEquals(max.get(), reduce.sortedTopDocs.scoreDocs[0].score, 0.0f);
+        assertFalse(reduce.sortedTopDocs.isSortedByField);
+        assertNull(reduce.sortedTopDocs.sortFields);
+        assertNull(reduce.sortedTopDocs.collapseField);
+        assertNull(reduce.sortedTopDocs.collapseValues);
     }
 
-    public void testConsumerOnlyAggs() throws InterruptedException {
+    public void testConsumerOnlyAggs() {
         int expectedNumResults = randomIntBetween(1, 100);
         int bufferSize = randomIntBetween(2, 200);
         SearchRequest request = new SearchRequest();
@@ -390,29 +403,31 @@ public class SearchPhaseControllerTests extends ESTestCase {
             searchPhaseController.newSearchPhaseResults(request, expectedNumResults);
         AtomicInteger max = new AtomicInteger();
         for (int i = 0; i < expectedNumResults; i++) {
-            int id = i;
             int number = randomIntBetween(1, 1000);
             max.updateAndGet(prev -> Math.max(prev, number));
-            QuerySearchResult result = new QuerySearchResult(id, new SearchShardTarget("node", new Index("a", "b"), id, null));
+            QuerySearchResult result = new QuerySearchResult(i, new SearchShardTarget("node", new Index("a", "b"), i, null));
             result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), number),
                     new DocValueFormat[0]);
             InternalAggregations aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", (double) number,
                 DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap())));
             result.aggregations(aggs);
-            result.setShardIndex(id);
+            result.setShardIndex(i);
             result.size(1);
             consumer.consumeResult(result);
         }
         SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
         InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
         assertEquals(max.get(), internalMax.getValue(), 0.0D);
-        assertEquals(0, reduce.scoreDocs.length);
+        assertEquals(0, reduce.sortedTopDocs.scoreDocs.length);
         assertEquals(max.get(), reduce.maxScore, 0.0f);
         assertEquals(expectedNumResults, reduce.totalHits.value);
+        assertFalse(reduce.sortedTopDocs.isSortedByField);
+        assertNull(reduce.sortedTopDocs.sortFields);
+        assertNull(reduce.sortedTopDocs.collapseField);
+        assertNull(reduce.sortedTopDocs.collapseValues);
     }
 
-
-    public void testConsumerOnlyHits() throws InterruptedException {
+    public void testConsumerOnlyHits() {
         int expectedNumResults = randomIntBetween(1, 100);
         int bufferSize = randomIntBetween(2, 200);
         SearchRequest request = new SearchRequest();
@@ -424,24 +439,26 @@ public class SearchPhaseControllerTests extends ESTestCase {
             searchPhaseController.newSearchPhaseResults(request, expectedNumResults);
         AtomicInteger max = new AtomicInteger();
         for (int i = 0; i < expectedNumResults; i++) {
-            int id = i;
             int number = randomIntBetween(1, 1000);
             max.updateAndGet(prev -> Math.max(prev, number));
-            QuerySearchResult result = new QuerySearchResult(id, new SearchShardTarget("node", new Index("a", "b"), id, null));
+            QuerySearchResult result = new QuerySearchResult(i, new SearchShardTarget("node", new Index("a", "b"), i, null));
             result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO),
                     new ScoreDoc[] {new ScoreDoc(0, number)}), number), new DocValueFormat[0]);
-            result.setShardIndex(id);
+            result.setShardIndex(i);
             result.size(1);
             consumer.consumeResult(result);
         }
         SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
-        assertEquals(1, reduce.scoreDocs.length);
+        assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);
         assertEquals(max.get(), reduce.maxScore, 0.0f);
         assertEquals(expectedNumResults, reduce.totalHits.value);
-        assertEquals(max.get(), reduce.scoreDocs[0].score, 0.0f);
+        assertEquals(max.get(), reduce.sortedTopDocs.scoreDocs[0].score, 0.0f);
+        assertFalse(reduce.sortedTopDocs.isSortedByField);
+        assertNull(reduce.sortedTopDocs.sortFields);
+        assertNull(reduce.sortedTopDocs.collapseField);
+        assertNull(reduce.sortedTopDocs.collapseValues);
     }
 
-
     public void testNewSearchPhaseResults() {
         for (int i = 0; i < 10; i++) {
             int expectedNumResults = randomIntBetween(1, 10);
@@ -497,15 +514,87 @@ public class SearchPhaseControllerTests extends ESTestCase {
             consumer.consumeResult(result);
         }
         // 4*3 results = 12 we get result 5 to 10 here with from=5 and size=5
-
         SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
-        assertEquals(5, reduce.scoreDocs.length);
+        ScoreDoc[] scoreDocs = reduce.sortedTopDocs.scoreDocs;
+        assertEquals(5, scoreDocs.length);
         assertEquals(100.f, reduce.maxScore, 0.0f);
         assertEquals(12, reduce.totalHits.value);
-        assertEquals(95.0f, reduce.scoreDocs[0].score, 0.0f);
-        assertEquals(94.0f, reduce.scoreDocs[1].score, 0.0f);
-        assertEquals(93.0f, reduce.scoreDocs[2].score, 0.0f);
-        assertEquals(92.0f, reduce.scoreDocs[3].score, 0.0f);
-        assertEquals(91.0f, reduce.scoreDocs[4].score, 0.0f);
+        assertEquals(95.0f, scoreDocs[0].score, 0.0f);
+        assertEquals(94.0f, scoreDocs[1].score, 0.0f);
+        assertEquals(93.0f, scoreDocs[2].score, 0.0f);
+        assertEquals(92.0f, scoreDocs[3].score, 0.0f);
+        assertEquals(91.0f, scoreDocs[4].score, 0.0f);
+    }
+
+    public void testConsumerSortByField() {
+        int expectedNumResults = randomIntBetween(1, 100);
+        int bufferSize = randomIntBetween(2, 200);
+        SearchRequest request = new SearchRequest();
+        int size = randomIntBetween(1, 10);
+        request.setBatchedReduceSize(bufferSize);
+        InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer =
+            searchPhaseController.newSearchPhaseResults(request, expectedNumResults);
+        AtomicInteger max = new AtomicInteger();
+        SortField[] sortFields = {new SortField("field", SortField.Type.INT, true)};
+        DocValueFormat[] docValueFormats = {DocValueFormat.RAW};
+        for (int i = 0; i < expectedNumResults; i++) {
+            int number = randomIntBetween(1, 1000);
+            max.updateAndGet(prev -> Math.max(prev, number));
+            FieldDoc[] fieldDocs = {new FieldDoc(0, Float.NaN, new Object[]{number})};
+            TopDocs topDocs = new TopFieldDocs(new TotalHits(1, Relation.EQUAL_TO), fieldDocs, sortFields);
+            QuerySearchResult result = new QuerySearchResult(i, new SearchShardTarget("node", new Index("a", "b"), i, null));
+            result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats);
+            result.setShardIndex(i);
+            result.size(size);
+            consumer.consumeResult(result);
+        }
+        SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
+        assertEquals(Math.min(expectedNumResults, size), reduce.sortedTopDocs.scoreDocs.length);
+        assertEquals(expectedNumResults, reduce.totalHits.value);
+        assertEquals(max.get(), ((FieldDoc)reduce.sortedTopDocs.scoreDocs[0]).fields[0]);
+        assertTrue(reduce.sortedTopDocs.isSortedByField);
+        assertEquals(1, reduce.sortedTopDocs.sortFields.length);
+        assertEquals("field", reduce.sortedTopDocs.sortFields[0].getField());
+        assertEquals(SortField.Type.INT, reduce.sortedTopDocs.sortFields[0].getType());
+        assertNull(reduce.sortedTopDocs.collapseField);
+        assertNull(reduce.sortedTopDocs.collapseValues);
+    }
+
+    public void testConsumerFieldCollapsing() {
+        int expectedNumResults = randomIntBetween(30, 100);
+        int bufferSize = randomIntBetween(2, 200);
+        SearchRequest request = new SearchRequest();
+        int size = randomIntBetween(5, 10);
+        request.setBatchedReduceSize(bufferSize);
+        InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer =
+            searchPhaseController.newSearchPhaseResults(request, expectedNumResults);
+        SortField[] sortFields = {new SortField("field", SortField.Type.STRING)};
+        BytesRef a = new BytesRef("a");
+        BytesRef b = new BytesRef("b");
+        BytesRef c = new BytesRef("c");
+        Object[] collapseValues = new Object[]{a, b, c};
+        DocValueFormat[] docValueFormats = {DocValueFormat.RAW};
+        for (int i = 0; i < expectedNumResults; i++) {
+            Object[] values = {randomFrom(collapseValues)};
+            FieldDoc[] fieldDocs = {new FieldDoc(0, Float.NaN, values)};
+            TopDocs topDocs = new CollapseTopFieldDocs("field", new TotalHits(1, Relation.EQUAL_TO), fieldDocs, sortFields, values);
+            QuerySearchResult result = new QuerySearchResult(i, new SearchShardTarget("node", new Index("a", "b"), i, null));
+            result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), docValueFormats);
+            result.setShardIndex(i);
+            result.size(size);
+            consumer.consumeResult(result);
+        }
+        SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
+        assertEquals(3, reduce.sortedTopDocs.scoreDocs.length);
+        assertEquals(expectedNumResults, reduce.totalHits.value);
+        assertEquals(a, ((FieldDoc)reduce.sortedTopDocs.scoreDocs[0]).fields[0]);
+        assertEquals(b, ((FieldDoc)reduce.sortedTopDocs.scoreDocs[1]).fields[0]);
+        assertEquals(c, ((FieldDoc)reduce.sortedTopDocs.scoreDocs[2]).fields[0]);
+        assertTrue(reduce.sortedTopDocs.isSortedByField);
+        assertEquals(1, reduce.sortedTopDocs.sortFields.length);
+        assertEquals("field", reduce.sortedTopDocs.sortFields[0].getField());
+        assertEquals(SortField.Type.STRING, reduce.sortedTopDocs.sortFields[0].getType());
+        assertEquals("field", reduce.sortedTopDocs.collapseField);
+        assertArrayEquals(collapseValues, reduce.sortedTopDocs.collapseValues);
     }
 }

+ 22 - 1
server/src/test/java/org/elasticsearch/common/io/stream/StreamTests.java

@@ -213,7 +213,6 @@ public class StreamTests extends ESTestCase {
     }
 
     public void testWritableArrays() throws IOException {
-
         final String[] strings = generateRandomStringArray(10, 10, false, true);
         WriteableString[] sourceArray = Arrays.stream(strings).<WriteableString>map(WriteableString::new).toArray(WriteableString[]::new);
         WriteableString[] targetArray;
@@ -233,6 +232,28 @@ public class StreamTests extends ESTestCase {
         assertThat(targetArray, equalTo(sourceArray));
     }
 
+    public void testArrays() throws IOException {
+        final String[] strings;
+        final String[] deserialized;
+        Writeable.Writer<String> writer = StreamOutput::writeString;
+        Writeable.Reader<String> reader = StreamInput::readString;
+        BytesStreamOutput out = new BytesStreamOutput();
+        if (randomBoolean()) {
+            if (randomBoolean()) {
+                strings = null;
+            } else {
+                strings = generateRandomStringArray(10, 10, false, true);
+            }
+            out.writeOptionalArray(writer, strings);
+            deserialized = out.bytes().streamInput().readOptionalArray(reader, String[]::new);
+        } else {
+            strings = generateRandomStringArray(10, 10, false, true);
+            out.writeArray(writer, strings);
+            deserialized = out.bytes().streamInput().readArray(reader, String[]::new);
+        }
+        assertThat(deserialized, equalTo(strings));
+    }
+
     public void testSetOfLongs() throws IOException {
         final int size = randomIntBetween(0, 6);
         final Set<Long> sourceSet = new HashSet<>(size);

+ 162 - 1
server/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java

@@ -23,6 +23,7 @@ import org.apache.lucene.analysis.core.KeywordAnalyzer;
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
 import org.apache.lucene.document.Field.Store;
+import org.apache.lucene.document.LatLonDocValuesField;
 import org.apache.lucene.document.StringField;
 import org.apache.lucene.document.TextField;
 import org.apache.lucene.index.DirectoryReader;
@@ -37,8 +38,12 @@ import org.apache.lucene.index.SoftDeletesRetentionMergePolicy;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.MatchAllDocsQuery;
-import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.SortField;
+import org.apache.lucene.search.SortedNumericSortField;
+import org.apache.lucene.search.SortedSetSelector;
+import org.apache.lucene.search.SortedSetSortField;
 import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.Weight;
@@ -46,8 +51,18 @@ import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.MMapDirectory;
 import org.apache.lucene.store.MockDirectoryWrapper;
 import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.common.collect.Tuple;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.core.internal.io.IOUtils;
+import org.elasticsearch.index.fielddata.IndexFieldData;
+import org.elasticsearch.index.fielddata.fieldcomparator.BytesRefFieldComparatorSource;
+import org.elasticsearch.index.fielddata.fieldcomparator.DoubleValuesComparatorSource;
+import org.elasticsearch.index.fielddata.fieldcomparator.FloatValuesComparatorSource;
+import org.elasticsearch.index.fielddata.fieldcomparator.LongValuesComparatorSource;
+import org.elasticsearch.search.MultiValueMode;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.VersionUtils;
 
 import java.io.IOException;
 import java.io.StringReader;
@@ -62,6 +77,8 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import static org.hamcrest.Matchers.equalTo;
 
 public class LuceneTests extends ESTestCase {
+    private static final NamedWriteableRegistry EMPTY_REGISTRY = new NamedWriteableRegistry(Collections.emptyList());
+
     public void testWaitForIndex() throws Exception {
         final MockDirectoryWrapper dir = newMockDirectory();
 
@@ -498,4 +515,148 @@ public class LuceneTests extends ESTestCase {
         }
         IOUtils.close(writer, dir);
     }
+
+    public void testSortFieldSerialization() throws IOException {
+        Tuple<SortField, SortField> sortFieldTuple = randomSortField();
+        SortField deserialized = copyInstance(sortFieldTuple.v1(), EMPTY_REGISTRY, Lucene::writeSortField, Lucene::readSortField,
+            VersionUtils.randomVersion(random()));
+        assertEquals(sortFieldTuple.v2(), deserialized);
+    }
+
+    public void testSortValueSerialization() throws IOException {
+        Object sortValue = randomSortValue();
+        Object deserialized = copyInstance(sortValue, EMPTY_REGISTRY, Lucene::writeSortValue, Lucene::readSortValue,
+            VersionUtils.randomVersion(random()));
+        assertEquals(sortValue, deserialized);
+    }
+
+    public static Object randomSortValue() {
+        switch(randomIntBetween(0, 8)) {
+            case 0:
+                return randomAlphaOfLengthBetween(3, 10);
+            case 1:
+                return randomInt();
+            case 2:
+                return randomLong();
+            case 3:
+                return randomFloat();
+            case 4:
+                return randomDouble();
+            case 5:
+                return randomByte();
+            case 6:
+                return randomShort();
+            case 7:
+                return randomBoolean();
+            case 8:
+                return new BytesRef(randomAlphaOfLengthBetween(3, 10));
+            default:
+                throw new UnsupportedOperationException();
+        }
+    }
+
+    public static Tuple<SortField, SortField> randomSortField() {
+        switch(randomIntBetween(0, 2)) {
+            case 0:
+                return randomSortFieldCustomComparatorSource();
+            case 1:
+                return randomCustomSortField();
+            case 2:
+                String field = randomAlphaOfLengthBetween(3, 10);
+                SortField.Type type = randomFrom(SortField.Type.values());
+                if ((type == SortField.Type.SCORE || type == SortField.Type.DOC) && randomBoolean()) {
+                    field = null;
+                }
+                SortField sortField = new SortField(field, type, randomBoolean());
+                Object missingValue = randomMissingValue(sortField.getType());
+                if (missingValue != null) {
+                    sortField.setMissingValue(missingValue);
+                }
+                return Tuple.tuple(sortField, sortField);
+            default:
+                throw new UnsupportedOperationException();
+        }
+    }
+
+    private static Tuple<SortField, SortField> randomSortFieldCustomComparatorSource() {
+        String field = randomAlphaOfLengthBetween(3, 10);
+        IndexFieldData.XFieldComparatorSource comparatorSource;
+        boolean reverse = randomBoolean();
+        Object missingValue = null;
+        switch(randomIntBetween(0, 3)) {
+            case 0:
+                comparatorSource = new LongValuesComparatorSource(null, randomBoolean() ? randomLong() : null,
+                    randomFrom(MultiValueMode.values()), null);
+                break;
+            case 1:
+                comparatorSource = new DoubleValuesComparatorSource(null, randomBoolean() ? randomDouble() : null,
+                    randomFrom(MultiValueMode.values()), null);
+                break;
+            case 2:
+                comparatorSource = new FloatValuesComparatorSource(null, randomBoolean() ? randomFloat() : null,
+                    randomFrom(MultiValueMode.values()), null);
+                break;
+            case 3:
+                comparatorSource = new BytesRefFieldComparatorSource(null,
+                    randomBoolean() ? "_first" : "_last", randomFrom(MultiValueMode.values()), null);
+                missingValue = comparatorSource.missingValue(reverse);
+                break;
+            default:
+                throw new UnsupportedOperationException();
+        }
+        SortField sortField = new SortField(field, comparatorSource, reverse);
+        SortField expected = new SortField(field, comparatorSource.reducedType(), reverse);
+        expected.setMissingValue(missingValue);
+        return Tuple.tuple(sortField, expected);
+    }
+
+    private static Tuple<SortField, SortField> randomCustomSortField() {
+        String field = randomAlphaOfLengthBetween(3, 10);
+        switch(randomIntBetween(0, 2)) {
+            case 0: {
+                SortField sortField = LatLonDocValuesField.newDistanceSort(field, 0, 0);
+                SortField expected = new SortField(field, SortField.Type.DOUBLE);
+                expected.setMissingValue(Double.POSITIVE_INFINITY);
+                return Tuple.tuple(sortField, expected);
+            }
+            case 1: {
+                SortedSetSortField sortField = new SortedSetSortField(field, randomBoolean(), randomFrom(SortedSetSelector.Type.values()));
+                SortField expected = new SortField(sortField.getField(), SortField.Type.STRING, sortField.getReverse());
+                Object missingValue = randomMissingValue(SortField.Type.STRING);
+                sortField.setMissingValue(missingValue);
+                expected.setMissingValue(missingValue);
+                return Tuple.tuple(sortField, expected);
+            }
+            case 2: {
+                SortField.Type type = randomFrom(SortField.Type.DOUBLE, SortField.Type.INT, SortField.Type.FLOAT, SortField.Type.LONG);
+                SortedNumericSortField sortField = new SortedNumericSortField(field, type, randomBoolean());
+                SortField expected = new SortField(sortField.getField(), sortField.getNumericType(), sortField.getReverse());
+                Object missingValue = randomMissingValue(type);
+                if (missingValue != null) {
+                    sortField.setMissingValue(missingValue);
+                    expected.setMissingValue(missingValue);
+                }
+                return Tuple.tuple(sortField, expected);
+            }
+            default:
+                throw new UnsupportedOperationException();
+        }
+    }
+
+    private static Object randomMissingValue(SortField.Type type) {
+        switch(type) {
+            case INT:
+                return randomInt();
+            case FLOAT:
+                return randomFloat();
+            case DOUBLE:
+                return randomDouble();
+            case LONG:
+                return randomLong();
+            case STRING:
+                return randomBoolean() ? SortField.STRING_FIRST : SortField.STRING_LAST;
+            default:
+                return null;
+        }
+    }
 }

+ 12 - 12
server/src/test/java/org/elasticsearch/search/SearchHitTests.java

@@ -19,15 +19,6 @@
 
 package org.elasticsearch.search;
 
-import java.io.IOException;
-import java.io.InputStream;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.function.Predicate;
-
 import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.TotalHits;
 import org.elasticsearch.action.OriginalIndices;
@@ -52,6 +43,15 @@ import org.elasticsearch.search.fetch.subphase.highlight.HighlightFieldTests;
 import org.elasticsearch.test.AbstractStreamableTestCase;
 import org.elasticsearch.test.RandomObjects;
 
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Predicate;
+
 import static org.elasticsearch.common.xcontent.XContentHelper.toXContent;
 import static org.elasticsearch.test.XContentTestUtils.insertRandomFields;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
@@ -65,7 +65,7 @@ public class SearchHitTests extends AbstractStreamableTestCase<SearchHit> {
         return createTestItem(randomFrom(XContentType.values()), withOptionalInnerHits, withShardTarget);
     }
 
-    public static SearchHit createTestItem(XContentType xContentType, boolean withOptionalInnerHits, boolean withShardTarget) {
+    public static SearchHit createTestItem(XContentType xContentType, boolean withOptionalInnerHits, boolean transportSerialization) {
         int internalId = randomInt();
         String uid = randomAlphaOfLength(10);
         Text type = new Text(randomAlphaOfLengthBetween(5, 10));
@@ -120,12 +120,12 @@ public class SearchHitTests extends AbstractStreamableTestCase<SearchHit> {
                 Map<String, SearchHits> innerHits = new HashMap<>(innerHitsSize);
                 for (int i = 0; i < innerHitsSize; i++) {
                     innerHits.put(randomAlphaOfLength(5),
-                        SearchHitsTests.createTestItem(xContentType, false, withShardTarget));
+                        SearchHitsTests.createTestItem(xContentType, false, transportSerialization));
                 }
                 hit.setInnerHits(innerHits);
             }
         }
-        if (withShardTarget && randomBoolean()) {
+        if (transportSerialization && randomBoolean()) {
             String index = randomAlphaOfLengthBetween(5, 10);
             String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10);
             hit.shard(new SearchShardTarget(randomAlphaOfLengthBetween(5, 10),

+ 111 - 10
server/src/test/java/org/elasticsearch/search/SearchHitsTests.java

@@ -19,11 +19,15 @@
 
 package org.elasticsearch.search;
 
+import org.apache.lucene.search.SortField;
 import org.apache.lucene.search.TotalHits;
 import org.apache.lucene.util.TestUtil;
+import org.elasticsearch.Version;
 import org.elasticsearch.action.OriginalIndices;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.lucene.LuceneTests;
 import org.elasticsearch.common.text.Text;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
 import org.elasticsearch.common.xcontent.ToXContent;
@@ -34,41 +38,75 @@ import org.elasticsearch.common.xcontent.json.JsonXContent;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.test.AbstractStreamableXContentTestCase;
+import org.elasticsearch.test.VersionUtils;
 
 import java.io.IOException;
+import java.util.Base64;
 import java.util.Collections;
 import java.util.function.Predicate;
 
 public class SearchHitsTests extends AbstractStreamableXContentTestCase<SearchHits> {
+
     public static SearchHits createTestItem(boolean withOptionalInnerHits, boolean withShardTarget) {
         return createTestItem(randomFrom(XContentType.values()), withOptionalInnerHits, withShardTarget);
     }
 
     private static SearchHit[] createSearchHitArray(int size, XContentType xContentType, boolean withOptionalInnerHits,
-                                                    boolean withShardTarget) {
+                                                    boolean transportSerialization) {
         SearchHit[] hits = new SearchHit[size];
         for (int i = 0; i < hits.length; i++) {
-            hits[i] = SearchHitTests.createTestItem(xContentType, withOptionalInnerHits, withShardTarget);
+            hits[i] = SearchHitTests.createTestItem(xContentType, withOptionalInnerHits, transportSerialization);
         }
         return hits;
     }
 
-    private static TotalHits randomTotalHits() {
+    private static TotalHits randomTotalHits(TotalHits.Relation relation) {
         long totalHits = TestUtil.nextLong(random(), 0, Long.MAX_VALUE);
-        TotalHits.Relation relation = randomFrom(TotalHits.Relation.values());
         return new TotalHits(totalHits, relation);
     }
 
-    public static SearchHits createTestItem(XContentType xContentType, boolean withOptionalInnerHits, boolean withShardTarget) {
+    public static SearchHits createTestItem(XContentType xContentType, boolean withOptionalInnerHits, boolean transportSerialization) {
+        return createTestItem(xContentType, withOptionalInnerHits, transportSerialization, randomFrom(TotalHits.Relation.values()));
+    }
+
+    private static SearchHits createTestItem(XContentType xContentType, boolean withOptionalInnerHits, boolean transportSerialization,
+                                             TotalHits.Relation totalHitsRelation) {
         int searchHits = randomIntBetween(0, 5);
-        SearchHit[] hits = createSearchHitArray(searchHits, xContentType, withOptionalInnerHits, withShardTarget);
+        SearchHit[] hits = createSearchHitArray(searchHits, xContentType, withOptionalInnerHits, transportSerialization);
+        TotalHits totalHits = frequently() ? randomTotalHits(totalHitsRelation) : null;
         float maxScore = frequently() ? randomFloat() : Float.NaN;
-        return new SearchHits(hits, frequently() ? randomTotalHits() : null, maxScore);
+        SortField[] sortFields = null;
+        String collapseField = null;
+        Object[] collapseValues = null;
+        if (transportSerialization) {
+            sortFields = randomBoolean() ? createSortFields(randomIntBetween(1, 5)) : null;
+            collapseField = randomAlphaOfLengthBetween(5, 10);
+            collapseValues = randomBoolean() ? createCollapseValues(randomIntBetween(1, 10)) : null;
+        }
+        return new SearchHits(hits, totalHits, maxScore, sortFields, collapseField, collapseValues);
+    }
+
+    private static SortField[] createSortFields(int size) {
+        SortField[] sortFields = new SortField[size];
+        for (int i = 0; i < sortFields.length; i++) {
+            //sort fields are simplified before serialization, we write directly the simplified version
+            //otherwise equality comparisons become complicated
+            sortFields[i] = LuceneTests.randomSortField().v2();
+        }
+        return sortFields;
+    }
+
+    private static Object[] createCollapseValues(int size) {
+        Object[] collapseValues = new Object[size];
+        for (int i = 0; i < collapseValues.length; i++) {
+            collapseValues[i] = LuceneTests.randomSortValue();
+        }
+        return collapseValues;
     }
 
     @Override
     protected SearchHits mutateInstance(SearchHits instance) {
-        switch (randomIntBetween(0, 2)) {
+        switch (randomIntBetween(0, 5)) {
             case 0:
                 return new SearchHits(createSearchHitArray(instance.getHits().length + 1,
                     randomFrom(XContentType.values()), false, randomBoolean()),
@@ -76,7 +114,7 @@ public class SearchHitsTests extends AbstractStreamableXContentTestCase<SearchHi
             case 1:
                 final TotalHits totalHits;
                 if (instance.getTotalHits() == null) {
-                    totalHits = randomTotalHits();
+                    totalHits = randomTotalHits(randomFrom(TotalHits.Relation.values()));
                 } else {
                     totalHits = null;
                 }
@@ -89,6 +127,33 @@ public class SearchHitsTests extends AbstractStreamableXContentTestCase<SearchHi
                     maxScore = Float.NaN;
                 }
                 return new SearchHits(instance.getHits(), instance.getTotalHits(), maxScore);
+            case 3:
+                SortField[] sortFields;
+                if (instance.getSortFields() == null) {
+                    sortFields = createSortFields(randomIntBetween(1, 5));
+                } else {
+                    sortFields = randomBoolean() ? createSortFields(instance.getSortFields().length + 1) : null;
+                }
+                return new SearchHits(instance.getHits(), instance.getTotalHits(), instance.getMaxScore(),
+                    sortFields, instance.getCollapseField(), instance.getCollapseValues());
+            case 4:
+                String collapseField;
+                if (instance.getCollapseField() == null) {
+                    collapseField = randomAlphaOfLengthBetween(5, 10);
+                } else {
+                    collapseField = randomBoolean() ? instance.getCollapseField() + randomAlphaOfLengthBetween(2, 5) : null;
+                }
+                return new SearchHits(instance.getHits(), instance.getTotalHits(), instance.getMaxScore(),
+                    instance.getSortFields(), collapseField, instance.getCollapseValues());
+            case 5:
+                Object[] collapseValues;
+                if (instance.getCollapseValues() == null) {
+                    collapseValues = createCollapseValues(randomIntBetween(1, 5));
+                } else {
+                    collapseValues = randomBoolean() ? createCollapseValues(instance.getCollapseValues().length) : null;
+                }
+                return new SearchHits(instance.getHits(), instance.getTotalHits(), instance.getMaxScore(),
+                    instance.getSortFields(), instance.getCollapseField(), collapseValues);
             default:
                 throw new UnsupportedOperationException();
         }
@@ -125,7 +190,7 @@ public class SearchHitsTests extends AbstractStreamableXContentTestCase<SearchHi
         // deserialized hit cannot be equal to the original instance.
         // There is another test (#testFromXContentWithShards) that checks the
         // rest serialization with shard targets.
-        return createTestItem(xContentType,true, false);
+        return createTestItem(xContentType, true, false);
     }
 
     @Override
@@ -205,4 +270,40 @@ public class SearchHitsTests extends AbstractStreamableXContentTestCase<SearchHi
 
         }
     }
+
+    //TODO rename method and adapt versions after backport
+    public void testReadFromPre70() throws IOException {
+        try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode("AQC/gAAAAAA="))) {
+            in.setVersion(VersionUtils.randomVersionBetween(random(), Version.V_6_0_0, VersionUtils.getPreviousVersion(Version.V_7_0_0)));
+            SearchHits searchHits = new SearchHits();
+            searchHits.readFrom(in);
+            assertEquals(0, searchHits.getHits().length);
+            assertNotNull(searchHits.getTotalHits());
+            assertEquals(0L, searchHits.getTotalHits().value);
+            assertEquals(TotalHits.Relation.EQUAL_TO, searchHits.getTotalHits().relation);
+            assertEquals(-1F, searchHits.getMaxScore(), 0F);
+            assertNull(searchHits.getSortFields());
+            assertNull(searchHits.getCollapseField());
+            assertNull(searchHits.getCollapseValues());
+        }
+    }
+
+    //TODO rename method and adapt versions after backport
+    public void testSerializationPre70() throws IOException {
+        Version version = VersionUtils.randomVersionBetween(random(), Version.V_6_0_0, VersionUtils.getPreviousVersion(Version.V_7_0_0));
+        SearchHits original = createTestItem(randomFrom(XContentType.values()), false, true, TotalHits.Relation.EQUAL_TO);
+        SearchHits deserialized = copyInstance(original, version);
+        assertArrayEquals(original.getHits(), deserialized.getHits());
+        assertEquals(original.getMaxScore(), deserialized.getMaxScore(), 0F);
+        if (original.getTotalHits() == null) {
+            assertNull(deserialized.getTotalHits());
+        } else {
+            assertNotNull(deserialized.getTotalHits());
+            assertEquals(original.getTotalHits().value, deserialized.getTotalHits().value);
+            assertEquals(original.getTotalHits().relation, deserialized.getTotalHits().relation);
+        }
+        assertNull(deserialized.getSortFields());
+        assertNull(deserialized.getCollapseField());
+        assertNull(deserialized.getCollapseValues());
+    }
 }

+ 1 - 1
test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java

@@ -1151,7 +1151,7 @@ public abstract class ESTestCase extends LuceneTestCase {
                 Streamable.newWriteableReader(supplier), version);
     }
 
-    private static <T> T copyInstance(T original, NamedWriteableRegistry namedWriteableRegistry, Writeable.Writer<T> writer,
+    protected static <T> T copyInstance(T original, NamedWriteableRegistry namedWriteableRegistry, Writeable.Writer<T> writer,
                                       Writeable.Reader<T> reader, Version version) throws IOException {
         try (BytesStreamOutput output = new BytesStreamOutput()) {
             output.setVersion(version);