Browse Source

Add profiling information for knn vector queries (#90200)

This adds timers to the dfs phase to profile a knn vector query and provide a breakdown of several 
parts of the query.
Jack Conradson 3 years ago
parent
commit
94f05da248

+ 5 - 0
docs/changelog/90200.yaml

@@ -0,0 +1,5 @@
+pr: 90200
+summary: Add profiling information for knn vector queries
+area: Vector Search
+type: enhancement
+issues: []

+ 116 - 0
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/370_profile.yml

@@ -161,3 +161,119 @@ disabling stored fields removes fetch sub phases:
   - match: { hits.hits.0._index: test }
   - match: { profile.shards.0.fetch.debug.stored_fields: [] }
   - is_false: profile.shards.0.fetch.children
+
+---
+dfs knn vector profiling:
+  - skip:
+      version: ' - 8.5.99'
+      reason: dfs profiling implemented in 8.6.0
+
+  - do:
+      indices.create:
+        index: images
+        body:
+          settings:
+            index.number_of_shards: 1
+          mappings:
+            properties:
+              image:
+                type: "dense_vector"
+                dims: 3
+                index: true
+                similarity: "l2_norm"
+
+  - do:
+      index:
+        index: images
+        id: "1"
+        refresh: true
+        body:
+          image: [1, 5, -20]
+
+  - do:
+      search:
+        index: images
+        body:
+          profile: true
+          knn:
+            field: "image"
+            query_vector: [-5, 9, -12]
+            k: 1
+            num_candidates: 100
+
+  - match: { hits.total.value: 1 }
+  - match: { profile.shards.0.dfs.knn.query.0.type: "DocAndScoreQuery" }
+  - match: { profile.shards.0.dfs.knn.query.0.description: "DocAndScore[100]" }
+  - gt: { profile.shards.0.dfs.knn.query.0.time_in_nanos: 0 }
+  - match: { profile.shards.0.dfs.knn.query.0.breakdown.set_min_competitive_score_count: 0 }
+  - match: { profile.shards.0.dfs.knn.query.0.breakdown.set_min_competitive_score: 0 }
+  - match: { profile.shards.0.dfs.knn.query.0.breakdown.match_count: 0 }
+  - match: { profile.shards.0.dfs.knn.query.0.breakdown.match: 0 }
+  - match: { profile.shards.0.dfs.knn.query.0.breakdown.shallow_advance_count: 0 }
+  - match: { profile.shards.0.dfs.knn.query.0.breakdown.shallow_advance: 0 }
+  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.next_doc_count: 0 }
+  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.next_doc: 0 }
+  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.score_count: 0 }
+  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.score: 0 }
+  - match: { profile.shards.0.dfs.knn.query.0.breakdown.compute_max_score_count: 0 }
+  - match: { profile.shards.0.dfs.knn.query.0.breakdown.compute_max_score: 0 }
+  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.advance_count: 0 }
+  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.advance: 0 }
+  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.build_scorer_count: 0 }
+  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.build_scorer: 0 }
+  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.create_weight: 0 }
+  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.create_weight_count: 0 }
+  - gt: { profile.shards.0.dfs.knn.rewrite_time: 0 }
+  - match: { profile.shards.0.dfs.knn.collector.0.name: "SimpleTopScoreDocCollector" }
+  - match: { profile.shards.0.dfs.knn.collector.0.reason: "search_top_hits" }
+  - gt: { profile.shards.0.dfs.knn.collector.0.time_in_nanos: 0 }
+
+---
+dfs without knn vector profiling:
+  - skip:
+      version: ' - 8.5.99'
+      reason: dfs profiling implemented in 8.6.0
+
+  - do:
+      indices.create:
+        index: keywords
+        body:
+          settings:
+            index.number_of_shards: 1
+          mappings:
+            properties:
+              keyword:
+                type: "keyword"
+  - do:
+      index:
+        index: keywords
+        id: "1"
+        refresh: true
+        body:
+          keyword: "a"
+
+  - do:
+      search:
+        index: keywords
+        search_type: dfs_query_then_fetch
+        body:
+          profile: true
+          query:
+            term:
+              keyword: "a"
+
+  - match: { hits.total.value: 1 }
+  - is_false: profile.shards.0.dfs
+
+  - do:
+      search:
+        index: keywords
+        search_type: query_then_fetch
+        body:
+          profile: true
+          query:
+            term:
+              keyword: "a"
+
+  - match: { hits.total.value: 1 }
+  - is_false: profile.shards.0.dfs

