Bladeren bron

Wire up DiversifiedAggregation (#56145)

Mark Tozzi 5 jaren geleden
bovenliggende
commit
0df2c77141

+ 1 - 1
server/src/main/java/org/elasticsearch/search/SearchModule.java

@@ -385,7 +385,7 @@ public class SearchModule {
                     .addResultReader(UnmappedSampler.NAME, UnmappedSampler::new),
             builder);
         registerAggregation(new AggregationSpec(DiversifiedAggregationBuilder.NAME, DiversifiedAggregationBuilder::new,
-                DiversifiedAggregationBuilder.PARSER)
+                DiversifiedAggregationBuilder.PARSER).setAggregatorRegistrar(DiversifiedAggregationBuilder::registerAggregators)
                     /* Reuses result readers from SamplerAggregator*/, builder);
         registerAggregation(new AggregationSpec(TermsAggregationBuilder.NAME, TermsAggregationBuilder::new,
                 TermsAggregationBuilder.PARSER)

+ 5 - 0
server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedAggregationBuilder.java

@@ -31,6 +31,7 @@ import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
 import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
 import org.elasticsearch.search.aggregations.support.ValuesSourceAggregatorFactory;
 import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
+import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
 import org.elasticsearch.search.aggregations.support.ValuesSourceType;
 
 import java.io.IOException;
@@ -51,6 +52,10 @@ public class DiversifiedAggregationBuilder extends ValuesSourceAggregationBuilde
         PARSER.declareString(DiversifiedAggregationBuilder::executionHint, SamplerAggregator.EXECUTION_HINT_FIELD);
     }
 
+    public static void registerAggregators(ValuesSourceRegistry.Builder builder) {
+        DiversifiedAggregatorFactory.registerAggregators(builder);
+    }
+
     private int shardSize = SamplerAggregationBuilder.DEFAULT_SHARD_SAMPLE_SIZE;
     private int maxDocsPerValue = MAX_DOCS_PER_VALUE_DEFAULT;
     private String executionHint = null;

+ 42 - 32
server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedAggregatorFactory.java

@@ -27,19 +27,51 @@ import org.elasticsearch.search.aggregations.AggregatorFactory;
 import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.NonCollectingAggregator;
 import org.elasticsearch.search.aggregations.bucket.sampler.SamplerAggregator.ExecutionMode;
+import org.elasticsearch.search.aggregations.support.AggregatorSupplier;
+import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
 import org.elasticsearch.search.aggregations.support.ValuesSource;
-import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
 import org.elasticsearch.search.aggregations.support.ValuesSourceAggregatorFactory;
 import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
+import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
 import org.elasticsearch.search.internal.SearchContext;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Map;
 
