Browse Source

Save memory when parent and child are not on top (#57892)

Reworks the `parent` and `child` aggregation are not at the top level
using the optimization from #55873. Instead of wrapping all
non-top-level `parent` and `child` aggregators we now handle being a
child aggregator in the aggregator, specifically by adding recording
which global ordinals show up in the parent and then checking if they
match the child.
Nik Everett 5 years ago
parent
commit
3adbe5b106

+ 2 - 6
modules/parent-join/src/main/java/org/elasticsearch/join/aggregations/ChildrenAggregatorFactory.java

@@ -79,12 +79,8 @@ public class ChildrenAggregatorFactory extends ValuesSourceAggregatorFactory {
         }
         WithOrdinals valuesSource = (WithOrdinals) rawValuesSource;
         long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher());
-        if (collectsFromSingleBucket) {
-            return new ParentToChildrenAggregator(name, factories, searchContext, parent, childFilter,
-                parentFilter, valuesSource, maxOrd, metadata);
-        } else {
-            return asMultiBucketAggregator(this, searchContext, parent);
-        }
+        return new ParentToChildrenAggregator(name, factories, searchContext, parent, childFilter,
+            parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
     }
 
     @Override

+ 2 - 2
modules/parent-join/src/main/java/org/elasticsearch/join/aggregations/ChildrenToParentAggregator.java

@@ -40,8 +40,8 @@ public class ChildrenToParentAggregator extends ParentJoinAggregator {
     public ChildrenToParentAggregator(String name, AggregatorFactories factories,
             SearchContext context, Aggregator parent, Query childFilter,
             Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource,
-            long maxOrd, Map<String, Object> metadata) throws IOException {
-        super(name, factories, context, parent, childFilter, parentFilter, valuesSource, maxOrd, metadata);
+            long maxOrd, boolean collectsFromSingleBucket, Map<String, Object> metadata) throws IOException {
+        super(name, factories, context, parent, childFilter, parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
     }
 
     @Override

+ 2 - 6
modules/parent-join/src/main/java/org/elasticsearch/join/aggregations/ParentAggregatorFactory.java

@@ -80,12 +80,8 @@ public class ParentAggregatorFactory extends ValuesSourceAggregatorFactory {
         }
         WithOrdinals valuesSource = (WithOrdinals) rawValuesSource;
         long maxOrd = valuesSource.globalMaxOrd(searchContext.searcher());
-        if (collectsFromSingleBucket) {
-            return new ChildrenToParentAggregator(name, factories, searchContext, children, childFilter,
-                parentFilter, valuesSource, maxOrd, metadata);
-        } else {
-            return asMultiBucketAggregator(this, searchContext, children);
-        }
+        return new ChildrenToParentAggregator(name, factories, searchContext, children, childFilter,
+            parentFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
     }
 
     @Override

+ 36 - 23
modules/parent-join/src/main/java/org/elasticsearch/join/aggregations/ParentJoinAggregator.java

@@ -33,12 +33,12 @@ import org.elasticsearch.common.lease.Releasables;
 import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.BitArray;
-import org.elasticsearch.common.util.LongHash;
 import org.elasticsearch.search.aggregations.Aggregator;
 import org.elasticsearch.search.aggregations.AggregatorFactories;
 import org.elasticsearch.search.aggregations.LeafBucketCollector;
 import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
 import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
+import org.elasticsearch.search.aggregations.bucket.terms.LongKeyedBucketOrds;
 import org.elasticsearch.search.aggregations.support.ValuesSource;
 import org.elasticsearch.search.internal.SearchContext;
 
@@ -68,6 +68,7 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
                                     Query outFilter,
                                     ValuesSource.Bytes.WithOrdinals valuesSource,
                                     long maxOrd,
+                                    boolean collectsFromSingleBucket,
                                     Map<String, Object> metadata) throws IOException {
         super(name, factories, context, parent, metadata);
 
@@ -81,8 +82,9 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
         this.outFilter = context.searcher().createWeight(context.searcher().rewrite(outFilter), ScoreMode.COMPLETE_NO_SCORES, 1f);
         this.valuesSource = valuesSource;
         boolean singleAggregator = parent == null;
-        collectionStrategy = singleAggregator ?
-                new DenseCollectionStrategy(maxOrd, context.bigArrays()) : new SparseCollectionStrategy(context.bigArrays());
+        collectionStrategy = singleAggregator && collectsFromSingleBucket
+            ? new DenseCollectionStrategy(maxOrd, context.bigArrays())
+            : new SparseCollectionStrategy(context.bigArrays(), collectsFromSingleBucket);
     }
 
     @Override
@@ -95,19 +97,18 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
         final Bits parentDocs = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), inFilter.scorerSupplier(ctx));
         return new LeafBucketCollector() {
             @Override
-            public void collect(int docId, long bucket) throws IOException {
-                assert bucket == 0;
+            public void collect(int docId, long owningBucketOrd) throws IOException {
                 if (parentDocs.get(docId) && globalOrdinals.advanceExact(docId)) {
                     int globalOrdinal = (int) globalOrdinals.nextOrd();
                     assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
-                    collectionStrategy.addGlobalOrdinal(globalOrdinal);
+                    collectionStrategy.add(owningBucketOrd, globalOrdinal);
                 }
             }
         };
     }
 
     @Override
