Browse Source

Aggs: Add validation to Bucket script pipeline agg (#132320) (#132440)

Closes https://github.com/elastic/elasticsearch/issues/132272

Docs are explicit on what the bucket_script agg requires:

> A parent pipeline aggregation which executes a script which can perform per bucket computations on specified metrics in the parent **_multi-bucket aggregation_**

But it's missing a validation.
Iván Cea Fontenla 2 months ago
parent
commit
078bfc555a

+ 6 - 0
docs/changelog/132320.yaml

@@ -0,0 +1,6 @@
+pr: 132320
+summary: "Aggs: Add validation to Bucket script pipeline agg"
+area: Aggregations
+type: bug
+issues:
+ - 132272

+ 28 - 0
modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/bucket_script.yml

@@ -340,3 +340,31 @@ top level fails:
                 buckets_path:
                   b: b
                 script: params.b + 12
+
+---
+invalid parent aggregation:
+  - requires:
+      capabilities:
+        - method: POST
+          path: /_search
+          capabilities: [ bucket_script_parent_multi_bucket_error ]
+      test_runner_features: [capabilities]
+      reason: "changed error 500 to 400"
+  - do:
+      catch: /Expected a multi bucket aggregation but got \[InternalFilter\] for aggregation \[d\]/
+      search:
+        body:
+          aggs:
+            a:
+              filter:
+                term:
+                  a: 1
+              aggs:
+                b:
+                  sum:
+                    field: b
+                d:
+                  bucket_script:
+                    buckets_path:
+                      b: b
+                    script: params.b + 12

+ 2 - 0
server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

@@ -49,6 +49,7 @@ public final class SearchCapabilities {
     private static final String SIGNIFICANT_TERMS_ON_NESTED_FIELDS = "significant_terms_on_nested_fields";
     private static final String EXCLUDE_VECTORS_PARAM = "exclude_vectors_param";
     private static final String DENSE_VECTOR_UPDATABLE_BBQ = "dense_vector_updatable_bbq";
+    private static final String BUCKET_SCRIPT_PARENT_MULTI_BUCKET_ERROR = "bucket_script_parent_multi_bucket_error";
 
     public static final Set<String> CAPABILITIES;
     static {
@@ -70,6 +71,7 @@ public final class SearchCapabilities {
         capabilities.add(SIGNIFICANT_TERMS_ON_NESTED_FIELDS);
         capabilities.add(EXCLUDE_VECTORS_PARAM);
         capabilities.add(DENSE_VECTOR_UPDATABLE_BBQ);
+        capabilities.add(BUCKET_SCRIPT_PARENT_MULTI_BUCKET_ERROR);
         CAPABILITIES = Set.copyOf(capabilities);
     }
 }

+ 18 - 3
server/src/main/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptPipelineAggregator.java

@@ -21,6 +21,7 @@ import org.elasticsearch.search.aggregations.pipeline.BucketHelpers.GapPolicy;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 
 import static org.elasticsearch.search.aggregations.pipeline.BucketHelpers.resolveBucketValue;
@@ -47,10 +48,24 @@ public class BucketScriptPipelineAggregator extends PipelineAggregator {
     }
 
     @Override
+    @SuppressWarnings({ "rawtypes", "unchecked" })
     public InternalAggregation reduce(InternalAggregation aggregation, AggregationReduceContext reduceContext) {
-        @SuppressWarnings({ "rawtypes", "unchecked" })
-        InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg =
-            (InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket>) aggregation;
+
+        InternalMultiBucketAggregation<InternalMultiBucketAggregation, InternalMultiBucketAggregation.InternalBucket> originalAgg;
+
+        if (aggregation instanceof InternalMultiBucketAggregation multiBucketAggregation) {
+            originalAgg = multiBucketAggregation;
+        } else {
+            throw new IllegalArgumentException(
+                String.format(
+                    Locale.ROOT,
+                    "Expected a multi bucket aggregation but got [%s] for aggregation [%s]",
+                    aggregation.getClass().getSimpleName(),
+                    name()
+                )
+            );
+        }
+
         List<? extends InternalMultiBucketAggregation.InternalBucket> buckets = originalAgg.getBuckets();
 
         BucketAggregationScript.Factory factory = reduceContext.scriptService().compile(script, BucketAggregationScript.CONTEXT);

+ 37 - 1
server/src/test/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptAggregatorTests.java

@@ -30,7 +30,9 @@ import org.elasticsearch.script.ScriptEngine;
 import org.elasticsearch.script.ScriptModule;
 import org.elasticsearch.script.ScriptService;
 import org.elasticsearch.script.ScriptType;
+import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregatorTestCase;
+import org.elasticsearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
 import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBuilder;
 import org.elasticsearch.search.aggregations.bucket.filter.InternalFilters;
 import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
@@ -95,8 +97,42 @@ public class BucketScriptAggregatorTests extends AggregatorTestCase {
         );
     }
 
+    public void testNonMultiBucketParent() {
+        MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType("number_field", NumberFieldMapper.NumberType.INTEGER);
+        MappedFieldType fieldType1 = new KeywordFieldMapper.KeywordFieldType("the_field");
+
+        FilterAggregationBuilder filter = new FilterAggregationBuilder("placeholder", new MatchAllQueryBuilder()).subAggregation(
+            new TermsAggregationBuilder("the_terms").userValueTypeHint(ValueType.STRING)
+                .field("the_field")
+                .subAggregation(new AvgAggregationBuilder("the_avg").field("number_field"))
+        )
+            .subAggregation(
+                new BucketScriptPipelineAggregationBuilder(
+                    "bucket_script",
+                    Collections.singletonMap("the_avg", "the_terms['test1']>the_avg.value"),
+                    new Script(ScriptType.INLINE, MockScriptEngine.NAME, SCRIPT_NAME, Collections.emptyMap())
+                )
+            );
+
+        assertThrows(
+            "Expected a multi bucket aggregation but got [InternalFilter] for aggregation [bucket_script]",
+            IllegalArgumentException.class,
+            () -> testCase(filter, new MatchAllDocsQuery(), iw -> {
+                Document doc = new Document();
+                doc.add(new SortedSetDocValuesField("the_field", new BytesRef("test1")));
+                doc.add(new SortedNumericDocValuesField("number_field", 19));
+                iw.addDocument(doc);
+
+                doc = new Document();
+                doc.add(new SortedSetDocValuesField("the_field", new BytesRef("test2")));
+                doc.add(new SortedNumericDocValuesField("number_field", 55));
+                iw.addDocument(doc);
+            }, f -> fail("This shouldn't be called"), fieldType, fieldType1)
+        );
+    }
+
     private void testCase(
-        FiltersAggregationBuilder aggregationBuilder,
+        AggregationBuilder aggregationBuilder,
         Query query,
         CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
         Consumer<InternalFilters> verify,