-import static org.elasticsearch.search.aggregations.support.AggregationUsageService.OTHER_SUBTYPE;
-
 public class DiversifiedAggregatorFactory extends ValuesSourceAggregatorFactory {
 
+    public static void registerAggregators(ValuesSourceRegistry.Builder builder) {
+        builder.register(DiversifiedAggregationBuilder.NAME,
+            List.of(CoreValuesSourceType.NUMERIC, CoreValuesSourceType.DATE, CoreValuesSourceType.BOOLEAN),
+            (DiversifiedAggregatorSupplier) (String name, int shardSize, AggregatorFactories factories, SearchContext context,
+                                             Aggregator parent, Map<String, Object> metadata, ValuesSource valuesSource,
+                                             int maxDocsPerValue, String executionHint) ->
+                new DiversifiedNumericSamplerAggregator(name, shardSize, factories, context, parent, metadata, valuesSource,
+                    maxDocsPerValue)
+        );
+
+        builder.register(DiversifiedAggregationBuilder.NAME, CoreValuesSourceType.BYTES,
+            (DiversifiedAggregatorSupplier) (String name, int shardSize, AggregatorFactories factories, SearchContext context,
+                                             Aggregator parent, Map<String, Object> metadata, ValuesSource valuesSource,
+                                             int maxDocsPerValue, String executionHint) -> {
+                ExecutionMode execution = null;
+                if (executionHint != null) {
+                    execution = ExecutionMode.fromString(executionHint);
+                }
+
+                // In some cases using ordinals is just not supported: override it
+                if (execution == null) {
+                    execution = ExecutionMode.GLOBAL_ORDINALS;
+                }
+                if ((execution.needsGlobalOrdinals()) && (!(valuesSource instanceof ValuesSource.Bytes.WithOrdinals))) {
+                    execution = ExecutionMode.MAP;
+                }
+                return execution.create(name, factories, shardSize, maxDocsPerValue, valuesSource, context, parent, metadata);
+        });
+
+    }
+
     private final int shardSize;
     private final int maxDocsPerValue;
     private final String executionHint;
@@ -60,30 +92,14 @@ public class DiversifiedAggregatorFactory extends ValuesSourceAggregatorFactory
                                             boolean collectsFromSingleBucket,
                                             Map<String, Object> metadata) throws IOException {
 
-        if (valuesSource instanceof ValuesSource.Numeric) {
-            return new DiversifiedNumericSamplerAggregator(name, shardSize, factories, searchContext, parent, metadata,
-                    (Numeric) valuesSource, maxDocsPerValue);
-        }
-
-        if (valuesSource instanceof ValuesSource.Bytes) {
-            ExecutionMode execution = null;
-            if (executionHint != null) {
-                execution = ExecutionMode.fromString(executionHint);
-            }
-
-            // In some cases using ordinals is just not supported: override
-            // it
-            if (execution == null) {
-                execution = ExecutionMode.GLOBAL_ORDINALS;
-            }
-            if ((execution.needsGlobalOrdinals()) && (!(valuesSource instanceof ValuesSource.Bytes.WithOrdinals))) {
-                execution = ExecutionMode.MAP;
-            }
-            return execution.create(name, factories, shardSize, maxDocsPerValue, valuesSource, searchContext, parent, metadata);
+        AggregatorSupplier supplier = queryShardContext.getValuesSourceRegistry().getAggregator(config.valueSourceType(),
+            DiversifiedAggregationBuilder.NAME);
+        if (supplier instanceof DiversifiedAggregatorSupplier == false) {
+            throw new AggregationExecutionException("Registry miss-match - expected " + DiversifiedAggregatorSupplier.class.toString() +
+                ", found [" + supplier.getClass().toString() + "]");
         }
-
-        throw new AggregationExecutionException("Sampler aggregation cannot be applied to field [" + config.fieldContext().field()
-                + "]. It can only be applied to numeric or string fields.");
+        return ((DiversifiedAggregatorSupplier) supplier).build(name, shardSize, factories, searchContext, parent, metadata, valuesSource,
+            maxDocsPerValue, executionHint);
     }
 
     @Override
@@ -99,10 +115,4 @@ public class DiversifiedAggregatorFactory extends ValuesSourceAggregatorFactory
             }
         };
     }
-
-    @Override
-    public String getStatsSubtype() {
-        // DiversifiedAggregatorFactory doesn't register itself with ValuesSourceRegistry
-        return OTHER_SUBTYPE;
-    }
 }

+ 42 - 0
server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedAggregatorSupplier.java

@@ -0,0 +1,42 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.search.aggregations.bucket.sampler;
+
+import org.elasticsearch.search.aggregations.Aggregator;
+import org.elasticsearch.search.aggregations.AggregatorFactories;
+import org.elasticsearch.search.aggregations.support.AggregatorSupplier;
+import org.elasticsearch.search.aggregations.support.ValuesSource;
+import org.elasticsearch.search.internal.SearchContext;
+
+import java.io.IOException;
+import java.util.Map;
+
+public interface DiversifiedAggregatorSupplier extends AggregatorSupplier {
+    Aggregator build(
+        String name,
+        int shardSize,
+        AggregatorFactories factories,
+        SearchContext context,
+        Aggregator parent,
+        Map<String, Object> metadata,
+        ValuesSource valuesSource,
+        int maxDocsPerValue,
+        String executionHint) throws IOException;
+}

+ 2 - 2
server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/DiversifiedNumericSamplerAggregator.java

@@ -44,9 +44,9 @@ public class DiversifiedNumericSamplerAggregator extends SamplerAggregator {
 
     DiversifiedNumericSamplerAggregator(String name, int shardSize, AggregatorFactories factories,
             SearchContext context, Aggregator parent, Map<String, Object> metadata,
-            ValuesSource.Numeric valuesSource, int maxDocsPerValue) throws IOException {
+            ValuesSource valuesSource, int maxDocsPerValue) throws IOException {
         super(name, shardSize, factories, context, parent, metadata);
-        this.valuesSource = valuesSource;
+        this.valuesSource = (ValuesSource.Numeric) valuesSource;
         this.maxDocsPerValue = maxDocsPerValue;
     }