1
0
Эх сурвалжийг харах

Aggregations Refactor: Refactor Filters Aggregation

Colin Goodheart-Smithe 9 жил өмнө
parent
commit
6df27fe0e0

+ 187 - 10
core/src/main/java/org/elasticsearch/search/aggregations/bucket/filters/FiltersAggregator.java

@@ -23,7 +23,14 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.Weight;
 import org.apache.lucene.util.Bits;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.lucene.Lucene;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.search.aggregations.Aggregator;
 import org.elasticsearch.search.aggregations.AggregatorFactories;
 import org.elasticsearch.search.aggregations.AggregatorFactory;
@@ -39,21 +46,72 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 
 /**
  *
  */
 public class FiltersAggregator extends BucketsAggregator {
 
-    static class KeyedFilter {
+    public static final ParseField FILTERS_FIELD = new ParseField("filters");
+    public static final ParseField OTHER_BUCKET_FIELD = new ParseField("other_bucket");
+    public static final ParseField OTHER_BUCKET_KEY_FIELD = new ParseField("other_bucket_key");
 
-        final String key;
-        final Query filter;
+    public static class KeyedFilter implements Writeable<KeyedFilter>, ToXContent {
 
-        KeyedFilter(String key, Query filter) {
+        static final KeyedFilter PROTOTYPE = new KeyedFilter(null, null);
+        private final String key;
+        private final QueryBuilder<?> filter;
+
+        public KeyedFilter(String key, QueryBuilder<?> filter) {
             this.key = key;
             this.filter = filter;
         }
+
+        public String key() {
+            return key;
+        }
+
+        public QueryBuilder<?> filter() {
+            return filter;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.field(key, filter);
+            return builder;
+        }
+
+        @Override
+        public KeyedFilter readFrom(StreamInput in) throws IOException {
+            String key = in.readString();
+            QueryBuilder<?> filter = in.readQuery();
+            return new KeyedFilter(key, filter);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(key);
+            out.writeQuery(filter);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(key, filter);
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (obj == null) {
+                return false;
+            }
+            if (getClass() != obj.getClass()) {
+                return false;
+            }
+            KeyedFilter other = (KeyedFilter) obj;
+            return Objects.equals(key, other.key)
+                    && Objects.equals(filter, other.filter);
+        }
     }
 
     private final String[] keys;
@@ -81,7 +139,8 @@ public class FiltersAggregator extends BucketsAggregator {
         for (int i = 0; i < filters.size(); ++i) {
             KeyedFilter keyedFilter = filters.get(i);
             this.keys[i] = keyedFilter.key;
-            this.filters[i] = aggregationContext.searchContext().searcher().createNormalizedWeight(keyedFilter.filter, false);
+            Query filter = keyedFilter.filter.toFilter(context.searchContext().indexShard().getQueryShardContext());
+            this.filters[i] = aggregationContext.searchContext().searcher().createNormalizedWeight(filter, false);
         }
     }
 
@@ -146,20 +205,138 @@ public class FiltersAggregator extends BucketsAggregator {
     public static class Factory extends AggregatorFactory {
 
         private final List<KeyedFilter> filters;
-        private boolean keyed;
-        private String otherBucketKey;
+        private final boolean keyed;
+        private boolean otherBucket = false;
+        private String otherBucketKey = "_other_";
 
-        public Factory(String name, List<KeyedFilter> filters, boolean keyed, String otherBucketKey) {
+        public Factory(String name, List<KeyedFilter> filters) {
             super(name, InternalFilters.TYPE);
             this.filters = filters;
-            this.keyed = keyed;
+            this.keyed = true;
+        }
+
+        public Factory(String name, QueryBuilder<?>... filters) {
+            super(name, InternalFilters.TYPE);
+            List<KeyedFilter> keyedFilters = new ArrayList<>(filters.length);
+            for (int i = 0; i < filters.length; i++) {
+                keyedFilters.add(new KeyedFilter(String.valueOf(i), filters[i]));
+            }
+            this.filters = keyedFilters;
+            this.keyed = false;
+        }
+
+        /**
+         * Set whether to include a bucket for documents not matching any filter
+         */
+        public void otherBucket(boolean otherBucket) {
+            this.otherBucket = otherBucket;
+        }
+
+        /**
+         * Get whether to include a bucket for documents not matching any filter
+         */
+        public boolean otherBucket() {
+            return otherBucket;
+        }
+
+        /**
+         * Set the key to use for the bucket for documents not matching any
+         * filter.
+         */
+        public void otherBucketKey(String otherBucketKey) {
             this.otherBucketKey = otherBucketKey;
         }
 
+        /**
+         * Get the key to use for the bucket for documents not matching any
+         * filter.
+         */
+        public String otherBucketKey() {
+            return otherBucketKey;
+        }
+
         @Override
         public Aggregator createInternal(AggregationContext context, Aggregator parent, boolean collectsFromSingleBucket,
                 List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
-            return new FiltersAggregator(name, factories, filters, keyed, otherBucketKey, context, parent, pipelineAggregators, metaData);
+            return new FiltersAggregator(name, factories, filters, keyed, otherBucket ? otherBucketKey : null, context, parent,
+                    pipelineAggregators, metaData);
+        }
+
+        @Override
+        protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            if (keyed) {
+                builder.startObject(FILTERS_FIELD.getPreferredName());
+                for (KeyedFilter keyedFilter : filters) {
+                    builder.field(keyedFilter.key(), keyedFilter.filter());
+                }
+                builder.endObject();
+            } else {
+                builder.startArray(FILTERS_FIELD.getPreferredName());
+                for (KeyedFilter keyedFilter : filters) {
+                    builder.value(keyedFilter.filter());
+                }
+                builder.endArray();
+            }
+            builder.field(OTHER_BUCKET_FIELD.getPreferredName(), otherBucket);
+            builder.field(OTHER_BUCKET_KEY_FIELD.getPreferredName(), otherBucketKey);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        protected AggregatorFactory doReadFrom(String name, StreamInput in) throws IOException {
+            Factory factory;
+            if (in.readBoolean()) {
+                int size = in.readVInt();
+                List<KeyedFilter> filters = new ArrayList<>(size);
+                for (int i = 0; i < size; i++) {
+                    filters.add(KeyedFilter.PROTOTYPE.readFrom(in));
+                }
+                factory = new Factory(name, filters);
+            } else {
+                int size = in.readVInt();
+                QueryBuilder<?>[] filters = new QueryBuilder<?>[size];
+                for (int i = 0; i < size; i++) {
+                    filters[i] = in.readQuery();
+                }
+                factory = new Factory(name, filters);
+            }
+            factory.otherBucket = in.readBoolean();
+            factory.otherBucketKey = in.readString();
+            return factory;
+        }
+
+        @Override
+        protected void doWriteTo(StreamOutput out) throws IOException {
+            out.writeBoolean(keyed);
+            if (keyed) {
+                out.writeVInt(filters.size());
+                for (KeyedFilter keyedFilter : filters) {
+                    keyedFilter.writeTo(out);
+                }
+            } else {
+                out.writeVInt(filters.size());
+                for (KeyedFilter keyedFilter : filters) {
+                    out.writeQuery(keyedFilter.filter());
+                }
+            }
+            out.writeBoolean(otherBucket);
+            out.writeString(otherBucketKey);
+        }
+
+        @Override
+        protected int doHashCode() {
+            return Objects.hash(filters, keyed, otherBucket, otherBucketKey);
+        }
+
+        @Override
+        protected boolean doEquals(Object obj) {
+            Factory other = (Factory) obj;
+            return Objects.equals(filters, other.filters)
+                    && Objects.equals(keyed, other.keyed)
+                    && Objects.equals(otherBucket, other.otherBucket)
+                    && Objects.equals(otherBucketKey, other.otherBucketKey);
         }
     }
 

+ 42 - 18
core/src/main/java/org/elasticsearch/search/aggregations/bucket/filters/FiltersParser.java

@@ -20,9 +20,12 @@
 package org.elasticsearch.search.aggregations.bucket.filters;
 
 import org.elasticsearch.common.ParseField;
-import org.elasticsearch.common.lucene.search.Queries;
+import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.xcontent.XContentParser;
-import org.elasticsearch.index.query.ParsedQuery;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.index.query.QueryParseContext;
+import org.elasticsearch.indices.query.IndicesQueriesRegistry;
 import org.elasticsearch.search.SearchParseException;
 import org.elasticsearch.search.aggregations.Aggregator;
 import org.elasticsearch.search.aggregations.AggregatorFactory;
@@ -30,6 +33,7 @@ import org.elasticsearch.search.internal.SearchContext;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 
 /**
@@ -40,6 +44,12 @@ public class FiltersParser implements Aggregator.Parser {
     public static final ParseField FILTERS_FIELD = new ParseField("filters");
     public static final ParseField OTHER_BUCKET_FIELD = new ParseField("other_bucket");
     public static final ParseField OTHER_BUCKET_KEY_FIELD = new ParseField("other_bucket_key");
+    private final IndicesQueriesRegistry queriesRegistry;
+
+    @Inject
+    public FiltersParser(IndicesQueriesRegistry queriesRegistry) {
+        this.queriesRegistry = queriesRegistry;
+    }
 
     @Override
     public String type() {
@@ -49,13 +59,13 @@ public class FiltersParser implements Aggregator.Parser {
     @Override
     public AggregatorFactory parse(String aggregationName, XContentParser parser, SearchContext context) throws IOException {
 
-        List<FiltersAggregator.KeyedFilter> filters = new ArrayList<>();
+        List<FiltersAggregator.KeyedFilter> keyedFilters = null;
+        List<QueryBuilder<?>> nonKeyedFilters = null;
 
         XContentParser.Token token = null;
         String currentFieldName = null;
-        Boolean keyed = null;
         String otherBucketKey = null;
-        boolean otherBucket = false;
+        Boolean otherBucket = false;
         while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
             if (token == XContentParser.Token.FIELD_NAME) {
                 currentFieldName = parser.currentName();
@@ -69,21 +79,24 @@ public class FiltersParser implements Aggregator.Parser {
             } else if (token == XContentParser.Token.VALUE_STRING) {
                 if (context.parseFieldMatcher().match(currentFieldName, OTHER_BUCKET_KEY_FIELD)) {
                     otherBucketKey = parser.text();
-                    otherBucket = true;
                 } else {
                     throw new SearchParseException(context, "Unknown key for a " + token + " in [" + aggregationName + "]: ["
                             + currentFieldName + "].", parser.getTokenLocation());
                 }
             } else if (token == XContentParser.Token.START_OBJECT) {
                 if (context.parseFieldMatcher().match(currentFieldName, FILTERS_FIELD)) {
-                    keyed = true;
+                    keyedFilters = new ArrayList<>();
                     String key = null;
                     while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
                         if (token == XContentParser.Token.FIELD_NAME) {
                             key = parser.currentName();
                         } else {
-                            ParsedQuery filter = context.indexShard().getQueryShardContext().parseInnerFilter(parser);
-                            filters.add(new FiltersAggregator.KeyedFilter(key, filter == null ? Queries.newMatchAllQuery() : filter.query()));
+                            QueryParseContext queryParseContext = new QueryParseContext(queriesRegistry);
+                            queryParseContext.reset(parser);
+                            queryParseContext.parseFieldMatcher(context.parseFieldMatcher());
+                            QueryBuilder<?> filter = queryParseContext.parseInnerQueryBuilder();
+                            keyedFilters
+                                    .add(new FiltersAggregator.KeyedFilter(key, filter == null ? QueryBuilders.matchAllQuery() : filter));
                         }
                     }
                 } else {
@@ -92,13 +105,13 @@ public class FiltersParser implements Aggregator.Parser {
                 }
             } else if (token == XContentParser.Token.START_ARRAY) {
                 if (context.parseFieldMatcher().match(currentFieldName, FILTERS_FIELD)) {
-                    keyed = false;
-                    int idx = 0;
+                    nonKeyedFilters = new ArrayList<>();
                     while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
-                        ParsedQuery filter = context.indexShard().getQueryShardContext().parseInnerFilter(parser);
-                        filters.add(new FiltersAggregator.KeyedFilter(String.valueOf(idx), filter == null ? Queries.newMatchAllQuery()
-                                : filter.query()));
-                        idx++;
+                        QueryParseContext queryParseContext = new QueryParseContext(queriesRegistry);
+                        queryParseContext.reset(parser);
+                        queryParseContext.parseFieldMatcher(context.parseFieldMatcher());
+                        QueryBuilder<?> filter = queryParseContext.parseInnerQueryBuilder();
+                        nonKeyedFilters.add(filter == null ? QueryBuilders.matchAllQuery() : filter);
                     }
                 } else {
                     throw new SearchParseException(context, "Unknown key for a " + token + " in [" + aggregationName + "]: ["
@@ -114,13 +127,24 @@ public class FiltersParser implements Aggregator.Parser {
             otherBucketKey = "_other_";
         }
 
-        return new FiltersAggregator.Factory(aggregationName, filters, keyed, otherBucketKey);
+        FiltersAggregator.Factory factory;
+        if (keyedFilters != null) {
+            factory = new FiltersAggregator.Factory(aggregationName, keyedFilters);
+        } else {
+            factory = new FiltersAggregator.Factory(aggregationName, nonKeyedFilters.toArray(new QueryBuilder<?>[nonKeyedFilters.size()]));
+        }
+        if (otherBucket != null) {
+            factory.otherBucket(otherBucket);
+        }
+        if (otherBucketKey != null) {
+            factory.otherBucketKey(otherBucketKey);
+        }
+        return factory;
     }
 
-    // NORELEASE implement this method when refactoring this aggregation
     @Override
     public AggregatorFactory[] getFactoryPrototypes() {
-        return null;
+        return new AggregatorFactory[] { new FiltersAggregator.Factory(null, Collections.emptyList()) };
     }
 
 }

+ 1 - 1
core/src/test/java/org/elasticsearch/search/aggregations/bucket/FiltersIT.java

@@ -308,7 +308,7 @@ public class FiltersIT extends ESIntegTestCase {
         SearchResponse response = client()
                 .prepareSearch("idx")
                 .addAggregation(
-                        filters("tags").otherBucketKey("foobar")
+                        filters("tags").otherBucket(true).otherBucketKey("foobar")
                         .filter("tag1", termQuery("tag", "tag1"))
                         .filter("tag2", termQuery("tag", "tag2")))
                 .execute().actionGet();

+ 68 - 0
core/src/test/java/org/elasticsearch/search/aggregations/metrics/FiltersTests.java

@@ -0,0 +1,68 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.search.aggregations.metrics;
+
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.aggregations.BaseAggregationTestCase;
+import org.elasticsearch.search.aggregations.bucket.filters.FiltersAggregator;
+import org.elasticsearch.search.aggregations.bucket.filters.FiltersAggregator.Factory;
+import org.elasticsearch.search.aggregations.bucket.filters.FiltersAggregator.KeyedFilter;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class FiltersTests extends BaseAggregationTestCase<FiltersAggregator.Factory> {
+
+    @Override
+    protected Factory createTestAggregatorFactory() {
+
+        int size = randomIntBetween(1, 20);
+        Factory factory;
+        if (randomBoolean()) {
+            List<KeyedFilter> filters = new ArrayList<>(size);
+            for (int i = 0; i < size; i++) {
+                // NORELEASE make RandomQueryBuilder work outside of the
+                // AbstractQueryTestCase
+                // builder.query(RandomQueryBuilder.createQuery(getRandom()));
+                filters.add(new KeyedFilter(randomAsciiOfLengthBetween(1, 20),
+                        QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20), randomAsciiOfLengthBetween(5, 20))));
+            }
+            factory = new Factory(randomAsciiOfLengthBetween(1, 20), filters);
+        } else {
+            QueryBuilder<?>[] filters = new QueryBuilder<?>[size];
+            for (int i = 0; i < size; i++) {
+                // NORELEASE make RandomQueryBuilder work outside of the
+                // AbstractQueryTestCase
+                // builder.query(RandomQueryBuilder.createQuery(getRandom()));
+                filters[i] = QueryBuilders.termQuery(randomAsciiOfLengthBetween(5, 20), randomAsciiOfLengthBetween(5, 20));
+            }
+            factory = new Factory(randomAsciiOfLengthBetween(1, 20), filters);
+        }
+        if (randomBoolean()) {
+            factory.otherBucket(randomBoolean());
+        }
+        if (randomBoolean()) {
+            factory.otherBucketKey(randomAsciiOfLengthBetween(1, 20));
+        }
+        return factory;
+    }
+
+}