-    protected final void doPostCollection() throws IOException {
+    protected void beforeBuildingBuckets(long[] ordsToCollect) throws IOException {
         IndexReader indexReader = context().searcher().getIndexReader();
         for (LeafReaderContext ctx : indexReader.leaves()) {
             Scorer childDocsScorer = outFilter.scorer(ctx);
@@ -137,11 +138,21 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
                 if (liveDocs != null && liveDocs.get(docId) == false) {
                     continue;
                 }
-                if (globalOrdinals.advanceExact(docId)) {
-                    int globalOrdinal = (int) globalOrdinals.nextOrd();
-                    assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
-                    if (collectionStrategy.existsGlobalOrdinal(globalOrdinal)) {
-                        collectBucket(sub, docId, 0);
+                if (false == globalOrdinals.advanceExact(docId)) {
+                    continue;
+                }
+                int globalOrdinal = (int) globalOrdinals.nextOrd();
+                assert globalOrdinal != -1 && globalOrdinals.nextOrd() == SortedSetDocValues.NO_MORE_ORDS;
+                /*
+                 * Check if we contain every ordinal. It's almost certainly be
+                 * faster to replay all the matching ordinals and filter them down
+                 * to just those listed in ordsToCollect, but we don't have a data
+                 * structure that maps a primitive long to a list of primitive
+                 * longs. 
+                 */
+                for (long owningBucketOrd: ordsToCollect) {
+                    if (collectionStrategy.exists(owningBucketOrd, globalOrdinal)) {
+                        collectBucket(sub, docId, owningBucketOrd);
                     }
                 }
             }
@@ -160,8 +171,8 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
      * {@code ParentJoinAggregator#outFilter} also have the ordinal.
      */
     protected interface CollectionStrategy extends Releasable {
-        void addGlobalOrdinal(int globalOrdinal);
-        boolean existsGlobalOrdinal(int globalOrdinal);
+        void add(long owningBucketOrd, int globalOrdinal);
+        boolean exists(long owningBucketOrd, int globalOrdinal);
     }
 
     /**
@@ -178,12 +189,14 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
         }
 
         @Override
-        public void addGlobalOrdinal(int globalOrdinal) {
+        public void add(long owningBucketOrd, int globalOrdinal) {
+            assert owningBucketOrd == 0;
             ordsBits.set(globalOrdinal);
         }
 
         @Override
-        public boolean existsGlobalOrdinal(int globalOrdinal) {
+        public boolean exists(long owningBucketOrd, int globalOrdinal) {
+            assert owningBucketOrd == 0;
             return ordsBits.get(globalOrdinal);
         }
 
@@ -200,20 +213,20 @@ public abstract class ParentJoinAggregator extends BucketsAggregator implements
      * when only some docs might match.
      */
     protected class SparseCollectionStrategy implements CollectionStrategy {
-        private final LongHash ordsHash;
+        private final LongKeyedBucketOrds ordsHash;
 
-        public SparseCollectionStrategy(BigArrays bigArrays) {
-            ordsHash = new LongHash(1, bigArrays);
+        public SparseCollectionStrategy(BigArrays bigArrays, boolean collectsFromSingleBucket) {
+            ordsHash = LongKeyedBucketOrds.build(bigArrays, collectsFromSingleBucket);
         }
 
         @Override
-        public void addGlobalOrdinal(int globalOrdinal) {
-            ordsHash.add(globalOrdinal);
+        public void add(long owningBucketOrd, int globalOrdinal) {
+            ordsHash.add(owningBucketOrd, globalOrdinal);
         }
 
         @Override
-        public boolean existsGlobalOrdinal(int globalOrdinal) {
-            return ordsHash.find(globalOrdinal) >= 0;
+        public boolean exists(long owningBucketOrd, int globalOrdinal) {
+            return ordsHash.find(owningBucketOrd, globalOrdinal) >= 0;
         }
 
         @Override

+ 2 - 2
modules/parent-join/src/main/java/org/elasticsearch/join/aggregations/ParentToChildrenAggregator.java

@@ -36,8 +36,8 @@ public class ParentToChildrenAggregator extends ParentJoinAggregator {
     public ParentToChildrenAggregator(String name, AggregatorFactories factories,
             SearchContext context, Aggregator parent, Query childFilter,
             Query parentFilter, ValuesSource.Bytes.WithOrdinals valuesSource,
-            long maxOrd, Map<String, Object> metadata) throws IOException {
-        super(name, factories, context, parent, parentFilter, childFilter, valuesSource, maxOrd, metadata);
+            long maxOrd, boolean collectsFromSingleBucket, Map<String, Object> metadata) throws IOException {
+        super(name, factories, context, parent, parentFilter, childFilter, valuesSource, maxOrd, collectsFromSingleBucket, metadata);
     }
 
     @Override

+ 3 - 5
modules/parent-join/src/test/java/org/elasticsearch/join/aggregations/ChildrenToParentAggregatorTests.java

@@ -59,7 +59,6 @@ import org.elasticsearch.search.aggregations.bucket.terms.LongTerms;
 import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.InternalMin;
 import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
-import org.elasticsearch.search.aggregations.support.ValueType;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -313,8 +312,7 @@ public class ChildrenToParentAggregatorTests extends AggregatorTestCase {
             throws IOException {
 
         ParentAggregationBuilder aggregationBuilder = new ParentAggregationBuilder("_name", CHILD_TYPE);
-        aggregationBuilder.subAggregation(new TermsAggregationBuilder("value_terms").userValueTypeHint(ValueType.LONG)
-            .field("number"));
+        aggregationBuilder.subAggregation(new TermsAggregationBuilder("value_terms").field("number"));
 
         MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
         fieldType.setName("number");
@@ -326,9 +324,9 @@ public class ChildrenToParentAggregatorTests extends AggregatorTestCase {
     private void testCaseTermsParentTerms(Query query, IndexSearcher indexSearcher, Consumer<LongTerms> verify)
             throws IOException {
         AggregationBuilder aggregationBuilder =
-            new TermsAggregationBuilder("subvalue_terms").userValueTypeHint(ValueType.LONG).field("subNumber").
+            new TermsAggregationBuilder("subvalue_terms").field("subNumber").
                 subAggregation(new ParentAggregationBuilder("to_parent", CHILD_TYPE).
-                    subAggregation(new TermsAggregationBuilder("value_terms").userValueTypeHint(ValueType.LONG).field("number")));
+                    subAggregation(new TermsAggregationBuilder("value_terms").field("number")));
 
         MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.LONG);
         fieldType.setName("number");

+ 64 - 2
modules/parent-join/src/test/java/org/elasticsearch/join/aggregations/ParentToChildrenAggregatorTests.java

@@ -22,6 +22,7 @@ package org.elasticsearch.join.aggregations;
 import org.apache.lucene.document.Field;
 import org.apache.lucene.document.SortedDocValuesField;
 import org.apache.lucene.document.SortedNumericDocValuesField;
+import org.apache.lucene.document.SortedSetDocValuesField;
 import org.apache.lucene.document.StringField;
 import org.apache.lucene.index.DirectoryReader;
 import org.apache.lucene.index.IndexReader;
@@ -52,7 +53,10 @@ import org.elasticsearch.join.ParentJoinPlugin;
 import org.elasticsearch.join.mapper.MetaJoinFieldMapper;
 import org.elasticsearch.join.mapper.ParentJoinFieldMapper;
 import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregatorTestCase;
+import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
+import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.InternalMin;
 import org.elasticsearch.search.aggregations.metrics.MinAggregationBuilder;
 
@@ -64,6 +68,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.function.Consumer;
 
+import static org.hamcrest.Matchers.equalTo;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -124,12 +129,68 @@ public class ParentToChildrenAggregatorTests extends AggregatorTestCase {
         directory.close();
     }
 
+    public void testParentChildAsSubAgg() throws IOException {
+        try (Directory directory = newDirectory()) {
+            RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory);
+
+            final Map<String, Tuple<Integer, Integer>> expectedParentChildRelations = setupIndex(indexWriter);
+            indexWriter.close();
+
+            try (
+                IndexReader indexReader = ElasticsearchDirectoryReader.wrap(
+                    DirectoryReader.open(directory),
+                    new ShardId(new Index("foo", "_na_"), 1)
+                )
+            ) {
+                IndexSearcher indexSearcher = newSearcher(indexReader, false, true);
+
+                AggregationBuilder request = new TermsAggregationBuilder("t").field("kwd")
+                    .subAggregation(
+                        new ChildrenAggregationBuilder("children", CHILD_TYPE).subAggregation(
+                            new MinAggregationBuilder("min").field("number")
+                        )
+                    );
+
+                long expectedEvenChildCount = 0;
+                double expectedEvenMin = Double.MAX_VALUE;
+                long expectedOddChildCount = 0;
+                double expectedOddMin = Double.MAX_VALUE;
+                for (Map.Entry<String, Tuple<Integer, Integer>> e : expectedParentChildRelations.entrySet()) {
+                    if (Integer.valueOf(e.getKey().substring("parent".length())) % 2 == 0) {
+                        expectedEvenChildCount += e.getValue().v1();
+                        expectedEvenMin = Math.min(expectedEvenMin, e.getValue().v2());
+                    } else {
+                        expectedOddChildCount += e.getValue().v1();
+                        expectedOddMin = Math.min(expectedOddMin, e.getValue().v2());
+                    }
+                }
+                StringTerms result = search(indexSearcher, new MatchAllDocsQuery(), request, longField("number"), keywordField("kwd"));
+
+                StringTerms.Bucket evenBucket = result.getBucketByKey("even");
+                InternalChildren evenChildren = evenBucket.getAggregations().get("children");
+                InternalMin evenMin = evenChildren.getAggregations().get("min");
+                assertThat(evenChildren.getDocCount(), equalTo(expectedEvenChildCount));
+                assertThat(evenMin.getValue(), equalTo(expectedEvenMin));
+
+                if (expectedOddChildCount > 0) {
+                    StringTerms.Bucket oddBucket = result.getBucketByKey("odd");
+                    InternalChildren oddChildren = oddBucket.getAggregations().get("children");
+                    InternalMin oddMin = oddChildren.getAggregations().get("min");
+                    assertThat(oddChildren.getDocCount(), equalTo(expectedOddChildCount));
+                    assertThat(oddMin.getValue(), equalTo(expectedOddMin));
+                } else {
+                    assertNull(result.getBucketByKey("odd"));
+                }
+            }
+        }
+    }
+
     private static Map<String, Tuple<Integer, Integer>> setupIndex(RandomIndexWriter iw) throws IOException {
         Map<String, Tuple<Integer, Integer>> expectedValues = new HashMap<>();
         int numParents = randomIntBetween(1, 10);
         for (int i = 0; i < numParents; i++) {
             String parent = "parent" + i;
-            iw.addDocument(createParentDocument(parent));
+            iw.addDocument(createParentDocument(parent, i % 2 == 0 ? "even" : "odd"));
             int numChildren = randomIntBetween(1, 10);
             int minValue = Integer.MAX_VALUE;
             for (int c = 0; c < numChildren; c++) {
@@ -142,9 +203,10 @@ public class ParentToChildrenAggregatorTests extends AggregatorTestCase {
         return expectedValues;
     }
 
-    private static List<Field> createParentDocument(String id) {
+    private static List<Field> createParentDocument(String id, String kwd) {
         return Arrays.asList(
                 new StringField(IdFieldMapper.NAME, Uid.encodeId(id), Field.Store.NO),
+                new SortedSetDocValuesField("kwd", new BytesRef(kwd)),
                 new StringField("join_field", PARENT_TYPE, Field.Store.NO),
                 createJoinField(PARENT_TYPE, id)
         );