瀏覽代碼

Added support for shard_size in terms agg

Closes #4242
uboness 12 年之前
父節點
當前提交
fda6ca4869

+ 7 - 3
src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTerms.java

@@ -111,11 +111,13 @@ public class DoubleTerms extends InternalTerms {
     public InternalTerms reduce(ReduceContext reduceContext) {
         List<InternalAggregation> aggregations = reduceContext.aggregations();
         if (aggregations.size() == 1) {
-            return (InternalTerms) aggregations.get(0);
+            InternalTerms terms = (InternalTerms) aggregations.get(0);
+            terms.trimExcessEntries();
+            return terms;
         }
         InternalTerms reduced = null;
 
-        Recycler.V<DoubleObjectOpenHashMap<List<Bucket>>> buckets = reduceContext.cacheRecycler().doubleObjectMap(-1);
+        Recycler.V<DoubleObjectOpenHashMap<List<Bucket>>> buckets = null;
         for (InternalAggregation aggregation : aggregations) {
             InternalTerms terms = (InternalTerms) aggregation;
             if (terms instanceof UnmappedTerms) {
@@ -124,8 +126,10 @@ public class DoubleTerms extends InternalTerms {
             if (reduced == null) {
                 reduced = terms;
             }
+            if (buckets == null) {
+                buckets = reduceContext.cacheRecycler().doubleObjectMap(terms.buckets.size());
+            }
             for (Terms.Bucket bucket : terms.buckets) {
-
                 List<Bucket> existingBuckets = buckets.v().get(((Bucket) bucket).term);
                 if (existingBuckets == null) {
                     existingBuckets = new ArrayList<Bucket>(aggregations.size());

+ 4 - 2
src/main/java/org/elasticsearch/search/aggregations/bucket/terms/DoubleTermsAggregator.java

@@ -41,15 +41,17 @@ public class DoubleTermsAggregator extends BucketsAggregator {
 
     private final InternalOrder order;
     private final int requiredSize;
+    private final int shardSize;
     private final NumericValuesSource valuesSource;
     private final LongHash bucketOrds;
 
     public DoubleTermsAggregator(String name, AggregatorFactories factories, NumericValuesSource valuesSource,
-                               InternalOrder order, int requiredSize, AggregationContext aggregationContext, Aggregator parent) {
+                               InternalOrder order, int requiredSize, int shardSize, AggregationContext aggregationContext, Aggregator parent) {
         super(name, BucketAggregationMode.PER_BUCKET, factories, INITIAL_CAPACITY, aggregationContext, parent);
         this.valuesSource = valuesSource;
         this.order = order;
         this.requiredSize = requiredSize;
+        this.shardSize = shardSize;
         bucketOrds = new LongHash(INITIAL_CAPACITY);
     }
 
@@ -89,7 +91,7 @@ public class DoubleTermsAggregator extends BucketsAggregator {
     @Override
     public DoubleTerms buildAggregation(long owningBucketOrdinal) {
         assert owningBucketOrdinal == 0;
-        final int size = (int) Math.min(bucketOrds.size(), requiredSize);
+        final int size = (int) Math.min(bucketOrds.size(), shardSize);
 
         BucketPriorityQueue ordered = new BucketPriorityQueue(size, order.comparator());
         OrdinalBucket spare = null;

+ 26 - 2
src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalTerms.java

@@ -129,13 +129,15 @@ public abstract class InternalTerms extends InternalAggregation implements Terms
     public InternalTerms reduce(ReduceContext reduceContext) {
         List<InternalAggregation> aggregations = reduceContext.aggregations();
         if (aggregations.size() == 1) {
-            return (InternalTerms) aggregations.get(0);
+            InternalTerms terms = (InternalTerms) aggregations.get(0);
+            terms.trimExcessEntries();
+            return terms;
         }
         InternalTerms reduced = null;
 
         // TODO: would it be better to use a hppc map and then directly work on the backing array instead of using a PQ?
 
-        Map<Text, List<InternalTerms.Bucket>> buckets = new HashMap<Text, List<InternalTerms.Bucket>>(requiredSize);
+        Map<Text, List<InternalTerms.Bucket>> buckets = null;
         for (InternalAggregation aggregation : aggregations) {
             InternalTerms terms = (InternalTerms) aggregation;
             if (terms instanceof UnmappedTerms) {
@@ -144,6 +146,9 @@ public abstract class InternalTerms extends InternalAggregation implements Terms
             if (reduced == null) {
                 reduced = terms;
             }
+            if (buckets == null) {
+                buckets = new HashMap<Text, List<Bucket>>(terms.buckets.size());
+            }
             for (Bucket bucket : terms.buckets) {
                 List<Bucket> existingBuckets = buckets.get(bucket.getKey());
                 if (existingBuckets == null) {
@@ -173,4 +178,23 @@ public abstract class InternalTerms extends InternalAggregation implements Terms
         return reduced;
     }
 
+    protected void trimExcessEntries() {
+        if (requiredSize >= buckets.size()) {
+            return;
+        }
+
+        if (buckets instanceof List) {
+            buckets = ((List) buckets).subList(0, requiredSize);
+            return;
+        }
+
+        int i = 0;
+        for (Iterator<Bucket> iter  = buckets.iterator(); iter.hasNext();) {
+            iter.next();
+            if (i++ >= requiredSize) {
+                iter.remove();
+            }
+        }
+    }
+
 }

+ 7 - 2
src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTerms.java

@@ -109,11 +109,13 @@ public class LongTerms extends InternalTerms {
     public InternalTerms reduce(ReduceContext reduceContext) {
         List<InternalAggregation> aggregations = reduceContext.aggregations();
         if (aggregations.size() == 1) {
-            return (InternalTerms) aggregations.get(0);
+            InternalTerms terms = (InternalTerms) aggregations.get(0);
+            terms.trimExcessEntries();
+            return terms;
         }
         InternalTerms reduced = null;
 
-        Recycler.V<LongObjectOpenHashMap<List<Bucket>>> buckets = reduceContext.cacheRecycler().longObjectMap(-1);
+        Recycler.V<LongObjectOpenHashMap<List<Bucket>>> buckets = null;
         for (InternalAggregation aggregation : aggregations) {
             InternalTerms terms = (InternalTerms) aggregation;
             if (terms instanceof UnmappedTerms) {
@@ -122,6 +124,9 @@ public class LongTerms extends InternalTerms {
             if (reduced == null) {
                 reduced = terms;
             }
+            if (buckets == null) {
+                buckets = reduceContext.cacheRecycler().longObjectMap(terms.buckets.size());
+            }
             for (Terms.Bucket bucket : terms.buckets) {
                 List<Bucket> existingBuckets = buckets.v().get(((Bucket) bucket).term);
                 if (existingBuckets == null) {

+ 4 - 2
src/main/java/org/elasticsearch/search/aggregations/bucket/terms/LongTermsAggregator.java

@@ -41,15 +41,17 @@ public class LongTermsAggregator extends BucketsAggregator {
 
     private final InternalOrder order;
     private final int requiredSize;
+    private final int shardSize;
     private final NumericValuesSource valuesSource;
     private final LongHash bucketOrds;
 
     public LongTermsAggregator(String name, AggregatorFactories factories, NumericValuesSource valuesSource,
-                               InternalOrder order, int requiredSize, AggregationContext aggregationContext, Aggregator parent) {
+                               InternalOrder order, int requiredSize, int shardSize, AggregationContext aggregationContext, Aggregator parent) {
         super(name, BucketAggregationMode.PER_BUCKET, factories, INITIAL_CAPACITY, aggregationContext, parent);
         this.valuesSource = valuesSource;
         this.order = order;
         this.requiredSize = requiredSize;
+        this.shardSize = shardSize;
         bucketOrds = new LongHash(INITIAL_CAPACITY);
     }
 
@@ -88,7 +90,7 @@ public class LongTermsAggregator extends BucketsAggregator {
     @Override
     public LongTerms buildAggregation(long owningBucketOrdinal) {
         assert owningBucketOrdinal == 0;
-        final int size = (int) Math.min(bucketOrds.size(), requiredSize);
+        final int size = (int) Math.min(bucketOrds.size(), shardSize);
 
         BucketPriorityQueue ordered = new BucketPriorityQueue(size, order.comparator());
         OrdinalBucket spare = null;

+ 4 - 2
src/main/java/org/elasticsearch/search/aggregations/bucket/terms/StringTermsAggregator.java

@@ -43,15 +43,17 @@ public class StringTermsAggregator extends BucketsAggregator {
     private final ValuesSource valuesSource;
     private final InternalOrder order;
     private final int requiredSize;
+    private final int shardSize;
     private final BytesRefHash bucketOrds;
 
     public StringTermsAggregator(String name, AggregatorFactories factories, ValuesSource valuesSource,
-                                 InternalOrder order, int requiredSize, AggregationContext aggregationContext, Aggregator parent) {
+                                 InternalOrder order, int requiredSize, int shardSize, AggregationContext aggregationContext, Aggregator parent) {
 
         super(name, BucketAggregationMode.PER_BUCKET, factories, INITIAL_CAPACITY, aggregationContext, parent);
         this.valuesSource = valuesSource;
         this.order = order;
         this.requiredSize = requiredSize;
+        this.shardSize = shardSize;
         bucketOrds = new BytesRefHash();
     }
 
@@ -91,7 +93,7 @@ public class StringTermsAggregator extends BucketsAggregator {
     @Override
     public StringTerms buildAggregation(long owningBucketOrdinal) {
         assert owningBucketOrdinal == 0;
-        final int size = Math.min(bucketOrds.size(), requiredSize);
+        final int size = Math.min(bucketOrds.size(), shardSize);
 
         BucketPriorityQueue ordered = new BucketPriorityQueue(size, order.comparator());
         OrdinalBucket spare = null;

+ 7 - 5
src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java

@@ -22,11 +22,11 @@ package org.elasticsearch.search.aggregations.bucket.terms;
 import org.elasticsearch.search.aggregations.AggregationExecutionException;
 import org.elasticsearch.search.aggregations.Aggregator;
 import org.elasticsearch.search.aggregations.support.AggregationContext;
+import org.elasticsearch.search.aggregations.support.ValueSourceAggregatorFactory;
 import org.elasticsearch.search.aggregations.support.ValuesSource;
 import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
 import org.elasticsearch.search.aggregations.support.bytes.BytesValuesSource;
 import org.elasticsearch.search.aggregations.support.numeric.NumericValuesSource;
-import org.elasticsearch.search.aggregations.support.ValueSourceAggregatorFactory;
 
 /**
  *
@@ -35,11 +35,13 @@ public class TermsAggregatorFactory extends ValueSourceAggregatorFactory {
 
     private final InternalOrder order;
     private final int requiredSize;
+    private final int shardSize;
 
-    public TermsAggregatorFactory(String name, ValuesSourceConfig valueSourceConfig, InternalOrder order, int requiredSize) {
+    public TermsAggregatorFactory(String name, ValuesSourceConfig valueSourceConfig, InternalOrder order, int requiredSize, int shardSize) {
         super(name, StringTerms.TYPE.name(), valueSourceConfig);
         this.order = order;
         this.requiredSize = requiredSize;
+        this.shardSize = shardSize;
     }
 
     @Override
@@ -50,14 +52,14 @@ public class TermsAggregatorFactory extends ValueSourceAggregatorFactory {
     @Override
     protected Aggregator create(ValuesSource valuesSource, long expectedBucketsCount, AggregationContext aggregationContext, Aggregator parent) {
         if (valuesSource instanceof BytesValuesSource) {
-            return new StringTermsAggregator(name, factories, valuesSource, order, requiredSize, aggregationContext, parent);
+            return new StringTermsAggregator(name, factories, valuesSource, order, requiredSize, shardSize, aggregationContext, parent);
         }
 
         if (valuesSource instanceof NumericValuesSource) {
             if (((NumericValuesSource) valuesSource).isFloatingPoint()) {
-                return new DoubleTermsAggregator(name, factories, (NumericValuesSource) valuesSource, order, requiredSize, aggregationContext, parent);
+                return new DoubleTermsAggregator(name, factories, (NumericValuesSource) valuesSource, order, requiredSize, shardSize, aggregationContext, parent);
             }
-            return new LongTermsAggregator(name, factories, (NumericValuesSource) valuesSource, order, requiredSize, aggregationContext, parent);
+            return new LongTermsAggregator(name, factories, (NumericValuesSource) valuesSource, order, requiredSize, shardSize, aggregationContext, parent);
         }
 
         throw new AggregationExecutionException("terms aggregation cannot be applied to field [" + valuesSourceConfig.fieldContext().field() +

+ 9 - 0
src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsBuilder.java

@@ -12,6 +12,7 @@ import java.util.Locale;
 public class TermsBuilder extends ValuesSourceAggregationBuilder<TermsBuilder> {
 
     private int size = -1;
+    private int shardSize = -1;
     private Terms.ValueType valueType;
     private Terms.Order order;
 
@@ -24,6 +25,11 @@ public class TermsBuilder extends ValuesSourceAggregationBuilder<TermsBuilder> {
         return this;
     }
 
+    public TermsBuilder shardSize(int shardSize) {
+        this.shardSize = shardSize;
+        return this;
+    }
+
     public TermsBuilder valueType(Terms.ValueType valueType) {
         this.valueType = valueType;
         return this;
@@ -39,6 +45,9 @@ public class TermsBuilder extends ValuesSourceAggregationBuilder<TermsBuilder> {
         if (size >=0) {
             builder.field("size", size);
         }
+        if (shardSize >= 0) {
+            builder.field("shard_size", shardSize);
+        }
         if (valueType != null) {
             builder.field("value_type", valueType.name().toLowerCase(Locale.ROOT));
         }

+ 12 - 4
src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsParser.java

@@ -27,6 +27,7 @@ import org.elasticsearch.index.mapper.core.DateFieldMapper;
 import org.elasticsearch.index.mapper.ip.IpFieldMapper;
 import org.elasticsearch.script.SearchScript;
 import org.elasticsearch.search.aggregations.Aggregator;
+import org.elasticsearch.search.aggregations.AggregatorFactory;
 import org.elasticsearch.search.aggregations.support.FieldContext;
 import org.elasticsearch.search.aggregations.support.ValuesSource;
 import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
@@ -34,7 +35,6 @@ import org.elasticsearch.search.aggregations.support.bytes.BytesValuesSource;
 import org.elasticsearch.search.aggregations.support.numeric.NumericValuesSource;
 import org.elasticsearch.search.aggregations.support.numeric.ValueFormatter;
 import org.elasticsearch.search.aggregations.support.numeric.ValueParser;
-import org.elasticsearch.search.aggregations.AggregatorFactory;
 import org.elasticsearch.search.internal.SearchContext;
 
 import java.io.IOException;
@@ -62,6 +62,7 @@ public class TermsParser implements Aggregator.Parser {
         Map<String, Object> scriptParams = null;
         Terms.ValueType valueType = null;
         int requiredSize = 10;
+        int shardSize = -1;
         String orderKey = "_count";
         boolean orderAsc = false;
         String format = null;
@@ -92,6 +93,8 @@ public class TermsParser implements Aggregator.Parser {
             } else if (token == XContentParser.Token.VALUE_NUMBER) {
                 if ("size".equals(currentFieldName)) {
                     requiredSize = parser.intValue();
+                } else if ("shard_size".equals(currentFieldName) || "shardSize".equals(currentFieldName)) {
+                    shardSize = parser.intValue();
                 }
             } else if (token == XContentParser.Token.START_OBJECT) {
                 if ("params".equals(currentFieldName)) {
@@ -110,6 +113,11 @@ public class TermsParser implements Aggregator.Parser {
             }
         }
 
+        // shard_size cannot be smaller than size as we need to at least fetch <size> entries from every shards in order to return <size>
+        if (shardSize < requiredSize) {
+            shardSize = requiredSize;
+        }
+
         InternalOrder order = resolveOrder(orderKey, orderAsc);
         SearchScript searchScript = null;
         if (script != null) {
@@ -131,14 +139,14 @@ public class TermsParser implements Aggregator.Parser {
             if (!assumeUnique) {
                 config.ensureUnique(true);
             }
-            return new TermsAggregatorFactory(aggregationName, config, order, requiredSize);
+            return new TermsAggregatorFactory(aggregationName, config, order, requiredSize, shardSize);
         }
 
         FieldMapper<?> mapper = context.smartNameFieldMapper(field);
         if (mapper == null) {
             ValuesSourceConfig<?> config = new ValuesSourceConfig<BytesValuesSource>(BytesValuesSource.class);
             config.unmapped(true);
-            return new TermsAggregatorFactory(aggregationName, config, order, requiredSize);
+            return new TermsAggregatorFactory(aggregationName, config, order, requiredSize, shardSize);
         }
         IndexFieldData<?> indexFieldData = context.fieldData().getForField(mapper);
 
@@ -180,7 +188,7 @@ public class TermsParser implements Aggregator.Parser {
             config.ensureUnique(true);
         }
 
-        return new TermsAggregatorFactory(aggregationName, config, order, requiredSize);
+        return new TermsAggregatorFactory(aggregationName, config, order, requiredSize, shardSize);
     }
 
     static InternalOrder resolveOrder(String key, boolean asc) {

+ 1 - 1
src/main/java/org/elasticsearch/search/facet/terms/TermsFacetParser.java

@@ -130,7 +130,7 @@ public class TermsFacetParser extends AbstractComponent implements FacetParser {
                     script = parser.text();
                 } else if ("size".equals(currentFieldName)) {
                     size = parser.intValue();
-                } else if ("shard_size".equals(currentFieldName)) {
+                } else if ("shard_size".equals(currentFieldName) || "shardSize".equals(currentFieldName)) {
                     shardSize = parser.intValue();
                 } else if ("all_terms".equals(currentFieldName) || "allTerms".equals(currentFieldName)) {
                     allTerms = parser.booleanValue();

+ 362 - 0
src/test/java/org/elasticsearch/search/aggregations/bucket/ShardSizeTermsTests.java

@@ -0,0 +1,362 @@
+/*
+ * Licensed to ElasticSearch and Shay Banon under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. ElasticSearch licenses this
+ * file to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.search.aggregations.bucket;
+
+import com.google.common.collect.ImmutableMap;
+import org.elasticsearch.action.index.IndexRequestBuilder;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.common.settings.ImmutableSettings;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.search.aggregations.bucket.terms.Terms;
+import org.elasticsearch.test.ElasticsearchIntegrationTest;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder;
+import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
+import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
+import static org.elasticsearch.test.ElasticsearchIntegrationTest.ClusterScope;
+import static org.elasticsearch.test.ElasticsearchIntegrationTest.Scope;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+
+/**
+ *
+ */
+@ClusterScope(scope = Scope.TEST)
+public class ShardSizeTermsTests extends ElasticsearchIntegrationTest {
+
+    /**
+     * to properly test the effect/functionality of shard_size, we need to force having 2 shards and also
+     * control the routing such that certain documents will end on each shard. Using "djb" routing hash + ignoring the
+     * doc type when hashing will ensure that docs with routing value "1" will end up in a different shard than docs with
+     * routing value "2".
+     */
+    @Override
+    protected Settings nodeSettings(int nodeOrdinal) {
+        return ImmutableSettings.builder()
+                .put("index.number_of_shards", 2)
+                .put("index.number_of_replicas", 0)
+                .put("cluster.routing.operation.hash.type", "djb")
+                .put("cluster.routing.operation.use_type", "false")
+                .build();
+    }
+
+    @Test
+    public void noShardSize_string() throws Exception {
+
+        client().admin().indices().prepareCreate("idx")
+                .addMapping("type", "key", "type=string,index=not_analyzed")
+                .execute().actionGet();
+
+        indexData();
+
+        SearchResponse response = client().prepareSearch("idx").setTypes("type")
+                .setQuery(matchAllQuery())
+                .addAggregation(terms("keys").field("key").size(3).order(Terms.Order.COUNT_DESC))
+                .execute().actionGet();
+
+        Terms  terms = response.getAggregations().get("keys");
+        Collection<Terms.Bucket> buckets = terms.buckets();
+        assertThat(buckets.size(), equalTo(3));
+        Map<String, Long> expected = ImmutableMap.<String, Long>builder()
+                .put("1", 8l)
+                .put("3", 8l)
+                .put("2", 4l)
+                .build();
+        for (Terms.Bucket bucket : buckets) {
+            assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKey().string())));
+        }
+    }
+
+    @Test
+    public void withShardSize_string() throws Exception {
+
+        client().admin().indices().prepareCreate("idx")
+                .addMapping("type", "key", "type=string,index=not_analyzed")
+                .execute().actionGet();
+
+        indexData();
+
+        SearchResponse response = client().prepareSearch("idx").setTypes("type")
+                .setQuery(matchAllQuery())
+                .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC))
+                .execute().actionGet();
+
+        Terms terms = response.getAggregations().get("keys");
+        Collection<Terms.Bucket> buckets = terms.buckets();
+        assertThat(buckets.size(), equalTo(3)); // we still only return 3 entries (based on the 'size' param)
+        Map<String, Long> expected = ImmutableMap.<String, Long>builder()
+                .put("1", 8l)
+                .put("3", 8l)
+                .put("2", 5l) // <-- count is now fixed
+                .build();
+        for (Terms.Bucket bucket : buckets) {
+            assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKey().string())));
+        }
+    }
+
+    @Test
+    public void withShardSize_string_singleShard() throws Exception {
+
+        client().admin().indices().prepareCreate("idx")
+                .addMapping("type", "key", "type=string,index=not_analyzed")
+                .execute().actionGet();
+
+        indexData();
+
+        SearchResponse response = client().prepareSearch("idx").setTypes("type").setRouting("1")
+                .setQuery(matchAllQuery())
+                .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC))
+                .execute().actionGet();
+
+        Terms terms = response.getAggregations().get("keys");
+        Collection<Terms.Bucket> buckets = terms.buckets();
+        assertThat(buckets.size(), equalTo(3)); // we still only return 3 entries (based on the 'size' param)
+        Map<String, Long> expected = ImmutableMap.<String, Long>builder()
+                .put("1", 5l)
+                .put("2", 4l)
+                .put("3", 3l) // <-- count is now fixed
+                .build();
+        for (Terms.Bucket bucket: buckets) {
+            assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKey().string())));
+        }
+    }
+
+    @Test
+    public void noShardSize_long() throws Exception {
+
+        client().admin().indices().prepareCreate("idx")
+                .addMapping("type", "key", "type=long")
+                .execute().actionGet();
+
+        indexData();
+
+        SearchResponse response = client().prepareSearch("idx").setTypes("type")
+                .setQuery(matchAllQuery())
+                .addAggregation(terms("keys").field("key").size(3).order(Terms.Order.COUNT_DESC))
+                .execute().actionGet();
+
+        Terms terms = response.getAggregations().get("keys");
+        Collection<Terms.Bucket> buckets = terms.buckets();
+        assertThat(buckets.size(), equalTo(3));
+        Map<Integer, Long> expected = ImmutableMap.<Integer, Long>builder()
+                .put(1, 8l)
+                .put(3, 8l)
+                .put(2, 4l)
+                .build();
+        for (Terms.Bucket bucket : buckets) {
+            assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue())));
+        }
+    }
+
+    @Test
+    public void withShardSize_long() throws Exception {
+
+        client().admin().indices().prepareCreate("idx")
+                .addMapping("type", "key", "type=long")
+                .execute().actionGet();
+
+        indexData();
+
+        SearchResponse response = client().prepareSearch("idx").setTypes("type")
+                .setQuery(matchAllQuery())
+                .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC))
+                .execute().actionGet();
+
+        Terms terms = response.getAggregations().get("keys");
+        Collection<Terms.Bucket> buckets = terms.buckets();
+        assertThat(buckets.size(), equalTo(3)); // we still only return 3 entries (based on the 'size' param)
+        Map<Integer, Long> expected = ImmutableMap.<Integer, Long>builder()
+                .put(1, 8l)
+                .put(3, 8l)
+                .put(2, 5l) // <-- count is now fixed
+                .build();
+        for (Terms.Bucket bucket : buckets) {
+            assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue())));
+        }
+    }
+
+    @Test
+    public void withShardSize_long_singleShard() throws Exception {
+
+        client().admin().indices().prepareCreate("idx")
+                .addMapping("type", "key", "type=long")
+                .execute().actionGet();
+
+        indexData();
+
+        SearchResponse response = client().prepareSearch("idx").setTypes("type").setRouting("1")
+                .setQuery(matchAllQuery())
+                .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC))
+                .execute().actionGet();
+
+        Terms terms = response.getAggregations().get("keys");
+        Collection<Terms.Bucket> buckets = terms.buckets();
+        assertThat(buckets.size(), equalTo(3)); // we still only return 3 entries (based on the 'size' param)
+        Map<Integer, Long> expected = ImmutableMap.<Integer, Long>builder()
+                .put(1, 5l)
+                .put(2, 4l)
+                .put(3, 3l)
+                .build();
+        for (Terms.Bucket bucket : buckets) {
+            assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue())));
+        }
+    }
+
+    @Test
+    public void noShardSize_double() throws Exception {
+
+        client().admin().indices().prepareCreate("idx")
+                .addMapping("type", "key", "type=double")
+                .execute().actionGet();
+
+        indexData();
+
+        SearchResponse response = client().prepareSearch("idx").setTypes("type")
+                .setQuery(matchAllQuery())
+                .addAggregation(terms("keys").field("key").size(3).order(Terms.Order.COUNT_DESC))
+                .execute().actionGet();
+
+        Terms terms = response.getAggregations().get("keys");
+        Collection<Terms.Bucket> buckets = terms.buckets();
+        assertThat(buckets.size(), equalTo(3));
+        Map<Integer, Long> expected = ImmutableMap.<Integer, Long>builder()
+                .put(1, 8l)
+                .put(3, 8l)
+                .put(2, 4l)
+                .build();
+        for (Terms.Bucket bucket : buckets) {
+            assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue())));
+        }
+    }
+
+    @Test
+    public void withShardSize_double() throws Exception {
+
+        client().admin().indices().prepareCreate("idx")
+                .addMapping("type", "key", "type=double")
+                .execute().actionGet();
+
+        indexData();
+
+        SearchResponse response = client().prepareSearch("idx").setTypes("type")
+                .setQuery(matchAllQuery())
+                .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC))
+                .execute().actionGet();
+
+        Terms terms = response.getAggregations().get("keys");
+        Collection<Terms.Bucket> buckets = terms.buckets();
+        assertThat(buckets.size(), equalTo(3));
+        Map<Integer, Long> expected = ImmutableMap.<Integer, Long>builder()
+                .put(1, 8l)
+                .put(3, 8l)
+                .put(2, 5l) // <-- count is now fixed
+                .build();
+        for (Terms.Bucket bucket : buckets) {
+            assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue())));
+        }
+    }
+
+    @Test
+    public void withShardSize_double_singleShard() throws Exception {
+
+        client().admin().indices().prepareCreate("idx")
+                .addMapping("type", "key", "type=double")
+                .execute().actionGet();
+
+        indexData();
+
+        SearchResponse response = client().prepareSearch("idx").setTypes("type").setRouting("1")
+                .setQuery(matchAllQuery())
+                .addAggregation(terms("keys").field("key").size(3).shardSize(5).order(Terms.Order.COUNT_DESC))
+                .execute().actionGet();
+
+        Terms terms = response.getAggregations().get("keys");
+        Collection<Terms.Bucket> buckets = terms.buckets();
+        assertThat(buckets.size(), equalTo(3));
+        Map<Integer, Long> expected = ImmutableMap.<Integer, Long>builder()
+                .put(1, 5l)
+                .put(2, 4l)
+                .put(3, 3l)
+                .build();
+        for (Terms.Bucket bucket : buckets) {
+            assertThat(bucket.getDocCount(), equalTo(expected.get(bucket.getKeyAsNumber().intValue())));
+        }
+    }
+
+    private void indexData() throws Exception {
+
+        /*
+
+
+        ||          ||           size = 3, shard_size = 5               ||           shard_size = size = 3               ||
+        ||==========||==================================================||===============================================||
+        || shard 1: ||  "1" - 5 | "2" - 4 | "3" - 3 | "4" - 2 | "5" - 1 || "1" - 5 | "3" - 3 | "2" - 4                   ||
+        ||----------||--------------------------------------------------||-----------------------------------------------||
+        || shard 2: ||  "1" - 3 | "2" - 1 | "3" - 5 | "4" - 2 | "5" - 1 || "1" - 3 | "3" - 5 | "4" - 2                   ||
+        ||----------||--------------------------------------------------||-----------------------------------------------||
+        || reduced: ||  "1" - 8 | "2" - 5 | "3" - 8 | "4" - 4 | "5" - 2 ||                                               ||
+        ||          ||                                                  || "1" - 8, "3" - 8, "2" - 4    <= WRONG         ||
+        ||          ||  "1" - 8 | "3" - 8 | "2" - 5     <= CORRECT      ||                                               ||
+
+
+        */
+
+        List<IndexRequestBuilder> indexOps = new ArrayList<IndexRequestBuilder>();
+
+        indexDoc("1", "1", 5, indexOps);
+        indexDoc("1", "2", 4, indexOps);
+        indexDoc("1", "3", 3, indexOps);
+        indexDoc("1", "4", 2, indexOps);
+        indexDoc("1", "5", 1, indexOps);
+
+        // total docs in shard "1" = 15
+
+        indexDoc("2", "1", 3, indexOps);
+        indexDoc("2", "2", 1, indexOps);
+        indexDoc("2", "3", 5, indexOps);
+        indexDoc("2", "4", 2, indexOps);
+        indexDoc("2", "5", 1, indexOps);
+
+        // total docs in shard "2"  = 12
+
+        indexRandom(true, indexOps);
+
+        long totalOnOne = client().prepareSearch("idx").setTypes("type").setRouting("1").setQuery(matchAllQuery()).execute().actionGet().getHits().getTotalHits();
+        assertThat(totalOnOne, is(15l));
+        long totalOnTwo = client().prepareSearch("idx").setTypes("type").setRouting("2").setQuery(matchAllQuery()).execute().actionGet().getHits().getTotalHits();
+        assertThat(totalOnTwo, is(12l));
+    }
+
+    private void indexDoc(String shard, String key, int times, List<IndexRequestBuilder> indexOps) throws Exception {
+        for (int i = 0; i < times; i++) {
+            indexOps.add(client().prepareIndex("idx", "type").setRouting(shard).setCreate(true).setSource(jsonBuilder()
+                    .startObject()
+                    .field("key", key)
+                    .endObject()));
+        }
+    }
+
+}