浏览代码

Delay the creation of SubSearchContext to the FetchSubPhase (#46598)

This change delays the creation of the SubSearchContext for nested and parent/child inner_hits
to the fetch sub phase in order to ensure that a SearchContext can built entirely from a
QueryShardContext. This commit also adds a validation step to the inner hits builder that ensures that we fail the request early if the inner hits path is invalid.

Relates #46523
Jim Ferenczi 6 年之前
父节点
当前提交
328fe4472f

+ 30 - 21
modules/parent-join/src/main/java/org/elasticsearch/join/query/ParentChildInnerHitContextBuilder.java

@@ -67,20 +67,26 @@ class ParentChildInnerHitContextBuilder extends InnerHitContextBuilder {
     }
     }
 
 
     @Override
     @Override
-    protected void doBuild(SearchContext context, InnerHitsContext innerHitsContext) throws IOException {
+    public void doValidate(QueryShardContext queryShardContext) {
+        if (ParentJoinFieldMapper.getMapper(queryShardContext.getMapperService()) == null
+                && innerHitBuilder.isIgnoreUnmapped() == false) {
+            throw new IllegalStateException("no join field has been configured");
+        }
+    }
+
+    @Override
+    public void build(SearchContext context, InnerHitsContext innerHitsContext) throws IOException {
         QueryShardContext queryShardContext = context.getQueryShardContext();
         QueryShardContext queryShardContext = context.getQueryShardContext();
         ParentJoinFieldMapper joinFieldMapper = ParentJoinFieldMapper.getMapper(context.mapperService());
         ParentJoinFieldMapper joinFieldMapper = ParentJoinFieldMapper.getMapper(context.mapperService());
-        if (joinFieldMapper != null) {
-            String name = innerHitBuilder.getName() != null ? innerHitBuilder.getName() : typeName;
-            JoinFieldInnerHitSubContext joinFieldInnerHits = new JoinFieldInnerHitSubContext(name, context, typeName,
-                fetchChildInnerHits, joinFieldMapper);
-            setupInnerHitsContext(queryShardContext, joinFieldInnerHits);
-            innerHitsContext.addInnerHitDefinition(joinFieldInnerHits);
-        } else {
-            if (innerHitBuilder.isIgnoreUnmapped() == false) {
-                throw new IllegalStateException("no join field has been configured");
-            }
+        if (joinFieldMapper == null) {
+            assert innerHitBuilder.isIgnoreUnmapped() : "should be validated first";
+            return;
         }
         }
+        String name = innerHitBuilder.getName() != null ? innerHitBuilder.getName() : typeName;
+        JoinFieldInnerHitSubContext joinFieldInnerHits =
+            new JoinFieldInnerHitSubContext(name, context, typeName, fetchChildInnerHits, joinFieldMapper);
+        setupInnerHitsContext(queryShardContext, joinFieldInnerHits);
+        innerHitsContext.addInnerHitDefinition(joinFieldInnerHits);
     }
     }
 
 
     static final class JoinFieldInnerHitSubContext extends InnerHitsContext.InnerHitSubContext {
     static final class JoinFieldInnerHitSubContext extends InnerHitsContext.InnerHitSubContext {
@@ -88,8 +94,11 @@ class ParentChildInnerHitContextBuilder extends InnerHitContextBuilder {
         private final boolean fetchChildInnerHits;
         private final boolean fetchChildInnerHits;
         private final ParentJoinFieldMapper joinFieldMapper;
         private final ParentJoinFieldMapper joinFieldMapper;
 
 
-        JoinFieldInnerHitSubContext(String name, SearchContext context, String typeName, boolean fetchChildInnerHits,
-                                    ParentJoinFieldMapper joinFieldMapper) {
+        JoinFieldInnerHitSubContext(String name,
+                                        SearchContext context,
+                                        String typeName,
+                                        boolean fetchChildInnerHits,
+                                        ParentJoinFieldMapper joinFieldMapper) {
             super(name, context);
             super(name, context);
             this.typeName = typeName;
             this.typeName = typeName;
             this.fetchChildInnerHits = fetchChildInnerHits;
             this.fetchChildInnerHits = fetchChildInnerHits;
@@ -102,13 +111,13 @@ class ParentChildInnerHitContextBuilder extends InnerHitContextBuilder {
             TopDocsAndMaxScore[] result = new TopDocsAndMaxScore[hits.length];
             TopDocsAndMaxScore[] result = new TopDocsAndMaxScore[hits.length];
             for (int i = 0; i < hits.length; i++) {
             for (int i = 0; i < hits.length; i++) {
                 SearchHit hit = hits[i];
                 SearchHit hit = hits[i];
-                String joinName = getSortedDocValue(joinFieldMapper.name(), context, hit.docId());
+                String joinName = getSortedDocValue(joinFieldMapper.name(), this, hit.docId());
                 if (joinName == null) {
                 if (joinName == null) {
                     result[i] = new TopDocsAndMaxScore(Lucene.EMPTY_TOP_DOCS, Float.NaN);
                     result[i] = new TopDocsAndMaxScore(Lucene.EMPTY_TOP_DOCS, Float.NaN);
                     continue;
                     continue;
                 }
                 }
 
 
-                QueryShardContext qsc = context.getQueryShardContext();
+                QueryShardContext qsc = getQueryShardContext();
                 ParentIdFieldMapper parentIdFieldMapper =
                 ParentIdFieldMapper parentIdFieldMapper =
                     joinFieldMapper.getParentIdFieldMapper(typeName, fetchChildInnerHits == false);
                     joinFieldMapper.getParentIdFieldMapper(typeName, fetchChildInnerHits == false);
                 if (parentIdFieldMapper == null) {
                 if (parentIdFieldMapper == null) {
@@ -126,14 +135,14 @@ class ParentChildInnerHitContextBuilder extends InnerHitContextBuilder {
                         .add(joinFieldMapper.fieldType().termQuery(typeName, qsc), BooleanClause.Occur.FILTER)
                         .add(joinFieldMapper.fieldType().termQuery(typeName, qsc), BooleanClause.Occur.FILTER)
                         .build();
                         .build();
                 } else {
                 } else {
-                    String parentId = getSortedDocValue(parentIdFieldMapper.name(), context, hit.docId());
-                    q = context.mapperService().fullName(IdFieldMapper.NAME).termQuery(parentId, qsc);
+                    String parentId = getSortedDocValue(parentIdFieldMapper.name(), this, hit.docId());
+                    q = mapperService().fullName(IdFieldMapper.NAME).termQuery(parentId, qsc);
                 }
                 }
 
 
-                Weight weight = context.searcher().createWeight(context.searcher().rewrite(q), ScoreMode.COMPLETE_NO_SCORES, 1f);
+                Weight weight = searcher().createWeight(searcher().rewrite(q), ScoreMode.COMPLETE_NO_SCORES, 1f);
                 if (size() == 0) {
                 if (size() == 0) {
                     TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
                     TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
-                    for (LeafReaderContext ctx : context.searcher().getIndexReader().leaves()) {
+                    for (LeafReaderContext ctx : searcher().getIndexReader().leaves()) {
                         intersect(weight, innerHitQueryWeight, totalHitCountCollector, ctx);
                         intersect(weight, innerHitQueryWeight, totalHitCountCollector, ctx);
                     }
                     }
                     result[i] = new TopDocsAndMaxScore(
                     result[i] = new TopDocsAndMaxScore(
@@ -142,7 +151,7 @@ class ParentChildInnerHitContextBuilder extends InnerHitContextBuilder {
                             Lucene.EMPTY_SCORE_DOCS
                             Lucene.EMPTY_SCORE_DOCS
                         ), Float.NaN);
                         ), Float.NaN);
                 } else {
                 } else {
-                    int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
+                    int topN = Math.min(from() + size(), searcher().getIndexReader().maxDoc());
                     TopDocsCollector<?> topDocsCollector;
                     TopDocsCollector<?> topDocsCollector;
                     MaxScoreCollector maxScoreCollector = null;
                     MaxScoreCollector maxScoreCollector = null;
                     if (sort() != null) {
                     if (sort() != null) {
@@ -155,7 +164,7 @@ class ParentChildInnerHitContextBuilder extends InnerHitContextBuilder {
                         maxScoreCollector = new MaxScoreCollector();
                         maxScoreCollector = new MaxScoreCollector();
                     }
                     }
                     try {
                     try {
-                        for (LeafReaderContext ctx : context.searcher().getIndexReader().leaves()) {
+                        for (LeafReaderContext ctx : searcher().getIndexReader().leaves()) {
                             intersect(weight, innerHitQueryWeight, MultiCollector.wrap(topDocsCollector, maxScoreCollector), ctx);
                             intersect(weight, innerHitQueryWeight, MultiCollector.wrap(topDocsCollector, maxScoreCollector), ctx);
                         }
                         }
                     } finally {
                     } finally {

+ 5 - 6
modules/parent-join/src/test/java/org/elasticsearch/join/query/HasChildQueryBuilderTests.java

@@ -177,14 +177,13 @@ public class HasChildQueryBuilderTests extends AbstractQueryTestCase<HasChildQue
             queryBuilder = (HasChildQueryBuilder) queryBuilder.rewrite(searchContext.getQueryShardContext());
             queryBuilder = (HasChildQueryBuilder) queryBuilder.rewrite(searchContext.getQueryShardContext());
             Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
             Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
             InnerHitContextBuilder.extractInnerHits(queryBuilder, innerHitBuilders);
             InnerHitContextBuilder.extractInnerHits(queryBuilder, innerHitBuilders);
+            final InnerHitsContext innerHitsContext = new InnerHitsContext();
             for (InnerHitContextBuilder builder : innerHitBuilders.values()) {
             for (InnerHitContextBuilder builder : innerHitBuilders.values()) {
-                builder.build(searchContext, searchContext.innerHits());
+                builder.build(searchContext, innerHitsContext);
             }
             }
-            assertNotNull(searchContext.innerHits());
-            assertEquals(1, searchContext.innerHits().getInnerHits().size());
-            assertTrue(searchContext.innerHits().getInnerHits().containsKey(queryBuilder.innerHit().getName()));
-            InnerHitsContext.InnerHitSubContext innerHits =
-                    searchContext.innerHits().getInnerHits().get(queryBuilder.innerHit().getName());
+            assertEquals(1, innerHitsContext.getInnerHits().size());
+            assertTrue(innerHitsContext.getInnerHits().containsKey(queryBuilder.innerHit().getName()));
+            InnerHitsContext.InnerHitSubContext innerHits = innerHitsContext.getInnerHits().get(queryBuilder.innerHit().getName());
             assertEquals(innerHits.size(), queryBuilder.innerHit().getSize());
             assertEquals(innerHits.size(), queryBuilder.innerHit().getSize());
             assertEquals(innerHits.sort().sort.getSort().length, 1);
             assertEquals(innerHits.sort().sort.getSort().length, 1);
             assertEquals(innerHits.sort().sort.getSort()[0].getField(), STRING_FIELD_NAME_2);
             assertEquals(innerHits.sort().sort.getSort()[0].getField(), STRING_FIELD_NAME_2);

+ 5 - 7
modules/parent-join/src/test/java/org/elasticsearch/join/query/HasParentQueryBuilderTests.java

@@ -148,17 +148,15 @@ public class HasParentQueryBuilderTests extends AbstractQueryTestCase<HasParentQ
             // doCreateTestQueryBuilder)
             // doCreateTestQueryBuilder)
             queryBuilder = (HasParentQueryBuilder) queryBuilder.rewrite(searchContext.getQueryShardContext());
             queryBuilder = (HasParentQueryBuilder) queryBuilder.rewrite(searchContext.getQueryShardContext());
 
 
-            assertNotNull(searchContext);
             Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
             Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
             InnerHitContextBuilder.extractInnerHits(queryBuilder, innerHitBuilders);
             InnerHitContextBuilder.extractInnerHits(queryBuilder, innerHitBuilders);
+            final InnerHitsContext innerHitsContext = new InnerHitsContext();
             for (InnerHitContextBuilder builder : innerHitBuilders.values()) {
             for (InnerHitContextBuilder builder : innerHitBuilders.values()) {
-                builder.build(searchContext, searchContext.innerHits());
+                builder.build(searchContext, innerHitsContext);
             }
             }
-            assertNotNull(searchContext.innerHits());
-            assertEquals(1, searchContext.innerHits().getInnerHits().size());
-            assertTrue(searchContext.innerHits().getInnerHits().containsKey(queryBuilder.innerHit().getName()));
-            InnerHitsContext.InnerHitSubContext innerHits = searchContext.innerHits()
-                    .getInnerHits().get(queryBuilder.innerHit().getName());
+            assertEquals(1, innerHitsContext.getInnerHits().size());
+            assertTrue(innerHitsContext.getInnerHits().containsKey(queryBuilder.innerHit().getName()));
+            InnerHitsContext.InnerHitSubContext innerHits = innerHitsContext.getInnerHits().get(queryBuilder.innerHit().getName());
             assertEquals(innerHits.size(), queryBuilder.innerHit().getSize());
             assertEquals(innerHits.size(), queryBuilder.innerHit().getSize());
             assertEquals(innerHits.sort().sort.getSort().length, 1);
             assertEquals(innerHits.sort().sort.getSort().length, 1);
             assertEquals(innerHits.sort().sort.getSort()[0].getField(), STRING_FIELD_NAME_2);
             assertEquals(innerHits.sort().sort.getSort()[0].getField(), STRING_FIELD_NAME_2);

+ 7 - 23
server/src/main/java/org/elasticsearch/index/query/InnerHitContextBuilder.java

@@ -29,7 +29,6 @@ import org.elasticsearch.search.sort.SortAndFormats;
 import org.elasticsearch.search.sort.SortBuilder;
 import org.elasticsearch.search.sort.SortBuilder;
 
 
 import java.io.IOException;
 import java.io.IOException;
-import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Optional;
 
 
@@ -47,9 +46,9 @@ public abstract class InnerHitContextBuilder {
         this.query = query;
         this.query = query;
     }
     }
 
 
-    public final void build(SearchContext parentSearchContext, InnerHitsContext innerHitsContext) throws IOException {
+    public final void validate(QueryShardContext queryShardContext) {
         long innerResultWindow = innerHitBuilder.getFrom() + innerHitBuilder.getSize();
         long innerResultWindow = innerHitBuilder.getFrom() + innerHitBuilder.getSize();
-        int maxInnerResultWindow = parentSearchContext.mapperService().getIndexSettings().getMaxInnerResultWindow();
+        int maxInnerResultWindow = queryShardContext.getIndexSettings().getMaxInnerResultWindow();
         if (innerResultWindow > maxInnerResultWindow) {
         if (innerResultWindow > maxInnerResultWindow) {
             throw new IllegalArgumentException(
             throw new IllegalArgumentException(
                 "Inner result window is too large, the inner hit definition's [" + innerHitBuilder.getName() +
                 "Inner result window is too large, the inner hit definition's [" + innerHitBuilder.getName() +
@@ -58,10 +57,12 @@ public abstract class InnerHitContextBuilder {
                     "] index level setting."
                     "] index level setting."
             );
             );
         }
         }
-        doBuild(parentSearchContext, innerHitsContext);
+        doValidate(queryShardContext);
     }
     }
 
 
-    protected abstract void doBuild(SearchContext parentSearchContext, InnerHitsContext innerHitsContext) throws IOException;
+    protected abstract void doValidate(QueryShardContext queryShardContext);
+
+    public abstract void build(SearchContext parentSearchContext, InnerHitsContext innerHitsContext) throws IOException;
 
 
     public static void extractInnerHits(QueryBuilder query, Map<String, InnerHitContextBuilder> innerHitBuilders) {
     public static void extractInnerHits(QueryBuilder query, Map<String, InnerHitContextBuilder> innerHitBuilders) {
         if (query instanceof AbstractQueryBuilder) {
         if (query instanceof AbstractQueryBuilder) {
@@ -109,23 +110,6 @@ public abstract class InnerHitContextBuilder {
         }
         }
         ParsedQuery parsedQuery = new ParsedQuery(query.toQuery(queryShardContext), queryShardContext.copyNamedQueries());
         ParsedQuery parsedQuery = new ParsedQuery(query.toQuery(queryShardContext), queryShardContext.copyNamedQueries());
         innerHitsContext.parsedQuery(parsedQuery);
         innerHitsContext.parsedQuery(parsedQuery);
-        Map<String, InnerHitsContext.InnerHitSubContext> baseChildren =
-            buildChildInnerHits(innerHitsContext.parentSearchContext(), children);
-        innerHitsContext.setChildInnerHits(baseChildren);
-    }
-
-    private static Map<String, InnerHitsContext.InnerHitSubContext> buildChildInnerHits(SearchContext parentSearchContext,
-                                Map<String, InnerHitContextBuilder> children) throws IOException {
-
-        Map<String, InnerHitsContext.InnerHitSubContext> childrenInnerHits = new HashMap<>();
-        for (Map.Entry<String, InnerHitContextBuilder> entry : children.entrySet()) {
-            InnerHitsContext childInnerHitsContext = new InnerHitsContext();
-            entry.getValue().build(
-                parentSearchContext, childInnerHitsContext);
-            if (childInnerHitsContext.getInnerHits() != null) {
-                childrenInnerHits.putAll(childInnerHitsContext.getInnerHits());
-            }
-        }
-        return childrenInnerHits;
+        innerHitsContext.innerHits(children);
     }
     }
 }
 }

+ 20 - 14
server/src/main/java/org/elasticsearch/index/query/NestedQueryBuilder.java

@@ -332,28 +332,34 @@ public class NestedQueryBuilder extends AbstractQueryBuilder<NestedQueryBuilder>
     static class NestedInnerHitContextBuilder extends InnerHitContextBuilder {
     static class NestedInnerHitContextBuilder extends InnerHitContextBuilder {
         private final String path;
         private final String path;
 
 
-        NestedInnerHitContextBuilder(String path, QueryBuilder query, InnerHitBuilder innerHitBuilder,
-                                     Map<String, InnerHitContextBuilder> children) {
+        NestedInnerHitContextBuilder(String path,
+                                        QueryBuilder query,
+                                        InnerHitBuilder innerHitBuilder,
+                                        Map<String, InnerHitContextBuilder> children) {
             super(query, innerHitBuilder, children);
             super(query, innerHitBuilder, children);
             this.path = path;
             this.path = path;
         }
         }
 
 
         @Override
         @Override
-        protected void doBuild(SearchContext parentSearchContext,
-                          InnerHitsContext innerHitsContext) throws IOException {
-            QueryShardContext queryShardContext = parentSearchContext.getQueryShardContext();
+        public void doValidate(QueryShardContext queryShardContext) {
+            if (queryShardContext.getObjectMapper(path) == null
+                    && innerHitBuilder.isIgnoreUnmapped() == false) {
+                throw new IllegalStateException("[" + query.getName() + "] no mapping found for type [" + path + "]");
+            }
+        }
+
+        @Override
+        public void build(SearchContext searchContext, InnerHitsContext innerHitsContext) throws IOException {
+            QueryShardContext queryShardContext = searchContext.getQueryShardContext();
             ObjectMapper nestedObjectMapper = queryShardContext.getObjectMapper(path);
             ObjectMapper nestedObjectMapper = queryShardContext.getObjectMapper(path);
             if (nestedObjectMapper == null) {
             if (nestedObjectMapper == null) {
-                if (innerHitBuilder.isIgnoreUnmapped() == false) {
-                    throw new IllegalStateException("[" + query.getName() + "] no mapping found for type [" + path + "]");
-                } else {
-                    return;
-                }
+                assert innerHitBuilder.isIgnoreUnmapped() : "should be validated first";
+                return;
             }
             }
             String name =  innerHitBuilder.getName() != null ? innerHitBuilder.getName() : nestedObjectMapper.fullPath();
             String name =  innerHitBuilder.getName() != null ? innerHitBuilder.getName() : nestedObjectMapper.fullPath();
             ObjectMapper parentObjectMapper = queryShardContext.nestedScope().nextLevel(nestedObjectMapper);
             ObjectMapper parentObjectMapper = queryShardContext.nestedScope().nextLevel(nestedObjectMapper);
             NestedInnerHitSubContext nestedInnerHits = new NestedInnerHitSubContext(
             NestedInnerHitSubContext nestedInnerHits = new NestedInnerHitSubContext(
-                name, parentSearchContext, parentObjectMapper, nestedObjectMapper
+                name, searchContext, parentObjectMapper, nestedObjectMapper
             );
             );
             setupInnerHitsContext(queryShardContext, nestedInnerHits);
             setupInnerHitsContext(queryShardContext, nestedInnerHits);
             queryShardContext.nestedScope().previousLevel();
             queryShardContext.nestedScope().previousLevel();
@@ -399,9 +405,9 @@ public class NestedQueryBuilder extends AbstractQueryBuilder<NestedQueryBuilder>
                 LeafReaderContext ctx = searcher().getIndexReader().leaves().get(readerIndex);
                 LeafReaderContext ctx = searcher().getIndexReader().leaves().get(readerIndex);
 
 
                 Query childFilter = childObjectMapper.nestedTypeFilter();
                 Query childFilter = childObjectMapper.nestedTypeFilter();
-                BitSetProducer parentFilter = context.bitsetFilterCache().getBitSetProducer(rawParentFilter);
+                BitSetProducer parentFilter = bitsetFilterCache().getBitSetProducer(rawParentFilter);
                 Query q = new ParentChildrenBlockJoinQuery(parentFilter, childFilter, parentDocId);
                 Query q = new ParentChildrenBlockJoinQuery(parentFilter, childFilter, parentDocId);
-                Weight weight = context.searcher().createWeight(context.searcher().rewrite(q),
+                Weight weight = searcher().createWeight(searcher().rewrite(q),
                         org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES, 1f);
                         org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES, 1f);
                 if (size() == 0) {
                 if (size() == 0) {
                     TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
                     TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
@@ -409,7 +415,7 @@ public class NestedQueryBuilder extends AbstractQueryBuilder<NestedQueryBuilder>
                     result[i] = new TopDocsAndMaxScore(new TopDocs(new TotalHits(totalHitCountCollector.getTotalHits(),
                     result[i] = new TopDocsAndMaxScore(new TopDocs(new TotalHits(totalHitCountCollector.getTotalHits(),
                         TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), Float.NaN);
                         TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), Float.NaN);
                 } else {
                 } else {
-                    int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
+                    int topN = Math.min(from() + size(), searcher().getIndexReader().maxDoc());
                     TopDocsCollector<?> topDocsCollector;
                     TopDocsCollector<?> topDocsCollector;
                     MaxScoreCollector maxScoreCollector = null;
                     MaxScoreCollector maxScoreCollector = null;
                     if (sort() != null) {
                     if (sort() != null) {

+ 12 - 0
server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java

@@ -44,6 +44,7 @@ import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.ObjectMapper;
 import org.elasticsearch.index.mapper.ObjectMapper;
 import org.elasticsearch.index.query.AbstractQueryBuilder;
 import org.elasticsearch.index.query.AbstractQueryBuilder;
+import org.elasticsearch.index.query.InnerHitContextBuilder;
 import org.elasticsearch.index.query.ParsedQuery;
 import org.elasticsearch.index.query.ParsedQuery;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.index.query.QueryShardContext;
@@ -110,6 +111,7 @@ final class DefaultSearchContext extends SearchContext {
     private ScriptFieldsContext scriptFields;
     private ScriptFieldsContext scriptFields;
     private FetchSourceContext fetchSourceContext;
     private FetchSourceContext fetchSourceContext;
     private DocValueFieldsContext docValueFieldsContext;
     private DocValueFieldsContext docValueFieldsContext;
+    private Map<String, InnerHitContextBuilder> innerHits = Collections.emptyMap();
     private int from = -1;
     private int from = -1;
     private int size = -1;
     private int size = -1;
     private SortAndFormats sort;
     private SortAndFormats sort;
@@ -377,6 +379,16 @@ final class DefaultSearchContext extends SearchContext {
         this.highlight = highlight;
         this.highlight = highlight;
     }
     }
 
 
+    @Override
+    public void innerHits(Map<String, InnerHitContextBuilder> innerHits) {
+        this.innerHits = innerHits;
+    }
+
+    @Override
+    public Map<String, InnerHitContextBuilder> innerHits() {
+        return innerHits;
+    }
+
     @Override
     @Override
     public SuggestionSearchContext suggest() {
     public SuggestionSearchContext suggest() {
         return suggest;
         return suggest;

+ 2 - 5
server/src/main/java/org/elasticsearch/search/SearchService.java

@@ -739,6 +739,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         context.from(source.from());
         context.from(source.from());
         context.size(source.size());
         context.size(source.size());
         Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
         Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
+        context.innerHits(innerHitBuilders);
         if (source.query() != null) {
         if (source.query() != null) {
             InnerHitContextBuilder.extractInnerHits(source.query(), innerHitBuilders);
             InnerHitContextBuilder.extractInnerHits(source.query(), innerHitBuilders);
             context.parsedQuery(queryShardContext.toQuery(source.query()));
             context.parsedQuery(queryShardContext.toQuery(source.query()));
@@ -749,11 +750,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         }
         }
         if (innerHitBuilders.size() > 0) {
         if (innerHitBuilders.size() > 0) {
             for (Map.Entry<String, InnerHitContextBuilder> entry : innerHitBuilders.entrySet()) {
             for (Map.Entry<String, InnerHitContextBuilder> entry : innerHitBuilders.entrySet()) {
-                try {
-                    entry.getValue().build(context, context.innerHits());
-                } catch (IOException e) {
-                    throw new SearchContextException(context, "failed to build inner_hits", e);
-                }
+                entry.getValue().validate(queryShardContext);
             }
             }
         }
         }
         if (source.sorts() != null) {
         if (source.sorts() != null) {

+ 1 - 22
server/src/main/java/org/elasticsearch/search/fetch/subphase/InnerHitsContext.java

@@ -41,7 +41,6 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
-import java.util.Objects;
 
 
 /**
 /**
  * Context used for inner hits retrieval
  * Context used for inner hits retrieval
@@ -53,10 +52,6 @@ public final class InnerHitsContext {
         this.innerHits = new HashMap<>();
         this.innerHits = new HashMap<>();
     }
     }
 
 
-    InnerHitsContext(Map<String, InnerHitSubContext> innerHits) {
-        this.innerHits = Objects.requireNonNull(innerHits);
-    }
-
     public Map<String, InnerHitSubContext> getInnerHits() {
     public Map<String, InnerHitSubContext> getInnerHits() {
         return innerHits;
         return innerHits;
     }
     }
@@ -77,8 +72,6 @@ public final class InnerHitsContext {
     public abstract static class InnerHitSubContext extends SubSearchContext {
     public abstract static class InnerHitSubContext extends SubSearchContext {
 
 
         private final String name;
         private final String name;
-        protected final SearchContext context;
-        private InnerHitsContext childInnerHits;
 
 
         // TODO: when types are complete removed just use String instead for the id:
         // TODO: when types are complete removed just use String instead for the id:
         private Uid uid;
         private Uid uid;
@@ -86,7 +79,6 @@ public final class InnerHitsContext {
         protected InnerHitSubContext(String name, SearchContext context) {
         protected InnerHitSubContext(String name, SearchContext context) {
             super(context);
             super(context);
             this.name = name;
             this.name = name;
-            this.context = context;
         }
         }
 
 
         public abstract TopDocsAndMaxScore[] topDocs(SearchHit[] hits) throws IOException;
         public abstract TopDocsAndMaxScore[] topDocs(SearchHit[] hits) throws IOException;
@@ -95,25 +87,12 @@ public final class InnerHitsContext {
             return name;
             return name;
         }
         }
 
 
-        @Override
-        public InnerHitsContext innerHits() {
-            return childInnerHits;
-        }
-
-        public void setChildInnerHits(Map<String, InnerHitSubContext> childInnerHits) {
-            this.childInnerHits = new InnerHitsContext(childInnerHits);
-        }
-
         protected Weight createInnerHitQueryWeight() throws IOException {
         protected Weight createInnerHitQueryWeight() throws IOException {
             final boolean needsScores = size() != 0 && (sort() == null || sort().sort.needsScores());
             final boolean needsScores = size() != 0 && (sort() == null || sort().sort.needsScores());
-            return context.searcher().createWeight(context.searcher().rewrite(query()),
+            return searcher().createWeight(searcher().rewrite(query()),
                     needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES, 1f);
                     needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES, 1f);
         }
         }
 
 
-        public SearchContext parentSearchContext() {
-            return context;
-        }
-
         public Uid getUid() {
         public Uid getUid() {
             return uid;
             return uid;
         }
         }

+ 8 - 2
server/src/main/java/org/elasticsearch/search/fetch/subphase/InnerHitsFetchSubPhase.java

@@ -23,6 +23,7 @@ import org.apache.lucene.search.FieldDoc;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.ScoreDoc;
 import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
 import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
 import org.elasticsearch.index.mapper.Uid;
 import org.elasticsearch.index.mapper.Uid;
+import org.elasticsearch.index.query.InnerHitContextBuilder;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.fetch.FetchPhase;
 import org.elasticsearch.search.fetch.FetchPhase;
@@ -44,11 +45,16 @@ public final class InnerHitsFetchSubPhase implements FetchSubPhase {
 
 
     @Override
     @Override
     public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOException {
     public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOException {
-        if ((context.innerHits() != null && context.innerHits().getInnerHits().size() > 0) == false) {
+        if (context.innerHits().isEmpty()) {
             return;
             return;
         }
         }
 
 
-        for (Map.Entry<String, InnerHitsContext.InnerHitSubContext> entry : context.innerHits().getInnerHits().entrySet()) {
+        final InnerHitsContext innerHitsContext = new InnerHitsContext();
+        for (Map.Entry<String, InnerHitContextBuilder> entry : context.innerHits().entrySet()) {
+            entry.getValue().build(context, innerHitsContext);
+        }
+
+        for (Map.Entry<String, InnerHitsContext.InnerHitSubContext> entry : innerHitsContext.getInnerHits().entrySet()) {
             InnerHitsContext.InnerHitSubContext innerHits = entry.getValue();
             InnerHitsContext.InnerHitSubContext innerHits = entry.getValue();
             TopDocsAndMaxScore[] topDocs = innerHits.topDocs(hits);
             TopDocsAndMaxScore[] topDocs = innerHits.topDocs(hits);
             for (int i = 0; i < hits.length; i++) {
             for (int i = 0; i < hits.length; i++) {

+ 7 - 2
server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java

@@ -31,6 +31,7 @@ import org.elasticsearch.index.fielddata.IndexFieldData;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.ObjectMapper;
 import org.elasticsearch.index.mapper.ObjectMapper;
+import org.elasticsearch.index.query.InnerHitContextBuilder;
 import org.elasticsearch.index.query.ParsedQuery;
 import org.elasticsearch.index.query.ParsedQuery;
 import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.index.shard.IndexShard;
@@ -44,7 +45,6 @@ import org.elasticsearch.search.fetch.FetchPhase;
 import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.fetch.StoredFieldsContext;
 import org.elasticsearch.search.fetch.StoredFieldsContext;
 import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
 import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
-import org.elasticsearch.search.fetch.subphase.InnerHitsContext;
 import org.elasticsearch.search.fetch.subphase.ScriptFieldsContext;
 import org.elasticsearch.search.fetch.subphase.ScriptFieldsContext;
 import org.elasticsearch.search.fetch.subphase.highlight.SearchContextHighlight;
 import org.elasticsearch.search.fetch.subphase.highlight.SearchContextHighlight;
 import org.elasticsearch.search.lookup.SearchLookup;
 import org.elasticsearch.search.lookup.SearchLookup;
@@ -176,10 +176,15 @@ public abstract class FilteredSearchContext extends SearchContext {
     }
     }
 
 
     @Override
     @Override
-    public InnerHitsContext innerHits() {
+    public Map<String, InnerHitContextBuilder> innerHits() {
         return in.innerHits();
         return in.innerHits();
     }
     }
 
 
+    @Override
+    public void innerHits(Map<String, InnerHitContextBuilder> innerHits) {
+        in.innerHits(innerHits);
+    }
+
     @Override
     @Override
     public SuggestionSearchContext suggest() {
     public SuggestionSearchContext suggest() {
         return in.suggest();
         return in.suggest();

+ 4 - 9
server/src/main/java/org/elasticsearch/search/internal/SearchContext.java

@@ -18,7 +18,6 @@
  */
  */
 package org.elasticsearch.search.internal;
 package org.elasticsearch.search.internal;
 
 
-
 import org.apache.lucene.search.Collector;
 import org.apache.lucene.search.Collector;
 import org.apache.lucene.search.FieldDoc;
 import org.apache.lucene.search.FieldDoc;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.Query;
@@ -37,6 +36,7 @@ import org.elasticsearch.index.fielddata.IndexFieldData;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.ObjectMapper;
 import org.elasticsearch.index.mapper.ObjectMapper;
+import org.elasticsearch.index.query.InnerHitContextBuilder;
 import org.elasticsearch.index.query.ParsedQuery;
 import org.elasticsearch.index.query.ParsedQuery;
 import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.index.shard.IndexShard;
@@ -51,7 +51,6 @@ import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.fetch.StoredFieldsContext;
 import org.elasticsearch.search.fetch.StoredFieldsContext;
 import org.elasticsearch.search.fetch.subphase.DocValueFieldsContext;
 import org.elasticsearch.search.fetch.subphase.DocValueFieldsContext;
 import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
 import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
-import org.elasticsearch.search.fetch.subphase.InnerHitsContext;
 import org.elasticsearch.search.fetch.subphase.ScriptFieldsContext;
 import org.elasticsearch.search.fetch.subphase.ScriptFieldsContext;
 import org.elasticsearch.search.fetch.subphase.highlight.SearchContextHighlight;
 import org.elasticsearch.search.fetch.subphase.highlight.SearchContextHighlight;
 import org.elasticsearch.search.lookup.SearchLookup;
 import org.elasticsearch.search.lookup.SearchLookup;
@@ -87,7 +86,6 @@ public abstract class SearchContext extends AbstractRefCounted implements Releas
 
 
     private Map<Lifetime, List<Releasable>> clearables = null;
     private Map<Lifetime, List<Releasable>> clearables = null;
     private final AtomicBoolean closed = new AtomicBoolean(false);
     private final AtomicBoolean closed = new AtomicBoolean(false);
-    private InnerHitsContext innerHitsContext;
 
 
     protected SearchContext() {
     protected SearchContext() {
         super("search_context");
         super("search_context");
@@ -164,12 +162,9 @@ public abstract class SearchContext extends AbstractRefCounted implements Releas
 
 
     public abstract void highlight(SearchContextHighlight highlight);
     public abstract void highlight(SearchContextHighlight highlight);
 
 
-    public InnerHitsContext innerHits() {
-        if (innerHitsContext == null) {
-            innerHitsContext = new InnerHitsContext();
-        }
-        return innerHitsContext;
-    }
+    public abstract void innerHits(Map<String, InnerHitContextBuilder> innerHits);
+
+    public abstract Map<String, InnerHitContextBuilder> innerHits();
 
 
     public abstract SuggestionSearchContext suggest();
     public abstract SuggestionSearchContext suggest();
 
 

+ 24 - 0
server/src/main/java/org/elasticsearch/search/internal/SubSearchContext.java

@@ -20,7 +20,9 @@ package org.elasticsearch.search.internal;
 
 
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.Query;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.index.query.InnerHitContextBuilder;
 import org.elasticsearch.index.query.ParsedQuery;
 import org.elasticsearch.index.query.ParsedQuery;
+import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.search.aggregations.SearchContextAggregations;
 import org.elasticsearch.search.aggregations.SearchContextAggregations;
 import org.elasticsearch.search.collapse.CollapseContext;
 import org.elasticsearch.search.collapse.CollapseContext;
 import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.fetch.FetchSearchResult;
@@ -34,7 +36,9 @@ import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.sort.SortAndFormats;
 import org.elasticsearch.search.sort.SortAndFormats;
 import org.elasticsearch.search.suggest.SuggestionSearchContext;
 import org.elasticsearch.search.suggest.SuggestionSearchContext;
 
 
+import java.util.Collections;
 import java.util.List;
 import java.util.List;
+import java.util.Map;
 
 
 public class SubSearchContext extends FilteredSearchContext {
 public class SubSearchContext extends FilteredSearchContext {
 
 
@@ -42,6 +46,8 @@ public class SubSearchContext extends FilteredSearchContext {
     // the to hits are returned per bucket.
     // the to hits are returned per bucket.
     private static final int DEFAULT_SIZE = 3;
     private static final int DEFAULT_SIZE = 3;
 
 
+    private final QueryShardContext queryShardContext;
+
     private int from;
     private int from;
     private int size = DEFAULT_SIZE;
     private int size = DEFAULT_SIZE;
     private SortAndFormats sort;
     private SortAndFormats sort;
@@ -60,6 +66,7 @@ public class SubSearchContext extends FilteredSearchContext {
     private FetchSourceContext fetchSourceContext;
     private FetchSourceContext fetchSourceContext;
     private DocValueFieldsContext docValueFieldsContext;
     private DocValueFieldsContext docValueFieldsContext;
     private SearchContextHighlight highlight;
     private SearchContextHighlight highlight;
+    private Map<String, InnerHitContextBuilder> innerHits = Collections.emptyMap();
 
 
     private boolean explain;
     private boolean explain;
     private boolean trackScores;
     private boolean trackScores;
@@ -70,6 +77,9 @@ public class SubSearchContext extends FilteredSearchContext {
         super(context);
         super(context);
         this.fetchSearchResult = new FetchSearchResult();
         this.fetchSearchResult = new FetchSearchResult();
         this.querySearchResult = new QuerySearchResult();
         this.querySearchResult = new QuerySearchResult();
+        // we clone the query shard context in the sub context because the original one
+        // might be frozen at this point.
+        this.queryShardContext = new QueryShardContext(context.getQueryShardContext());
     }
     }
 
 
     @Override
     @Override
@@ -80,6 +90,11 @@ public class SubSearchContext extends FilteredSearchContext {
     public void preProcess(boolean rewrite) {
     public void preProcess(boolean rewrite) {
     }
     }
 
 
+    @Override
+    public QueryShardContext getQueryShardContext() {
+        return queryShardContext;
+    }
+
     @Override
     @Override
     public Query buildFilteredQuery(Query query) {
     public Query buildFilteredQuery(Query query) {
         throw new UnsupportedOperationException("this context should be read only");
         throw new UnsupportedOperationException("this context should be read only");
@@ -357,4 +372,13 @@ public class SubSearchContext extends FilteredSearchContext {
         throw new UnsupportedOperationException("Not supported");
         throw new UnsupportedOperationException("Not supported");
     }
     }
 
 
+    @Override
+    public Map<String, InnerHitContextBuilder> innerHits() {
+        return innerHits;
+    }
+
+    @Override
+    public void innerHits(Map<String, InnerHitContextBuilder> innerHits) {
+        this.innerHits = innerHits;
+    }
 }
 }

+ 7 - 7
server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java

@@ -106,13 +106,13 @@ public class NestedQueryBuilderTests extends AbstractQueryTestCase<NestedQueryBu
             assertNotNull(searchContext);
             assertNotNull(searchContext);
             Map<String, InnerHitContextBuilder> innerHitInternals = new HashMap<>();
             Map<String, InnerHitContextBuilder> innerHitInternals = new HashMap<>();
             InnerHitContextBuilder.extractInnerHits(queryBuilder, innerHitInternals);
             InnerHitContextBuilder.extractInnerHits(queryBuilder, innerHitInternals);
+            InnerHitsContext innerHitsContext = new InnerHitsContext();
             for (InnerHitContextBuilder builder : innerHitInternals.values()) {
             for (InnerHitContextBuilder builder : innerHitInternals.values()) {
-                builder.build(searchContext, searchContext.innerHits());
+                builder.build(searchContext, innerHitsContext);
             }
             }
-            assertNotNull(searchContext.innerHits());
-            assertEquals(1, searchContext.innerHits().getInnerHits().size());
-            assertTrue(searchContext.innerHits().getInnerHits().containsKey(queryBuilder.innerHit().getName()));
-            InnerHitsContext.InnerHitSubContext innerHits = searchContext.innerHits().getInnerHits().get(queryBuilder.innerHit().getName());
+            assertEquals(1, innerHitsContext.getInnerHits().size());
+            assertTrue(innerHitsContext.getInnerHits().containsKey(queryBuilder.innerHit().getName()));
+            InnerHitsContext.InnerHitSubContext innerHits = innerHitsContext.getInnerHits().get(queryBuilder.innerHit().getName());
             assertEquals(innerHits.size(), queryBuilder.innerHit().getSize());
             assertEquals(innerHits.size(), queryBuilder.innerHit().getSize());
             assertEquals(innerHits.sort().sort.getSort().length, 1);
             assertEquals(innerHits.sort().sort.getSort().length, 1);
             assertEquals(innerHits.sort().sort.getSort()[0].getField(), INT_FIELD_NAME);
             assertEquals(innerHits.sort().sort.getSort()[0].getField(), INT_FIELD_NAME);
@@ -328,7 +328,7 @@ public class NestedQueryBuilderTests extends AbstractQueryTestCase<NestedQueryBu
 
 
         MapperService mapperService = mock(MapperService.class);
         MapperService mapperService = mock(MapperService.class);
         IndexSettings settings = new IndexSettings(newIndexMeta("index", Settings.EMPTY), Settings.EMPTY);
         IndexSettings settings = new IndexSettings(newIndexMeta("index", Settings.EMPTY), Settings.EMPTY);
-        when(mapperService.getIndexSettings()).thenReturn(settings);
+        when(queryShardContext.getIndexSettings()).thenReturn(settings);
         when(searchContext.mapperService()).thenReturn(mapperService);
         when(searchContext.mapperService()).thenReturn(mapperService);
 
 
         InnerHitBuilder leafInnerHits = randomNestedInnerHits();
         InnerHitBuilder leafInnerHits = randomNestedInnerHits();
@@ -340,7 +340,7 @@ public class NestedQueryBuilderTests extends AbstractQueryTestCase<NestedQueryBu
             query1.extractInnerHitBuilders(innerHitBuilders);
             query1.extractInnerHitBuilders(innerHitBuilders);
             assertThat(innerHitBuilders.size(), Matchers.equalTo(1));
             assertThat(innerHitBuilders.size(), Matchers.equalTo(1));
             assertTrue(innerHitBuilders.containsKey(leafInnerHits.getName()));
             assertTrue(innerHitBuilders.containsKey(leafInnerHits.getName()));
-            innerHitBuilders.get(leafInnerHits.getName()).build(searchContext, innerHitsContext);
+            innerHitBuilders.get(leafInnerHits.getName()).validate(searchContext.getQueryShardContext());
         });
         });
         innerHitBuilders.clear();
         innerHitBuilders.clear();
         NestedQueryBuilder query2 = new NestedQueryBuilder("path", new MatchAllQueryBuilder(), ScoreMode.None);
         NestedQueryBuilder query2 = new NestedQueryBuilder("path", new MatchAllQueryBuilder(), ScoreMode.None);

+ 9 - 0
test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java

@@ -31,6 +31,7 @@ import org.elasticsearch.index.fielddata.IndexFieldData;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.ObjectMapper;
 import org.elasticsearch.index.mapper.ObjectMapper;
+import org.elasticsearch.index.query.InnerHitContextBuilder;
 import org.elasticsearch.index.query.ParsedQuery;
 import org.elasticsearch.index.query.ParsedQuery;
 import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.index.query.QueryShardContext;
 import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.index.shard.IndexShard;
@@ -199,6 +200,14 @@ public class TestSearchContext extends SearchContext {
     public void highlight(SearchContextHighlight highlight) {
     public void highlight(SearchContextHighlight highlight) {
     }
     }
 
 
+    @Override
+    public void innerHits(Map<String, InnerHitContextBuilder> innerHits) {}
+
+    @Override
+    public Map<String, InnerHitContextBuilder> innerHits() {
+        return null;
+    }
+
     @Override
     @Override
     public SuggestionSearchContext suggest() {
     public SuggestionSearchContext suggest() {
         return null;
         return null;