+ 1 - 0
server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

@@ -98,6 +98,7 @@ final class DfsQueryPhase extends SearchPhase {
                     @Override
                     protected void innerOnResponse(QuerySearchResult response) {
                         try {
+                            response.setSearchProfileDfsPhaseResult(dfsResult.searchProfileDfsPhaseResult());
                             counter.onResult(response);
                         } catch (Exception e) {
                             context.onPhaseFailure(DfsQueryPhase.this, "", e);

+ 84 - 56
server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java

@@ -10,15 +10,17 @@ package org.elasticsearch.search.dfs;
 
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.CollectionStatistics;
+import org.apache.lucene.search.Collector;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.TermStatistics;
-import org.apache.lucene.search.TopDocs;
-import org.elasticsearch.index.query.ParsedQuery;
+import org.apache.lucene.search.TopScoreDocCollector;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.internal.SearchContext;
+import org.elasticsearch.search.profile.query.CollectorResult;
+import org.elasticsearch.search.profile.query.InternalProfileCollector;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.vectors.KnnSearchBuilder;
 import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
@@ -26,6 +28,7 @@ import org.elasticsearch.tasks.TaskCancelledException;
 
 import java.io.IOException;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 /**
@@ -39,71 +42,96 @@ public class DfsPhase {
 
     public void execute(SearchContext context) {
         try {
-            Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
-            Map<Term, TermStatistics> stats = new HashMap<>();
-            IndexSearcher searcher = new IndexSearcher(context.searcher().getIndexReader()) {
-                @Override
-                public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq) throws IOException {
-                    if (context.isCancelled()) {
-                        throw new TaskCancelledException("cancelled");
-                    }
-                    TermStatistics ts = super.termStatistics(term, docFreq, totalTermFreq);
-                    if (ts != null) {
-                        stats.put(term, ts);
-                    }
-                    return ts;
-                }
+            collectStatistics(context);
+            executeKnnVectorQuery(context);
+        } catch (Exception e) {
+            throw new DfsPhaseExecutionException(context.shardTarget(), "Exception during dfs phase", e);
+        }
+    }
+
+    private void collectStatistics(SearchContext context) throws IOException {
+        Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
+        Map<Term, TermStatistics> stats = new HashMap<>();
 
-                @Override
-                public CollectionStatistics collectionStatistics(String field) throws IOException {
-                    if (context.isCancelled()) {
-                        throw new TaskCancelledException("cancelled");
-                    }
-                    CollectionStatistics cs = super.collectionStatistics(field);
-                    if (cs != null) {
-                        fieldStatistics.put(field, cs);
-                    }
-                    return cs;
+        IndexSearcher searcher = new IndexSearcher(context.searcher().getIndexReader()) {
+            @Override
+            public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq) throws IOException {
+                if (context.isCancelled()) {
+                    throw new TaskCancelledException("cancelled");
+                }
+                TermStatistics ts = super.termStatistics(term, docFreq, totalTermFreq);
+                if (ts != null) {
+                    stats.put(term, ts);
                 }
-            };
+                return ts;
+            }
 
-            searcher.createWeight(context.rewrittenQuery(), ScoreMode.COMPLETE, 1);
-            for (RescoreContext rescoreContext : context.rescore()) {
-                for (Query query : rescoreContext.getQueries()) {
-                    searcher.createWeight(context.searcher().rewrite(query), ScoreMode.COMPLETE, 1);
+            @Override
+            public CollectionStatistics collectionStatistics(String field) throws IOException {
+                if (context.isCancelled()) {
+                    throw new TaskCancelledException("cancelled");
+                }
+                CollectionStatistics cs = super.collectionStatistics(field);
+                if (cs != null) {
+                    fieldStatistics.put(field, cs);
                 }
+                return cs;
             }
+        };
 
-            Term[] terms = stats.keySet().toArray(new Term[0]);
-            TermStatistics[] termStatistics = new TermStatistics[terms.length];
-            for (int i = 0; i < terms.length; i++) {
-                termStatistics[i] = stats.get(terms[i]);
+        searcher.createWeight(context.rewrittenQuery(), ScoreMode.COMPLETE, 1);
+        for (RescoreContext rescoreContext : context.rescore()) {
+            for (Query query : rescoreContext.getQueries()) {
+                searcher.createWeight(context.searcher().rewrite(query), ScoreMode.COMPLETE, 1);
             }
+        }
 
-            context.dfsResult()
-                .termsStatistics(terms, termStatistics)
-                .fieldStatistics(fieldStatistics)
-                .maxDoc(context.searcher().getIndexReader().maxDoc());
+        Term[] terms = stats.keySet().toArray(new Term[0]);
+        TermStatistics[] termStatistics = new TermStatistics[terms.length];
+        for (int i = 0; i < terms.length; i++) {
+            termStatistics[i] = stats.get(terms[i]);
+        }
 
-            // If kNN search is requested, perform kNN query and gather top docs
-            SearchSourceBuilder source = context.request().source();
-            if (source != null && source.knnSearch() != null) {
-                SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext();
-                KnnSearchBuilder knnSearch = source.knnSearch();
+        context.dfsResult()
+            .termsStatistics(terms, termStatistics)
+            .fieldStatistics(fieldStatistics)
+            .maxDoc(context.searcher().getIndexReader().maxDoc());
+    }
 
-                KnnVectorQueryBuilder knnVectorQueryBuilder = knnSearch.toQueryBuilder();
-                if (context.request().getAliasFilter().getQueryBuilder() != null) {
-                    knnVectorQueryBuilder.addFilterQuery(context.request().getAliasFilter().getQueryBuilder());
-                }
-                ParsedQuery query = searchExecutionContext.toQuery(knnVectorQueryBuilder);
+    private void executeKnnVectorQuery(SearchContext context) throws IOException {
+        SearchSourceBuilder source = context.request().source();
+        if (source == null || source.knnSearch() == null) {
+            return;
+        }
 
-                TopDocs topDocs = searcher.search(query.query(), knnSearch.k());
-                DfsKnnResults knnResults = new DfsKnnResults(topDocs.scoreDocs);
-                context.dfsResult().knnResults(knnResults);
-            }
-        } catch (Exception e) {
-            throw new DfsPhaseExecutionException(context.shardTarget(), "Exception during dfs phase", e);
+        SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext();
+        KnnSearchBuilder knnSearch = context.request().source().knnSearch();
+        KnnVectorQueryBuilder knnVectorQueryBuilder = knnSearch.toQueryBuilder();
+
+        if (context.request().getAliasFilter().getQueryBuilder() != null) {
+            knnVectorQueryBuilder.addFilterQuery(context.request().getAliasFilter().getQueryBuilder());
         }
-    }
 
+        Query query = searchExecutionContext.toQuery(knnVectorQueryBuilder).query();
+        TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(knnSearch.k(), Integer.MAX_VALUE);
+        Collector collector = topScoreDocCollector;
+
+        if (context.getProfilers() != null) {
+            InternalProfileCollector ipc = new InternalProfileCollector(
+                topScoreDocCollector,
+                CollectorResult.REASON_SEARCH_TOP_HITS,
+                List.of()
+            );
+            context.getProfilers().getCurrentQueryProfiler().setCollector(ipc);
+            collector = ipc;
+        }
+
+        context.searcher().search(query, collector);
+        DfsKnnResults knnResults = new DfsKnnResults(topScoreDocCollector.topDocs().scoreDocs);
+        context.dfsResult().knnResults(knnResults);
+
+        if (context.getProfilers() != null) {
+            context.dfsResult().profileResult(context.getProfilers().buildDfsPhaseResults());
+        }
+    }
 }

+ 17 - 0
server/src/main/java/org/elasticsearch/search/dfs/DfsSearchResult.java

@@ -19,6 +19,7 @@ import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.search.internal.ShardSearchContextId;
 import org.elasticsearch.search.internal.ShardSearchRequest;
+import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
 
 import java.io.IOException;
 import java.util.HashMap;
@@ -33,6 +34,7 @@ public class DfsSearchResult extends SearchPhaseResult {
     private Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
     private DfsKnnResults knnResults;
     private int maxDoc;
+    private SearchProfileDfsPhaseResult searchProfileDfsPhaseResult;
 
     public DfsSearchResult(StreamInput in) throws IOException {
         super(in);
@@ -56,6 +58,9 @@ public class DfsSearchResult extends SearchPhaseResult {
         if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
             knnResults = in.readOptionalWriteable(DfsKnnResults::new);
         }
+        if (in.getVersion().onOrAfter(Version.V_8_6_0)) {
+            searchProfileDfsPhaseResult = in.readOptionalWriteable(SearchProfileDfsPhaseResult::new);
+        }
     }
 
     public DfsSearchResult(ShardSearchContextId contextId, SearchShardTarget shardTarget, ShardSearchRequest shardSearchRequest) {
@@ -89,6 +94,11 @@ public class DfsSearchResult extends SearchPhaseResult {
         return this;
     }
 
+    public DfsSearchResult profileResult(SearchProfileDfsPhaseResult searchProfileDfsPhaseResult) {
+        this.searchProfileDfsPhaseResult = searchProfileDfsPhaseResult;
+        return this;
+    }
+
     public Term[] terms() {
         return terms;
     }
@@ -105,6 +115,10 @@ public class DfsSearchResult extends SearchPhaseResult {
         return knnResults;
     }
 
+    public SearchProfileDfsPhaseResult searchProfileDfsPhaseResult() {
+        return searchProfileDfsPhaseResult;
+    }
+
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         contextId.writeTo(out);
@@ -121,6 +135,9 @@ public class DfsSearchResult extends SearchPhaseResult {
         if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
             out.writeOptionalWriteable(knnResults);
         }
+        if (out.getVersion().onOrAfter(Version.V_8_6_0)) {
+            out.writeOptionalWriteable(searchProfileDfsPhaseResult);
+        }
     }
 
     public static void writeFieldStats(StreamOutput out, Map<String, CollectionStatistics> fieldStatistics) throws IOException {

+ 13 - 0
server/src/main/java/org/elasticsearch/search/profile/Profilers.java

@@ -66,6 +66,19 @@ public final class Profilers {
         return new FetchProfiler();
     }
 
+    /**
+     * Build the results for the dfs phase.
+     */
+    public SearchProfileDfsPhaseResult buildDfsPhaseResults() {
+        QueryProfiler queryProfiler = getCurrentQueryProfiler();
+        QueryProfileShardResult queryProfileShardResult = new QueryProfileShardResult(
+            queryProfiler.getTree(),
+            queryProfiler.getRewriteTime(),
+            queryProfiler.getCollector()
+        );
+        return new SearchProfileDfsPhaseResult(queryProfileShardResult);
+    }
+
     /**
      * Build the results for the query phase.
      */

+ 86 - 0
server/src/main/java/org/elasticsearch/search/profile/SearchProfileDfsPhaseResult.java

@@ -0,0 +1,86 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.profile;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.search.profile.query.QueryProfileShardResult;
+import org.elasticsearch.xcontent.InstantiatingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ParserConstructor;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
+
+public class SearchProfileDfsPhaseResult implements Writeable, ToXContentObject {
+
+    private final QueryProfileShardResult queryProfileShardResult;
+
+    @ParserConstructor
+    public SearchProfileDfsPhaseResult(@Nullable QueryProfileShardResult queryProfileShardResult) {
+        this.queryProfileShardResult = queryProfileShardResult;
+    }
+
+    public SearchProfileDfsPhaseResult(StreamInput in) throws IOException {
+        queryProfileShardResult = in.readOptionalWriteable(QueryProfileShardResult::new);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalWriteable(queryProfileShardResult);
+    }
+
+    private static final ParseField KNN = new ParseField("knn");
+    private static final InstantiatingObjectParser<SearchProfileDfsPhaseResult, Void> PARSER;
+
+    static {
+        InstantiatingObjectParser.Builder<SearchProfileDfsPhaseResult, Void> parser = InstantiatingObjectParser.builder(
+            "search_profile_dfs_phase_result",
+            true,
+            SearchProfileDfsPhaseResult.class
+        );
+        parser.declareObject(optionalConstructorArg(), (p, c) -> QueryProfileShardResult.fromXContent(p), KNN);
+        PARSER = parser.build();
+    }
+
+    public static SearchProfileDfsPhaseResult fromXContent(XContentParser parser) throws IOException {
+        return PARSER.parse(parser, null);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        if (queryProfileShardResult != null) {
+            builder.startObject();
+            builder.field(KNN.getPreferredName());
+            queryProfileShardResult.toXContent(builder, params);
+            builder.endObject();
+        }
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        SearchProfileDfsPhaseResult that = (SearchProfileDfsPhaseResult) o;
+        return Objects.equals(queryProfileShardResult, that.queryProfileShardResult);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(queryProfileShardResult);
+    }
+}

+ 26 - 7
server/src/main/java/org/elasticsearch/search/profile/SearchProfileQueryPhaseResult.java

@@ -8,6 +8,7 @@
 
 package org.elasticsearch.search.profile;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
@@ -25,6 +26,8 @@ import java.util.Objects;
  */
 public class SearchProfileQueryPhaseResult implements Writeable {
 
+    private SearchProfileDfsPhaseResult searchProfileDfsPhaseResult;
+
     private final List<QueryProfileShardResult> queryProfileResults;
 
     private final AggregationProfileShardResult aggProfileShardResult;
@@ -33,11 +36,15 @@ public class SearchProfileQueryPhaseResult implements Writeable {
         List<QueryProfileShardResult> queryProfileResults,
         AggregationProfileShardResult aggProfileShardResult
     ) {
+        this.searchProfileDfsPhaseResult = null;
         this.aggProfileShardResult = aggProfileShardResult;
         this.queryProfileResults = Collections.unmodifiableList(queryProfileResults);
     }
 
     public SearchProfileQueryPhaseResult(StreamInput in) throws IOException {
+        if (in.getVersion().onOrAfter(Version.V_8_6_0)) {
+            searchProfileDfsPhaseResult = in.readOptionalWriteable(SearchProfileDfsPhaseResult::new);
+        }
         int profileSize = in.readVInt();
         List<QueryProfileShardResult> queryProfileResults = new ArrayList<>(profileSize);
         for (int i = 0; i < profileSize; i++) {
@@ -50,6 +57,9 @@ public class SearchProfileQueryPhaseResult implements Writeable {
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        if (out.getVersion().onOrAfter(Version.V_8_6_0)) {
+            out.writeOptionalWriteable(searchProfileDfsPhaseResult);
+        }
         out.writeVInt(queryProfileResults.size());
         for (QueryProfileShardResult queryShardResult : queryProfileResults) {
             queryShardResult.writeTo(out);
@@ -57,6 +67,14 @@ public class SearchProfileQueryPhaseResult implements Writeable {
         aggProfileShardResult.writeTo(out);
     }
 
+    public void setSearchProfileDfsPhaseResult(SearchProfileDfsPhaseResult searchProfileDfsPhaseResult) {
+        this.searchProfileDfsPhaseResult = searchProfileDfsPhaseResult;
+    }
+
+    public SearchProfileDfsPhaseResult getSearchProfileDfsPhaseResult() {
+        return searchProfileDfsPhaseResult;
+    }
+
     public List<QueryProfileShardResult> getQueryProfileResults() {
         return queryProfileResults;
     }
@@ -66,16 +84,17 @@ public class SearchProfileQueryPhaseResult implements Writeable {
     }
 
     @Override
-    public boolean equals(Object obj) {
-        if (obj == null || getClass() != obj.getClass()) {
-            return false;
-        }
-        SearchProfileQueryPhaseResult other = (SearchProfileQueryPhaseResult) obj;
-        return queryProfileResults.equals(other.queryProfileResults) && aggProfileShardResult.equals(other.aggProfileShardResult);
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        SearchProfileQueryPhaseResult that = (SearchProfileQueryPhaseResult) o;
+        return Objects.equals(searchProfileDfsPhaseResult, that.searchProfileDfsPhaseResult)
+            && Objects.equals(queryProfileResults, that.queryProfileResults)
+            && Objects.equals(aggProfileShardResult, that.aggProfileShardResult);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(queryProfileResults, aggProfileShardResult);
+        return Objects.hash(searchProfileDfsPhaseResult, queryProfileResults, aggProfileShardResult);
     }
 }

+ 10 - 2
server/src/main/java/org/elasticsearch/search/profile/SearchProfileResults.java

@@ -129,6 +129,7 @@ public final class SearchProfileResults implements Writeable, ToXContentFragment
         throws IOException {
         XContentParser.Token token = parser.currentToken();
         ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser);
+        SearchProfileDfsPhaseResult searchProfileDfsPhaseResult = null;
         List<QueryProfileShardResult> queryProfileResults = new ArrayList<>();
         AggregationProfileShardResult aggProfileShardResult = null;
         ProfileResult fetchResult = null;
@@ -145,7 +146,7 @@ public final class SearchProfileResults implements Writeable, ToXContentFragment
                 }
             } else if (token == XContentParser.Token.START_ARRAY) {
                 if ("searches".equals(currentFieldName)) {
-                    while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
+                    while ((parser.nextToken()) != XContentParser.Token.END_ARRAY) {
                         queryProfileResults.add(QueryProfileShardResult.fromXContent(parser));
                     }
                 } else if (AggregationProfileShardResult.AGGREGATIONS.equals(currentFieldName)) {
@@ -154,7 +155,13 @@ public final class SearchProfileResults implements Writeable, ToXContentFragment
                     parser.skipChildren();
                 }
             } else if (token == XContentParser.Token.START_OBJECT) {
-                fetchResult = ProfileResult.fromXContent(parser);
+                if ("dfs".equals(currentFieldName)) {
+                    searchProfileDfsPhaseResult = SearchProfileDfsPhaseResult.fromXContent(parser);
+                } else if ("fetch".equals(currentFieldName)) {
+                    fetchResult = ProfileResult.fromXContent(parser);
+                } else {
+                    parser.skipChildren();
+                }
             } else {
                 parser.skipChildren();
             }
@@ -163,6 +170,7 @@ public final class SearchProfileResults implements Writeable, ToXContentFragment
             new SearchProfileQueryPhaseResult(queryProfileResults, aggProfileShardResult),
             fetchResult
         );
+        result.getQueryPhase().setSearchProfileDfsPhaseResult(searchProfileDfsPhaseResult);
         searchProfileResults.put(id, result);
     }
 }

+ 14 - 7
server/src/main/java/org/elasticsearch/search/profile/SearchProfileShardResult.java

@@ -26,8 +26,8 @@ import java.util.Objects;
  * Profile results from a particular shard for all search phases.
  */
 public class SearchProfileShardResult implements Writeable, ToXContentFragment {
-    private final SearchProfileQueryPhaseResult queryPhase;
 
+    private final SearchProfileQueryPhaseResult queryPhase;
     private final ProfileResult fetchPhase;
 
     public SearchProfileShardResult(SearchProfileQueryPhaseResult queryPhase, @Nullable ProfileResult fetch) {
@@ -46,6 +46,10 @@ public class SearchProfileShardResult implements Writeable, ToXContentFragment {
         out.writeOptionalWriteable(fetchPhase);
     }
 
+    public SearchProfileDfsPhaseResult getSearchProfileDfsPhaseResult() {
+        return queryPhase.getSearchProfileDfsPhaseResult();
+    }
+
     public SearchProfileQueryPhaseResult getQueryPhase() {
         return queryPhase;
     }
@@ -64,6 +68,10 @@ public class SearchProfileShardResult implements Writeable, ToXContentFragment {
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        if (getSearchProfileDfsPhaseResult() != null) {
+            builder.field("dfs");
+            getSearchProfileDfsPhaseResult().toXContent(builder, params);
+        }
         builder.startArray("searches");
         for (QueryProfileShardResult result : queryPhase.getQueryProfileResults()) {
             result.toXContent(builder, params);
@@ -78,12 +86,11 @@ public class SearchProfileShardResult implements Writeable, ToXContentFragment {
     }
 
     @Override
-    public boolean equals(Object obj) {
-        if (obj == null || getClass() != obj.getClass()) {
-            return false;
-        }
-        SearchProfileShardResult other = (SearchProfileShardResult) obj;
-        return queryPhase.equals(other.queryPhase) && Objects.equals(fetchPhase, other.fetchPhase);
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        SearchProfileShardResult that = (SearchProfileShardResult) o;
+        return Objects.equals(queryPhase, that.queryPhase) && Objects.equals(fetchPhase, that.fetchPhase);
     }
 
     @Override

+ 8 - 0
server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java

@@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.internal.ShardSearchContextId;
 import org.elasticsearch.search.internal.ShardSearchRequest;
+import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
 import org.elasticsearch.search.profile.SearchProfileQueryPhaseResult;
 import org.elasticsearch.search.suggest.Suggest;
 
@@ -229,6 +230,13 @@ public final class QuerySearchResult extends SearchPhaseResult {
         return aggregations;
     }
 
+    public void setSearchProfileDfsPhaseResult(SearchProfileDfsPhaseResult searchProfileDfsPhaseResult) {
+        if (profileShardResults == null) {
+            return;
+        }
+        profileShardResults.setSearchProfileDfsPhaseResult(searchProfileDfsPhaseResult);
+    }
+
     /**
      * Returns and nulls out the profiled results for this search, or potentially null if result was empty.
      * This allows to free up memory once the profiled result is consumed.

+ 1 - 0
server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java

@@ -56,6 +56,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         float[] queryVector = randomVector();
         KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).boost(5.0f);
         SearchResponse response = client().prepareSearch("index")
+            .setProfile(true)
             .setKnnSearch(knnSearch)
             .setQuery(QueryBuilders.matchQuery("text", "goodnight"))
             .addFetchField("*")

+ 38 - 0
server/src/test/java/org/elasticsearch/search/profile/SearchProfileDfsPhaseResultTests.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.profile;
+
+import org.elasticsearch.common.io.stream.Writeable.Reader;
+import org.elasticsearch.search.profile.query.QueryProfileShardResultTests;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+
+public class SearchProfileDfsPhaseResultTests extends AbstractSerializingTestCase<SearchProfileDfsPhaseResult> {
+
+    static SearchProfileDfsPhaseResult createTestItem() {
+        return new SearchProfileDfsPhaseResult(rarely() ? null : QueryProfileShardResultTests.createTestItem());
+    }
+
+    @Override
+    protected SearchProfileDfsPhaseResult createTestInstance() {
+        return createTestItem();
+    }
+
+    @Override
+    protected Reader<SearchProfileDfsPhaseResult> instanceReader() {
+        return SearchProfileDfsPhaseResult::new;
+    }
+
+    @Override
+    protected SearchProfileDfsPhaseResult doParseInstance(XContentParser parser) throws IOException {
+        return SearchProfileDfsPhaseResult.fromXContent(parser);
+    }
+}

+ 8 - 1
server/src/test/java/org/elasticsearch/search/profile/SearchProfileQueryPhaseResultTests.java

@@ -26,7 +26,14 @@ public class SearchProfileQueryPhaseResultTests extends AbstractWireSerializingT
             queryProfileResults.add(QueryProfileShardResultTests.createTestItem());
         }
         AggregationProfileShardResult aggProfileShardResult = AggregationProfileShardResultTests.createTestItem(1);
-        return new SearchProfileQueryPhaseResult(queryProfileResults, aggProfileShardResult);
+        SearchProfileQueryPhaseResult searchProfileQueryPhaseResult = new SearchProfileQueryPhaseResult(
+            queryProfileResults,
+            aggProfileShardResult
+        );
+        if (randomBoolean()) {
+            searchProfileQueryPhaseResult.setSearchProfileDfsPhaseResult(SearchProfileDfsPhaseResultTests.createTestItem());
+        }
+        return searchProfileQueryPhaseResult;
     }
 
     @Override