Преглед на файлове

Support multiple associated groups for TopN (#108409)

In 8.14 we have introduced the ability to retrieve associated groups (services,
transactions) for a TopN function by a user-defined field. With this commit we
enhance this ability to allow to specify an arbitrary number of fields 
(restricted by request validation to two fields to avoid uncontrolled bucket 
explosion due to a misconfiguration). We also introduce a BWC layer so the old
request format is still accepted and the old response format automatically used.

Closes #108018
Daniel Mitterdorfer преди 1 година
родител
ревизия
418b9f4e41
променени са 19 файла, в които са добавени 902 реда и са изтрити 121 реда
  1. 6 0
      docs/changelog/108409.yaml
  2. 1 0
      x-pack/plugin/profiling/src/internalClusterTest/java/org/elasticsearch/xpack/profiling/action/GetFlameGraphActionIT.java
  3. 9 24
      x-pack/plugin/profiling/src/internalClusterTest/java/org/elasticsearch/xpack/profiling/action/GetStackTracesActionIT.java
  4. 3 0
      x-pack/plugin/profiling/src/internalClusterTest/java/org/elasticsearch/xpack/profiling/action/GetTopNFunctionsActionIT.java
  5. 56 7
      x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesRequest.java
  6. 1 1
      x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponseBuilder.java
  7. 1 1
      x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/StackTrace.java
  8. 142 0
      x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/SubGroup.java
  9. 156 0
      x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/SubGroupCollector.java
  10. 16 14
      x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TopNFunction.java
  11. 1 3
      x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TraceEvent.java
  12. 12 40
      x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TransportGetStackTracesAction.java
  13. 204 0
      x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/GetStackTracesRequestTests.java
  14. 8 3
      x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/GetTopNFunctionsResponseTests.java
  15. 5 0
      x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/ResamplerTests.java
  16. 149 0
      x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/SubGroupCollectorTests.java
  17. 102 0
      x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/SubGroupTests.java
  18. 29 26
      x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/TopNFunctionTests.java
  19. 1 2
      x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/TransportGetTopNFunctionsActionTests.java

+ 6 - 0
docs/changelog/108409.yaml

@@ -0,0 +1,6 @@
+pr: 108409
+summary: Support multiple associated groups for TopN
+area: Application
+type: enhancement
+issues:
+ - 108018

+ 1 - 0
x-pack/plugin/profiling/src/internalClusterTest/java/org/elasticsearch/xpack/profiling/action/GetFlameGraphActionIT.java

@@ -22,6 +22,7 @@ public class GetFlameGraphActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         GetFlamegraphResponse response = client().execute(GetFlamegraphAction.INSTANCE, request).get();

+ 9 - 24
x-pack/plugin/profiling/src/internalClusterTest/java/org/elasticsearch/xpack/profiling/action/GetStackTracesActionIT.java

@@ -28,6 +28,7 @@ public class GetStackTracesActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         request.setAdjustSampleCount(true);
@@ -72,6 +73,7 @@ public class GetStackTracesActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         request.setAdjustSampleCount(true);
@@ -91,7 +93,7 @@ public class GetStackTracesActionIT extends ProfilingTestCase {
         assertEquals(18, stackTrace.typeIds.length);
         assertEquals(0.0000048475146d, stackTrace.annualCO2Tons, 0.0000000001d);
         assertEquals(0.18834d, stackTrace.annualCostsUSD, 0.00001d);
-        assertEquals(Long.valueOf(2L), stackTrace.subGroups.get("basket"));
+        assertEquals(Long.valueOf(2L), stackTrace.subGroups.getCount("basket"));
 
         assertNotNull(response.getStackFrames());
         StackFrame stackFrame = response.getStackFrames().get("8NlMClggx8jaziUTJXlmWAAAAAAAAIYI");
@@ -101,28 +103,6 @@ public class GetStackTracesActionIT extends ProfilingTestCase {
         assertEquals("vmlinux", response.getExecutables().get("lHp5_WAgpLy2alrUVab6HA"));
     }
 
-    public void testGetStackTracesGroupedByInvalidField() {
-        GetStackTracesRequest request = new GetStackTracesRequest(
-            1000,
-            600.0d,
-            1.0d,
-            1.0d,
-            null,
-            null,
-            null,
-            // only service.name is supported (note the trailing "s")
-            "service.names",
-            null,
-            null,
-            null,
-            null,
-            null
-        );
-        request.setAdjustSampleCount(true);
-        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, client().execute(GetStackTracesAction.INSTANCE, request));
-        assertEquals("Requested custom event aggregation field [service.names] but only [service.name] is supported.", e.getMessage());
-    }
-
     public void testGetStackTracesFromAPMWithMatchNoDownsampling() throws Exception {
         BoolQueryBuilder query = QueryBuilders.boolQuery();
         query.must().add(QueryBuilders.termQuery("transaction.name", "encodeSha1"));
@@ -142,6 +122,7 @@ public class GetStackTracesActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         GetStackTracesResponse response = client().execute(GetStackTracesAction.INSTANCE, request).get();
@@ -161,7 +142,7 @@ public class GetStackTracesActionIT extends ProfilingTestCase {
         assertEquals(39, stackTrace.typeIds.length);
         assertTrue(stackTrace.annualCO2Tons > 0.0d);
         assertTrue(stackTrace.annualCostsUSD > 0.0d);
-        assertEquals(Long.valueOf(3L), stackTrace.subGroups.get("encodeSha1"));
+        assertEquals(Long.valueOf(3L), stackTrace.subGroups.getCount("encodeSha1"));
 
         assertNotNull(response.getStackFrames());
         StackFrame stackFrame = response.getStackFrames().get("fhsEKXDuxJ-jIJrZpdRuSAAAAAAAAFtj");
@@ -187,6 +168,7 @@ public class GetStackTracesActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         // ensures consistent results in the random sampler aggregation that is used internally
@@ -237,6 +219,7 @@ public class GetStackTracesActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         GetStackTracesResponse response = client().execute(GetStackTracesAction.INSTANCE, request).get();
@@ -259,6 +242,7 @@ public class GetStackTracesActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         GetStackTracesResponse response = client().execute(GetStackTracesAction.INSTANCE, request).get();
@@ -281,6 +265,7 @@ public class GetStackTracesActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         GetStackTracesResponse response = client().execute(GetStackTracesAction.INSTANCE, request).get();

+ 3 - 0
x-pack/plugin/profiling/src/internalClusterTest/java/org/elasticsearch/xpack/profiling/action/GetTopNFunctionsActionIT.java

@@ -25,6 +25,7 @@ public class GetTopNFunctionsActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         request.setAdjustSampleCount(true);
@@ -46,6 +47,7 @@ public class GetTopNFunctionsActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         request.setAdjustSampleCount(true);
@@ -73,6 +75,7 @@ public class GetTopNFunctionsActionIT extends ProfilingTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         GetTopNFunctionsResponse response = client().execute(GetTopNFunctionsAction.INSTANCE, request).get();

+ 56 - 7
x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesRequest.java

@@ -13,6 +13,7 @@ import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.action.support.TransportAction;
 import org.elasticsearch.common.ParsingException;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.UpdateForV9;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.Task;
@@ -42,7 +43,9 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
     public static final ParseField LIMIT_FIELD = new ParseField("limit");
     public static final ParseField INDICES_FIELD = new ParseField("indices");
     public static final ParseField STACKTRACE_IDS_FIELD = new ParseField("stacktrace_ids_field");
+    @UpdateForV9 // Remove this BWC layer and allow only AGGREGATION_FIELDS
     public static final ParseField AGGREGATION_FIELD = new ParseField("aggregation_field");
+    public static final ParseField AGGREGATION_FIELDS = new ParseField("aggregation_fields");
     public static final ParseField REQUESTED_DURATION_FIELD = new ParseField("requested_duration");
     public static final ParseField AWS_COST_FACTOR_FIELD = new ParseField("aws_cost_factor");
     public static final ParseField AZURE_COST_FACTOR_FIELD = new ParseField("azure_cost_factor");
@@ -59,7 +62,9 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
     private String[] indices;
     private boolean userProvidedIndices;
     private String stackTraceIdsField;
+    @UpdateForV9 // Remove this BWC layer and allow only aggregationFields
     private String aggregationField;
+    private String[] aggregationFields;
     private Double requestedDuration;
     private Double awsCostFactor;
     private Double azureCostFactor;
@@ -78,7 +83,7 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
     private Integer shardSeed;
 
     public GetStackTracesRequest() {
-        this(null, null, null, null, null, null, null, null, null, null, null, null, null);
+        this(null, null, null, null, null, null, null, null, null, null, null, null, null, null);
     }
 
     public GetStackTracesRequest(
@@ -90,6 +95,7 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
         String[] indices,
         String stackTraceIdsField,
         String aggregationField,
+        String[] aggregationFields,
         Double customCO2PerKWH,
         Double customDatacenterPUE,
         Double customPerCoreWattX86,
@@ -105,6 +111,7 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
         this.userProvidedIndices = indices != null && indices.length > 0;
         this.stackTraceIdsField = stackTraceIdsField;
         this.aggregationField = aggregationField;
+        this.aggregationFields = aggregationFields;
         this.customCO2PerKWH = customCO2PerKWH;
         this.customDatacenterPUE = customDatacenterPUE;
         this.customPerCoreWattX86 = customPerCoreWattX86;
@@ -181,6 +188,19 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
         return aggregationField;
     }
 
+    public String[] getAggregationFields() {
+        return aggregationField != null ? new String[] { aggregationField } : aggregationFields;
+    }
+
+    public boolean hasAggregationFields() {
+        String[] f = getAggregationFields();
+        return f != null && f.length > 0;
+    }
+
+    public boolean isLegacyAggregationField() {
+        return aggregationField != null;
+    }
+
     public boolean isAdjustSampleCount() {
         return Boolean.TRUE.equals(adjustSampleCount);
     }
@@ -244,8 +264,10 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
                 }
             } else if (token == XContentParser.Token.START_ARRAY) {
                 if (INDICES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
-                    this.indices = parseIndices(parser);
+                    this.indices = parseToStringArray(parser, INDICES_FIELD);
                     this.userProvidedIndices = true;
+                } else if (AGGREGATION_FIELDS.match(currentFieldName, parser.getDeprecationHandler())) {
+                    this.aggregationFields = parseToStringArray(parser, AGGREGATION_FIELDS);
                 } else {
                     throw new ParsingException(parser.getTokenLocation(), "Unexpected token " + token + " in [" + currentFieldName + "].");
                 }
@@ -260,12 +282,12 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
         }
     }
 
-    private String[] parseIndices(XContentParser parser) throws IOException {
+    private String[] parseToStringArray(XContentParser parser, ParseField parseField) throws IOException {
         XContentParser.Token token;
-        List<String> indices = new ArrayList<>();
+        List<String> values = new ArrayList<>();
         while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
             if (token == XContentParser.Token.VALUE_STRING) {
-                indices.add(parser.text());
+                values.add(parser.text());
             } else {
                 throw new ParsingException(
                     parser.getTokenLocation(),
@@ -274,12 +296,12 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
                         + "] but found ["
                         + token
                         + "] in ["
-                        + INDICES_FIELD.getPreferredName()
+                        + parseField.getPreferredName()
                         + "]."
                 );
             }
         }
-        return indices.toArray(new String[0]);
+        return values.toArray(new String[0]);
     }
 
     @Override
@@ -300,6 +322,32 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
                 );
             }
         }
+        if (aggregationField != null && aggregationFields != null) {
+            validationException = addValidationError(
+                "["
+                    + AGGREGATION_FIELD.getPreferredName()
+                    + "] must not be set when ["
+                    + AGGREGATION_FIELDS.getPreferredName()
+                    + "] is also set",
+                validationException
+            );
+
+        }
+        if (aggregationFields != null) {
+            // limit so we avoid an explosion of buckets
+            if (aggregationFields.length < 1 || aggregationFields.length > 2) {
+                validationException = addValidationError(
+                    "["
+                        + AGGREGATION_FIELDS.getPreferredName()
+                        + "] must contain either one or two elements but contains ["
+                        + aggregationFields.length
+                        + "] elements.",
+                    validationException
+                );
+            }
+
+        }
+
         if (aggregationField != null && aggregationField.isBlank()) {
             validationException = addValidationError(
                 "[" + AGGREGATION_FIELD.getPreferredName() + "] must be non-empty",
@@ -339,6 +387,7 @@ public class GetStackTracesRequest extends ActionRequest implements IndicesReque
                 appendField(sb, "indices", indices);
                 appendField(sb, "stacktrace_ids_field", stackTraceIdsField);
                 appendField(sb, "aggregation_field", aggregationField);
+                appendField(sb, "aggregation_fields", aggregationFields);
                 appendField(sb, "sample_size", sampleSize);
                 appendField(sb, "limit", limit);
                 appendField(sb, "requested_duration", requestedDuration);

+ 1 - 1
x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/GetStackTracesResponseBuilder.java

@@ -155,7 +155,7 @@ class GetStackTracesResponseBuilder {
                 if (event != null) {
                     StackTrace stackTrace = entry.getValue();
                     stackTrace.count = event.count;
-                    if (event.subGroups.isEmpty() == false) {
+                    if (event.subGroups != null) {
                         stackTrace.subGroups = event.subGroups;
                     }
                     stackTrace.annualCO2Tons = event.annualCO2Tons;

+ 1 - 1
x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/StackTrace.java

@@ -26,7 +26,7 @@ final class StackTrace implements ToXContentObject {
     String[] fileIds;
     String[] frameIds;
     int[] typeIds;
-    Map<String, Long> subGroups;
+    SubGroup subGroups;
     double annualCO2Tons;
     double annualCostsUSD;
     long count;

+ 142 - 0
x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/SubGroup.java

@@ -0,0 +1,142 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.profiling.action;
+
+import org.elasticsearch.core.UpdateForV9;
+import org.elasticsearch.xcontent.ToXContentFragment;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+public class SubGroup implements ToXContentFragment {
+    private final String name;
+    private Long count;
+    @UpdateForV9 // remove legacy XContent rendering
+    private final boolean renderLegacyXContent;
+    private final Map<String, SubGroup> subgroups;
+
+    public static SubGroup root(String name, boolean renderLegacyXContent) {
+        return new SubGroup(name, null, renderLegacyXContent, new HashMap<>());
+    }
+
+    public SubGroup(String name, Long count, boolean renderLegacyXContent, Map<String, SubGroup> subgroups) {
+        this.name = name;
+        this.count = count;
+        this.renderLegacyXContent = renderLegacyXContent;
+        this.subgroups = subgroups;
+    }
+
+    public SubGroup addCount(String name, long count) {
+        if (this.subgroups.containsKey(name) == false) {
+            this.subgroups.put(name, new SubGroup(name, count, renderLegacyXContent, new HashMap<>()));
+        } else {
+            SubGroup s = this.subgroups.get(name);
+            s.count += count;
+        }
+        return this;
+    }
+
+    public SubGroup getOrAddChild(String name) {
+        if (subgroups.containsKey(name) == false) {
+            this.subgroups.put(name, new SubGroup(name, null, renderLegacyXContent, new HashMap<>()));
+        }
+        return this.subgroups.get(name);
+    }
+
+    public Long getCount(String name) {
+        SubGroup subGroup = this.subgroups.get(name);
+        return subGroup != null ? subGroup.count : null;
+    }
+
+    public SubGroup getSubGroup(String name) {
+        return this.subgroups.get(name);
+    }
+
+    public SubGroup copy() {
+        Map<String, SubGroup> copy = new HashMap<>(subgroups.size());
+        for (Map.Entry<String, SubGroup> subGroup : subgroups.entrySet()) {
+            copy.put(subGroup.getKey(), subGroup.getValue().copy());
+        }
+        return new SubGroup(name, count, renderLegacyXContent, copy);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        if (renderLegacyXContent) {
+            // This assumes that we only have one level of sub groups
+            if (subgroups != null && subgroups.isEmpty() == false) {
+                for (SubGroup subgroup : subgroups.values()) {
+                    builder.field(subgroup.name, subgroup.count);
+                }
+            }
+            return builder;
+        } else {
+            builder.startObject(name);
+            // only the root node has no count
+            if (count != null) {
+                builder.field("count", count);
+            }
+            if (subgroups != null && subgroups.isEmpty() == false) {
+                for (SubGroup subgroup : subgroups.values()) {
+                    subgroup.toXContent(builder, params);
+                }
+            }
+            return builder.endObject();
+        }
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        SubGroup subGroup = (SubGroup) o;
+        return Objects.equals(name, subGroup.name)
+            && Objects.equals(count, subGroup.count)
+            && Objects.equals(subgroups, subGroup.subgroups);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(name, count, subgroups);
+    }
+
+    @Override
+    public String toString() {
+        return name;
+    }
+
+    public void merge(SubGroup s) {
+        if (s == null) {
+            return;
+        }
+        // must have the same name
+        if (this.name.equals(s.name)) {
+            if (this.count != null && s.count != null) {
+                this.count += s.count;
+            } else if (this.count == null) {
+                this.count = s.count;
+            }
+            for (SubGroup subGroup : s.subgroups.values()) {
+                if (this.subgroups.containsKey(subGroup.name)) {
+                    // merge
+                    this.subgroups.get(subGroup.name).merge(subGroup);
+                } else {
+                    // add sub group as is (recursively)
+                    this.subgroups.put(subGroup.name, subGroup.copy());
+                }
+            }
+        }
+    }
+}

+ 156 - 0
x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/SubGroupCollector.java

@@ -0,0 +1,156 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.profiling.action;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
+import org.elasticsearch.search.aggregations.InternalAggregations;
+import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
+import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
+
+import java.util.Iterator;
+
+public final class SubGroupCollector {
+    /**
+     * Users may provide a custom field via the API that is used to sub-divide profiling events. This is useful in the context of TopN
+     * where we want to provide additional breakdown of where a certain function has been called (e.g. a certain service or transaction).
+     */
+    static final String CUSTOM_EVENT_SUB_AGGREGATION_NAME = "custom_event_group_";
+
+    private static final Logger log = LogManager.getLogger(SubGroupCollector.class);
+
+    private final String[] aggregationFields;
+    private final boolean legacyAggregationField;
+
+    public static SubGroupCollector attach(
+        AbstractAggregationBuilder<?> parentAggregation,
+        String[] aggregationFields,
+        boolean legacyAggregationField
+    ) {
+        SubGroupCollector c = new SubGroupCollector(aggregationFields, legacyAggregationField);
+        c.addAggregations(parentAggregation);
+        return c;
+    }
+
+    private SubGroupCollector(String[] aggregationFields, boolean legacyAggregationField) {
+        this.aggregationFields = aggregationFields;
+        this.legacyAggregationField = legacyAggregationField;
+    }
+
+    private boolean hasAggregationFields() {
+        return aggregationFields != null && aggregationFields.length > 0;
+    }
+
+    private void addAggregations(AbstractAggregationBuilder<?> parentAggregation) {
+        if (hasAggregationFields()) {
+            // cast to Object to disambiguate this from a varargs call
+            log.trace("Grouping stacktrace events by {}.", (Object) aggregationFields);
+            AbstractAggregationBuilder<?> parentAgg = parentAggregation;
+            for (String aggregationField : aggregationFields) {
+                String aggName = CUSTOM_EVENT_SUB_AGGREGATION_NAME + aggregationField;
+                TermsAggregationBuilder agg = new TermsAggregationBuilder(aggName).field(aggregationField);
+                parentAgg.subAggregation(agg);
+                parentAgg = agg;
+            }
+        }
+    }
+
+    void collectResults(MultiBucketsAggregation.Bucket bucket, TraceEvent event) {
+        collectResults(new BucketAdapter(bucket), event);
+    }
+
+    void collectResults(Bucket bucket, TraceEvent event) {
+        if (hasAggregationFields()) {
+            if (event.subGroups == null) {
+                event.subGroups = SubGroup.root(aggregationFields[0], legacyAggregationField);
+            }
+            collectInternal(bucket.getAggregations(), event.subGroups, 0);
+        }
+    }
+
+    private void collectInternal(Agg parentAgg, SubGroup parentGroup, int aggField) {
+        if (aggField == aggregationFields.length) {
+            return;
+        }
+        String aggName = CUSTOM_EVENT_SUB_AGGREGATION_NAME + aggregationFields[aggField];
+        for (Bucket b : parentAgg.getBuckets(aggName)) {
+            String subGroupName = b.getKey();
+            parentGroup.addCount(subGroupName, b.getCount());
+            SubGroup currentGroup = parentGroup.getSubGroup(subGroupName);
+            int nextAggField = aggField + 1;
+            if (nextAggField < aggregationFields.length) {
+                collectInternal(b.getAggregations(), currentGroup.getOrAddChild(aggregationFields[nextAggField]), nextAggField);
+            }
+        }
+    }
+
+    // The sole purpose of the code below is to abstract our code from the aggs framework to make it unit-testable
+    interface Agg {
+        Iterable<Bucket> getBuckets(String aggName);
+
+    }
+
+    interface Bucket {
+        String getKey();
+
+        long getCount();
+
+        Agg getAggregations();
+    }
+
+    static class InternalAggregationAdapter implements Agg {
+        private final InternalAggregations agg;
+
+        InternalAggregationAdapter(InternalAggregations agg) {
+            this.agg = agg;
+        }
+
+        @Override
+        public Iterable<Bucket> getBuckets(String aggName) {
+            MultiBucketsAggregation multiBucketsAggregation = agg.get(aggName);
+            return () -> {
+                Iterator<? extends MultiBucketsAggregation.Bucket> it = multiBucketsAggregation.getBuckets().iterator();
+                return new Iterator<>() {
+                    @Override
+                    public boolean hasNext() {
+                        return it.hasNext();
+                    }
+
+                    @Override
+                    public Bucket next() {
+                        return new BucketAdapter(it.next());
+                    }
+                };
+            };
+        }
+    }
+
+    static class BucketAdapter implements Bucket {
+        private final MultiBucketsAggregation.Bucket bucket;
+
+        BucketAdapter(MultiBucketsAggregation.Bucket bucket) {
+            this.bucket = bucket;
+        }
+
+        @Override
+        public String getKey() {
+            return bucket.getKeyAsString();
+        }
+
+        @Override
+        public long getCount() {
+            return bucket.getDocCount();
+        }
+
+        @Override
+        public Agg getAggregations() {
+            return new InternalAggregationAdapter(bucket.getAggregations());
+        }
+    }
+}

+ 16 - 14
x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TopNFunction.java

@@ -11,11 +11,9 @@ import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
 import java.util.Objects;
 
-final class TopNFunction implements Cloneable, ToXContentObject, Comparable<TopNFunction> {
+final class TopNFunction implements ToXContentObject, Comparable<TopNFunction> {
     private final String id;
     private int rank;
     private final int frameType;
@@ -31,7 +29,7 @@ final class TopNFunction implements Cloneable, ToXContentObject, Comparable<TopN
     private double totalAnnualCO2Tons;
     private double selfAnnualCostsUSD;
     private double totalAnnualCostsUSD;
-    private final Map<String, Long> subGroups;
+    private SubGroup subGroups;
 
     TopNFunction(
         String id,
@@ -59,7 +57,7 @@ final class TopNFunction implements Cloneable, ToXContentObject, Comparable<TopN
             0.0d,
             0.0d,
             0.0d,
-            new HashMap<>()
+            null
         );
     }
 
@@ -79,7 +77,7 @@ final class TopNFunction implements Cloneable, ToXContentObject, Comparable<TopN
         double totalAnnualCO2Tons,
         double selfAnnualCostsUSD,
         double totalAnnualCostsUSD,
-        Map<String, Long> subGroups
+        SubGroup subGroups
     ) {
         this.id = id;
         this.rank = rank;
@@ -147,15 +145,15 @@ final class TopNFunction implements Cloneable, ToXContentObject, Comparable<TopN
         this.totalAnnualCostsUSD += costs;
     }
 
-    public void addSubGroups(Map<String, Long> subGroups) {
-        for (Map.Entry<String, Long> subGroup : subGroups.entrySet()) {
-            long count = this.subGroups.getOrDefault(subGroup.getKey(), 0L);
-            this.subGroups.put(subGroup.getKey(), count + subGroup.getValue());
+    public void addSubGroups(SubGroup subGroups) {
+        if (this.subGroups == null) {
+            this.subGroups = subGroups.copy();
+        } else {
+            this.subGroups.merge(subGroups);
         }
     }
 
-    @Override
-    protected TopNFunction clone() {
+    public TopNFunction copy() {
         return new TopNFunction(
             id,
             rank,
@@ -172,7 +170,7 @@ final class TopNFunction implements Cloneable, ToXContentObject, Comparable<TopN
             totalAnnualCO2Tons,
             selfAnnualCostsUSD,
             totalAnnualCostsUSD,
-            new HashMap<>(subGroups)
+            subGroups.copy()
         );
     }
 
@@ -190,7 +188,11 @@ final class TopNFunction implements Cloneable, ToXContentObject, Comparable<TopN
         builder.field("line_number", this.sourceLine);
         builder.field("executable_file_name", this.exeFilename);
         builder.endObject();
-        builder.field("sub_groups", subGroups);
+        if (subGroups != null) {
+            builder.startObject("sub_groups");
+            subGroups.toXContent(builder, params);
+            builder.endObject();
+        }
         builder.field("self_count", this.selfCount);
         builder.field("total_count", this.totalCount);
         builder.field("self_annual_co2_tons").rawValue(NumberUtils.doubleToString(selfAnnualCO2Tons));

+ 1 - 3
x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TraceEvent.java

@@ -7,8 +7,6 @@
 
 package org.elasticsearch.xpack.profiling.action;
 
-import java.util.HashMap;
-import java.util.Map;
 import java.util.Objects;
 
 final class TraceEvent {
@@ -16,7 +14,7 @@ final class TraceEvent {
     double annualCO2Tons;
     double annualCostsUSD;
     long count;
-    final Map<String, Long> subGroups = new HashMap<>();
+    SubGroup subGroups;
 
     TraceEvent(String stacktraceID) {
         this(stacktraceID, 0);

+ 12 - 40
x-pack/plugin/profiling/src/main/java/org/elasticsearch/xpack/profiling/action/TransportGetStackTracesAction.java

@@ -254,19 +254,11 @@ public class TransportGetStackTracesAction extends TransportAction<GetStackTrace
         CountedTermsAggregationBuilder groupByStackTraceId = new CountedTermsAggregationBuilder("group_by").size(
             MAX_TRACE_EVENTS_RESULT_SIZE
         ).field(request.getStackTraceIdsField());
-        if (request.getAggregationField() != null) {
-            String aggregationField = request.getAggregationField();
-            log.trace("Grouping stacktrace events by [{}].", aggregationField);
-            // be strict about the accepted field names to avoid downstream errors or leaking unintended information
-            if (aggregationField.equals("transaction.name") == false) {
-                throw new IllegalArgumentException(
-                    "Requested custom event aggregation field [" + aggregationField + "] but only [transaction.name] is supported."
-                );
-            }
-            groupByStackTraceId.subAggregation(
-                new TermsAggregationBuilder(CUSTOM_EVENT_SUB_AGGREGATION_NAME).field(request.getAggregationField())
-            );
-        }
+        SubGroupCollector subGroups = SubGroupCollector.attach(
+            groupByStackTraceId,
+            request.getAggregationFields(),
+            request.isLegacyAggregationField()
+        );
         RandomSamplerAggregationBuilder randomSampler = new RandomSamplerAggregationBuilder("sample").setSeed(request.hashCode())
             .setProbability(responseBuilder.getSamplingRate())
             .subAggregation(groupByStackTraceId);
@@ -307,14 +299,7 @@ public class TransportGetStackTracesAction extends TransportAction<GetStackTrace
                         stackTraceEvents.put(stackTraceID, event);
                     }
                     event.count += count;
-                    if (request.getAggregationField() != null) {
-                        Terms eventSubGroup = stacktraceBucket.getAggregations().get(CUSTOM_EVENT_SUB_AGGREGATION_NAME);
-                        for (Terms.Bucket b : eventSubGroup.getBuckets()) {
-                            String subGroupName = b.getKeyAsString();
-                            long subGroupCount = event.subGroups.getOrDefault(subGroupName, 0L);
-                            event.subGroups.put(subGroupName, subGroupCount + b.getDocCount());
-                        }
-                    }
+                    subGroups.collectResults(stacktraceBucket, event);
                 }
                 responseBuilder.setTotalSamples(totalSamples);
                 responseBuilder.setHostEventCounts(hostEventCounts);
@@ -340,17 +325,11 @@ public class TransportGetStackTracesAction extends TransportAction<GetStackTrace
             // Especially with high cardinality fields, this makes aggregations really slow.
             .executionHint("map")
             .subAggregation(new SumAggregationBuilder("count").field("Stacktrace.count"));
-        if (request.getAggregationField() != null) {
-            String aggregationField = request.getAggregationField();
-            log.trace("Grouping stacktrace events by [{}].", aggregationField);
-            // be strict about the accepted field names to avoid downstream errors or leaking unintended information
-            if (aggregationField.equals("service.name") == false) {
-                throw new IllegalArgumentException(
-                    "Requested custom event aggregation field [" + aggregationField + "] but only [service.name] is supported."
-                );
-            }
-            groupByStackTraceId.subAggregation(new TermsAggregationBuilder(CUSTOM_EVENT_SUB_AGGREGATION_NAME).field(aggregationField));
-        }
+        SubGroupCollector subGroups = SubGroupCollector.attach(
+            groupByStackTraceId,
+            request.getAggregationFields(),
+            request.isLegacyAggregationField()
+        );
         client.prepareSearch(eventsIndex.getName())
             .setTrackTotalHits(false)
             .setSize(0)
@@ -412,14 +391,7 @@ public class TransportGetStackTracesAction extends TransportAction<GetStackTrace
                             stackTraceEvents.put(stackTraceID, event);
                         }
                         event.count += finalCount;
-                        if (request.getAggregationField() != null) {
-                            Terms subGroup = stacktraceBucket.getAggregations().get(CUSTOM_EVENT_SUB_AGGREGATION_NAME);
-                            for (Terms.Bucket b : subGroup.getBuckets()) {
-                                String subGroupName = b.getKeyAsString();
-                                long subGroupCount = event.subGroups.getOrDefault(subGroupName, 0L);
-                                event.subGroups.put(subGroupName, subGroupCount + b.getDocCount());
-                            }
-                        }
+                        subGroups.collectResults(stacktraceBucket, event);
                     }
                 }
                 responseBuilder.setTotalSamples(totalFinalCount);

+ 204 - 0
x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/GetStackTracesRequestTests.java

@@ -51,6 +51,8 @@ public class GetStackTracesRequestTests extends ESTestCase {
             // Expect the default values
             assertNull(request.getIndices());
             assertNull(request.getStackTraceIdsField());
+            assertFalse(request.isLegacyAggregationField());
+            assertNull(request.getAggregationFields());
             assertNull(request.getAwsCostFactor());
             assertNull(request.getAzureCostFactor());
             assertNull(request.getCustomCO2PerKWH());
@@ -90,6 +92,92 @@ public class GetStackTracesRequestTests extends ESTestCase {
 
             // Expect the default values
             assertNull(request.getRequestedDuration());
+            assertFalse(request.isLegacyAggregationField());
+            assertNull(request.getAggregationFields());
+            assertNull(request.getAwsCostFactor());
+            assertNull(request.getAzureCostFactor());
+            assertNull(request.getCustomCO2PerKWH());
+            assertNull(request.getCustomDatacenterPUE());
+            assertNull(request.getCustomCostPerCoreHour());
+            assertNull(request.getCustomPerCoreWattX86());
+            assertNull(request.getCustomPerCoreWattARM64());
+        }
+    }
+
+    public void testParseValidXContentWithOneAggregationField() throws IOException {
+        try (XContentParser content = createParser(XContentFactory.jsonBuilder()
+        //tag::noformat
+            .startObject()
+                .field("sample_size", 2000)
+                .field("indices", new String[] {"my-traces"})
+                .field("stacktrace_ids_field", "stacktraces")
+                .field("aggregation_field", "service")
+                .startObject("query")
+                    .startObject("range")
+                        .startObject("@timestamp")
+                            .field("gte", "2022-10-05")
+                        .endObject()
+                    .endObject()
+                .endObject()
+            .endObject()
+        //end::noformat
+        )) {
+
+            GetStackTracesRequest request = new GetStackTracesRequest();
+            request.parseXContent(content);
+
+            assertEquals(2000, request.getSampleSize());
+            assertArrayEquals(new String[] { "my-traces" }, request.getIndices());
+            assertEquals("stacktraces", request.getStackTraceIdsField());
+            assertArrayEquals(new String[] { "service" }, request.getAggregationFields());
+            assertTrue(request.isLegacyAggregationField());
+            // a basic check suffices here
+            assertEquals("@timestamp", ((RangeQueryBuilder) request.getQuery()).fieldName());
+
+            // Expect the default values
+            assertNull(request.getRequestedDuration());
+            assertNull(request.getAwsCostFactor());
+            assertNull(request.getAzureCostFactor());
+            assertNull(request.getCustomCO2PerKWH());
+            assertNull(request.getCustomDatacenterPUE());
+            assertNull(request.getCustomCostPerCoreHour());
+            assertNull(request.getCustomPerCoreWattX86());
+            assertNull(request.getCustomPerCoreWattARM64());
+        }
+    }
+
+    public void testParseValidXContentWithMultipleAggregationFields() throws IOException {
+        try (XContentParser content = createParser(XContentFactory.jsonBuilder()
+        //tag::noformat
+            .startObject()
+                .field("sample_size", 2000)
+                .field("indices", new String[] {"my-traces"})
+                .field("stacktrace_ids_field", "stacktraces")
+                .field("aggregation_fields", new String[] {"service", "transaction"})
+                .startObject("query")
+                    .startObject("range")
+                        .startObject("@timestamp")
+                            .field("gte", "2022-10-05")
+                        .endObject()
+                    .endObject()
+                .endObject()
+            .endObject()
+        //end::noformat
+        )) {
+
+            GetStackTracesRequest request = new GetStackTracesRequest();
+            request.parseXContent(content);
+
+            assertEquals(2000, request.getSampleSize());
+            assertArrayEquals(new String[] { "my-traces" }, request.getIndices());
+            assertEquals("stacktraces", request.getStackTraceIdsField());
+            assertArrayEquals(new String[] { "service", "transaction" }, request.getAggregationFields());
+            // a basic check suffices here
+            assertEquals("@timestamp", ((RangeQueryBuilder) request.getQuery()).fieldName());
+
+            // Expect the default values
+            assertNull(request.getRequestedDuration());
+            assertFalse(request.isLegacyAggregationField());
             assertNull(request.getAwsCostFactor());
             assertNull(request.getAzureCostFactor());
             assertNull(request.getCustomCO2PerKWH());
@@ -143,6 +231,8 @@ public class GetStackTracesRequestTests extends ESTestCase {
             // Expect the default values
             assertNull(request.getIndices());
             assertNull(request.getStackTraceIdsField());
+            assertFalse(request.isLegacyAggregationField());
+            assertNull(request.getAggregationFields());
         }
     }
 
@@ -255,6 +345,7 @@ public class GetStackTracesRequestTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         List<String> validationErrors = request.validate().validationErrors();
@@ -276,6 +367,7 @@ public class GetStackTracesRequestTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         assertNull("Expecting no validation errors", request.validate());
@@ -295,6 +387,7 @@ public class GetStackTracesRequestTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         List<String> validationErrors = request.validate().validationErrors();
@@ -316,6 +409,7 @@ public class GetStackTracesRequestTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         List<String> validationErrors = request.validate().validationErrors();
@@ -323,6 +417,114 @@ public class GetStackTracesRequestTests extends ESTestCase {
         assertEquals("[stacktrace_ids_field] is mandatory", validationErrors.get(0));
     }
 
+    public void testValidateEmptyAggregationField() {
+        GetStackTracesRequest request = new GetStackTracesRequest(
+            null,
+            1.0d,
+            1.0d,
+            1.0d,
+            null,
+            new String[] { randomAlphaOfLength(5) },
+            randomAlphaOfLength(5),
+            "",
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+        List<String> validationErrors = request.validate().validationErrors();
+        assertEquals(1, validationErrors.size());
+        assertEquals("[aggregation_field] must be non-empty", validationErrors.get(0));
+    }
+
+    public void testValidateAggregationFieldAndAggregationFields() {
+        GetStackTracesRequest request = new GetStackTracesRequest(
+            null,
+            1.0d,
+            1.0d,
+            1.0d,
+            null,
+            new String[] { randomAlphaOfLength(5) },
+            randomAlphaOfLength(5),
+            "transaction.name",
+            new String[] { "transaction.name", "service.name" },
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+        List<String> validationErrors = request.validate().validationErrors();
+        assertEquals(1, validationErrors.size());
+        assertEquals("[aggregation_field] must not be set when [aggregation_fields] is also set", validationErrors.get(0));
+    }
+
+    public void testValidateAggregationFieldsContainsTooFewElements() {
+        GetStackTracesRequest request = new GetStackTracesRequest(
+            null,
+            1.0d,
+            1.0d,
+            1.0d,
+            null,
+            new String[] { randomAlphaOfLength(5) },
+            randomAlphaOfLength(5),
+            null,
+            new String[] {},
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+        List<String> validationErrors = request.validate().validationErrors();
+        assertEquals(1, validationErrors.size());
+        assertEquals("[aggregation_fields] must contain either one or two elements but contains [0] elements.", validationErrors.get(0));
+    }
+
+    public void testValidateAggregationFieldsContainsTooManyElements() {
+        GetStackTracesRequest request = new GetStackTracesRequest(
+            null,
+            1.0d,
+            1.0d,
+            1.0d,
+            null,
+            new String[] { randomAlphaOfLength(5) },
+            randomAlphaOfLength(5),
+            null,
+            new String[] { "application", "service", "transaction" },
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+        List<String> validationErrors = request.validate().validationErrors();
+        assertEquals(1, validationErrors.size());
+        assertEquals("[aggregation_fields] must contain either one or two elements but contains [3] elements.", validationErrors.get(0));
+    }
+
+    public void testValidateAggregationFieldsContainsEnoughElements() {
+        GetStackTracesRequest request = new GetStackTracesRequest(
+            null,
+            1.0d,
+            1.0d,
+            1.0d,
+            null,
+            new String[] { randomAlphaOfLength(5) },
+            randomAlphaOfLength(5),
+            null,
+            new String[] { "service", "service" },
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+        assertNull("Expecting no validation errors", request.validate());
+    }
+
     public void testConsidersCustomIndicesInRelatedIndices() {
         String customIndex = randomAlphaOfLength(5);
         GetStackTracesRequest request = new GetStackTracesRequest(
@@ -338,6 +540,7 @@ public class GetStackTracesRequestTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         String[] indices = request.indices();
@@ -359,6 +562,7 @@ public class GetStackTracesRequestTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         String[] indices = request.indices();

+ 8 - 3
x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/GetTopNFunctionsResponseTests.java

@@ -16,7 +16,6 @@ import org.elasticsearch.xcontent.XContentType;
 
 import java.io.IOException;
 import java.util.List;
-import java.util.Map;
 
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
 
@@ -56,7 +55,13 @@ public class GetTopNFunctionsResponseTests extends ESTestCase {
                             .field("line_number", sourceLine)
                             .field("executable_file_name", exeFilename)
                         .endObject()
-                        .field("sub_groups", Map.of("basket", 7L))
+                        .startObject("sub_groups")
+                            .startObject("transaction.name")
+                                .startObject("basket")
+                                    .field("count", 7L)
+                                .endObject()
+                            .endObject()
+                        .endObject()
                         .field("self_count", 1)
                         .field("total_count", 10)
                         .field("self_annual_co2_tons").rawValue("2.2000")
@@ -85,7 +90,7 @@ public class GetTopNFunctionsResponseTests extends ESTestCase {
             22.0d,
             12.0d,
             120.0d,
-            Map.of("basket", 7L)
+            SubGroup.root("transaction.name", false).addCount("basket", 7L)
         );
         GetTopNFunctionsResponse response = new GetTopNFunctionsResponse(1, 10, 2.2d, 12.0d, List.of(topNFunction));
         response.toXContent(actualResponse, ToXContent.EMPTY_PARAMS);

+ 5 - 0
x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/ResamplerTests.java

@@ -43,6 +43,7 @@ public class ResamplerTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         request.setAdjustSampleCount(false);
@@ -72,6 +73,7 @@ public class ResamplerTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         request.setAdjustSampleCount(true);
@@ -101,6 +103,7 @@ public class ResamplerTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         request.setAdjustSampleCount(false);
@@ -133,6 +136,7 @@ public class ResamplerTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
 
@@ -162,6 +166,7 @@ public class ResamplerTests extends ESTestCase {
             null,
             null,
             null,
+            null,
             null
         );
         request.setAdjustSampleCount(true);

+ 149 - 0
x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/SubGroupCollectorTests.java

@@ -0,0 +1,149 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.profiling.action;
+
+import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.profiling.action.SubGroupCollector.CUSTOM_EVENT_SUB_AGGREGATION_NAME;
+
+public class SubGroupCollectorTests extends ESTestCase {
+    public void testNoAggs() {
+        TermsAggregationBuilder stackTraces = new TermsAggregationBuilder("stacktraces").field("stacktrace.id");
+        TraceEvent traceEvent = new TraceEvent("1");
+
+        SubGroupCollector collector = SubGroupCollector.attach(stackTraces, new String[0], false);
+        assertTrue("Sub aggregations attached", stackTraces.getSubAggregations().isEmpty());
+
+        SubGroupCollector.Bucket currentStackTrace = bucket("1", 5);
+        collector.collectResults(currentStackTrace, traceEvent);
+
+        assertNull(traceEvent.subGroups);
+    }
+
+    public void testMultipleAggsInSingleStackTrace() {
+        TermsAggregationBuilder stackTraces = new TermsAggregationBuilder("stacktraces").field("stacktrace.id");
+        TraceEvent traceEvent = new TraceEvent("1");
+
+        SubGroupCollector collector = SubGroupCollector.attach(stackTraces, new String[] { "service.name", "transaction.name" }, false);
+        assertFalse("No sub aggregations attached", stackTraces.getSubAggregations().isEmpty());
+
+        StaticAgg services = new StaticAgg();
+        SubGroupCollector.Bucket currentStackTrace = bucket("1", 5, services);
+        // tag::noformat
+        services.addBuckets(CUSTOM_EVENT_SUB_AGGREGATION_NAME + "service.name",
+            bucket("basket", 7L,
+                agg(CUSTOM_EVENT_SUB_AGGREGATION_NAME + "transaction.name",
+                    bucket("add-to-basket", 4L),
+                    bucket("delete-from-basket", 3L)
+                )
+            ),
+            bucket("checkout", 4L,
+                agg(CUSTOM_EVENT_SUB_AGGREGATION_NAME + "transaction.name",
+                    bucket("enter-address", 4L),
+                    bucket("submit-order", 3L)
+                )
+            )
+        );
+        // end::noformat
+
+        collector.collectResults(currentStackTrace, traceEvent);
+
+        assertNotNull(traceEvent.subGroups);
+        assertEquals(Long.valueOf(7L), traceEvent.subGroups.getCount("basket"));
+        assertEquals(Long.valueOf(4L), traceEvent.subGroups.getCount("checkout"));
+        SubGroup basketTransactionNames = traceEvent.subGroups.getSubGroup("basket").getSubGroup("transaction.name");
+        assertEquals(Long.valueOf(4L), basketTransactionNames.getCount("add-to-basket"));
+        assertEquals(Long.valueOf(3L), basketTransactionNames.getCount("delete-from-basket"));
+        SubGroup checkoutTransactionNames = traceEvent.subGroups.getSubGroup("checkout").getSubGroup("transaction.name");
+        assertEquals(Long.valueOf(4L), checkoutTransactionNames.getCount("enter-address"));
+        assertEquals(Long.valueOf(3L), checkoutTransactionNames.getCount("submit-order"));
+    }
+
+    public void testSingleAggInMultipleStackTraces() {
+        TermsAggregationBuilder stackTraces = new TermsAggregationBuilder("stacktraces").field("stacktrace.id");
+        TraceEvent traceEvent = new TraceEvent("1");
+
+        SubGroupCollector collector = SubGroupCollector.attach(stackTraces, new String[] { "service.name" }, false);
+        assertFalse("No sub aggregations attached", stackTraces.getSubAggregations().isEmpty());
+
+        StaticAgg services1 = new StaticAgg();
+        SubGroupCollector.Bucket currentStackTrace1 = bucket("1", 5, services1);
+        services1.addBuckets(CUSTOM_EVENT_SUB_AGGREGATION_NAME + "service.name", bucket("basket", 7L));
+
+        collector.collectResults(currentStackTrace1, traceEvent);
+
+        StaticAgg services2 = new StaticAgg();
+        SubGroupCollector.Bucket currentStackTrace2 = bucket("1", 3, services2);
+        services2.addBuckets(CUSTOM_EVENT_SUB_AGGREGATION_NAME + "service.name", bucket("basket", 1L), bucket("checkout", 5L));
+
+        collector.collectResults(currentStackTrace2, traceEvent);
+
+        assertNotNull(traceEvent.subGroups);
+        assertEquals(Long.valueOf(8L), traceEvent.subGroups.getCount("basket"));
+        assertEquals(Long.valueOf(5L), traceEvent.subGroups.getCount("checkout"));
+    }
+
+    private SubGroupCollector.Bucket bucket(String key, long count) {
+        return bucket(key, count, null);
+    }
+
+    private SubGroupCollector.Bucket bucket(String key, long count, SubGroupCollector.Agg aggregations) {
+        return new StaticBucket(key, count, aggregations);
+    }
+
+    private SubGroupCollector.Agg agg(String name, SubGroupCollector.Bucket... buckets) {
+        StaticAgg a = new StaticAgg();
+        a.addBuckets(name, buckets);
+        return a;
+    }
+
+    private static class StaticBucket implements SubGroupCollector.Bucket {
+        private final String key;
+        private final long count;
+        private SubGroupCollector.Agg aggregations;
+
+        private StaticBucket(String key, long count, SubGroupCollector.Agg aggregations) {
+            this.key = key;
+            this.count = count;
+            this.aggregations = aggregations;
+        }
+
+        @Override
+        public String getKey() {
+            return key;
+        }
+
+        @Override
+        public long getCount() {
+            return count;
+        }
+
+        @Override
+        public SubGroupCollector.Agg getAggregations() {
+            return aggregations;
+        }
+    }
+
+    private static class StaticAgg implements SubGroupCollector.Agg {
+        private final Map<String, List<SubGroupCollector.Bucket>> buckets = new HashMap<>();
+
+        public void addBuckets(String name, SubGroupCollector.Bucket... buckets) {
+            this.buckets.put(name, List.of(buckets));
+        }
+
+        @Override
+        public Iterable<SubGroupCollector.Bucket> getBuckets(String aggName) {
+            return buckets.get(aggName);
+        }
+    }
+}

+ 102 - 0
x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/SubGroupTests.java

@@ -0,0 +1,102 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.profiling.action;
+
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+
+import java.io.IOException;
+
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
+
+public class SubGroupTests extends ESTestCase {
+    public void testToXContent() throws IOException {
+        XContentType contentType = randomFrom(XContentType.values());
+        // tag::noformat
+        XContentBuilder expectedRequest = XContentFactory.contentBuilder(contentType)
+            .startObject()
+                .startObject("transaction.name")
+                    .startObject("basket")
+                        .field("count", 7L)
+                    .endObject()
+                .endObject()
+            .endObject();
+        // end::noformat
+
+        XContentBuilder actualRequest = XContentFactory.contentBuilder(contentType);
+        actualRequest.startObject();
+        SubGroup g = SubGroup.root("transaction.name", false).addCount("basket", 7L);
+        g.toXContent(actualRequest, ToXContent.EMPTY_PARAMS);
+        actualRequest.endObject();
+
+        assertToXContentEquivalent(BytesReference.bytes(expectedRequest), BytesReference.bytes(actualRequest), contentType);
+    }
+
+    public void testRenderLegacyXContent() throws IOException {
+        XContentType contentType = randomFrom(XContentType.values());
+        // tag::noformat
+        XContentBuilder expectedRequest = XContentFactory.contentBuilder(contentType)
+            .startObject()
+                .field("basket", 7L)
+            .endObject();
+        // end::noformat
+
+        XContentBuilder actualRequest = XContentFactory.contentBuilder(contentType);
+        actualRequest.startObject();
+        SubGroup g = SubGroup.root("transaction.name", true).addCount("basket", 7L);
+        g.toXContent(actualRequest, ToXContent.EMPTY_PARAMS);
+        actualRequest.endObject();
+
+        assertToXContentEquivalent(BytesReference.bytes(expectedRequest), BytesReference.bytes(actualRequest), contentType);
+    }
+
+    public void testMergeNoCommonRoot() {
+        SubGroup root1 = SubGroup.root("transaction.name", false);
+        SubGroup root2 = SubGroup.root("service.name", false);
+
+        SubGroup toMerge = root1.copy();
+
+        toMerge.merge(root2);
+
+        assertEquals(root1, toMerge);
+    }
+
+    public void testMergeIdenticalTree() {
+        SubGroup g = SubGroup.root("transaction.name", false);
+        g.addCount("basket", 5L);
+        g.addCount("checkout", 7L);
+
+        SubGroup g2 = g.copy();
+
+        g.merge(g2);
+
+        assertEquals(Long.valueOf(10L), g.getCount("basket"));
+        assertEquals(Long.valueOf(14L), g.getCount("checkout"));
+    }
+
+    public void testMergeMixedTree() {
+        SubGroup g1 = SubGroup.root("transaction.name", false);
+        g1.addCount("basket", 5L);
+        g1.addCount("checkout", 7L);
+
+        SubGroup g2 = SubGroup.root("transaction.name", false);
+        g2.addCount("catalog", 8L);
+        g2.addCount("basket", 5L);
+        g2.addCount("checkout", 2L);
+
+        g1.merge(g2);
+
+        assertEquals(Long.valueOf(8L), g1.getCount("catalog"));
+        assertEquals(Long.valueOf(10L), g1.getCount("basket"));
+        assertEquals(Long.valueOf(9L), g1.getCount("checkout"));
+    }
+}

+ 29 - 26
x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/TopNFunctionTests.java

@@ -16,7 +16,6 @@ import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
 
 import java.io.IOException;
-import java.util.Map;
 
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
 
@@ -35,31 +34,35 @@ public class TopNFunctionTests extends ESTestCase {
         String frameGroupID = FrameGroupID.create(fileID, addressOrLine, exeFilename, sourceFilename, functionName);
 
         XContentType contentType = randomFrom(XContentType.values());
+        // tag::noformat
         XContentBuilder expectedRequest = XContentFactory.contentBuilder(contentType)
             .startObject()
-            .field("id", frameGroupID)
-            .field("rank", 1)
-            .startObject("frame")
-            .field("frame_type", frameType)
-            .field("inline", inline)
-            .field("address_or_line", addressOrLine)
-            .field("function_name", functionName)
-            .field("file_name", sourceFilename)
-            .field("line_number", sourceLine)
-            .field("executable_file_name", exeFilename)
-            .endObject()
-            .field("sub_groups", Map.of("basket", 7L))
-            .field("self_count", 1)
-            .field("total_count", 10)
-            .field("self_annual_co2_tons")
-            .rawValue("2.2000")
-            .field("total_annual_co2_tons")
-            .rawValue("22.0000")
-            .field("self_annual_costs_usd")
-            .rawValue("12.0000")
-            .field("total_annual_costs_usd")
-            .rawValue("120.0000")
+                .field("id", frameGroupID)
+                .field("rank", 1)
+                .startObject("frame")
+                    .field("frame_type", frameType)
+                    .field("inline", inline)
+                    .field("address_or_line", addressOrLine)
+                    .field("function_name", functionName)
+                    .field("file_name", sourceFilename)
+                    .field("line_number", sourceLine)
+                    .field("executable_file_name", exeFilename)
+                .endObject()
+                .startObject("sub_groups")
+                    .startObject("transaction.name")
+                        .startObject("basket")
+                            .field("count", 7L)
+                        .endObject()
+                    .endObject()
+                .endObject()
+                .field("self_count", 1)
+                .field("total_count", 10)
+                .field("self_annual_co2_tons").rawValue("2.2000")
+                .field("total_annual_co2_tons").rawValue("22.0000")
+                .field("self_annual_costs_usd").rawValue("12.0000")
+                .field("total_annual_costs_usd").rawValue("120.0000")
             .endObject();
+        // end::noformat
 
         XContentBuilder actualRequest = XContentFactory.contentBuilder(contentType);
         TopNFunction topNFunction = new TopNFunction(
@@ -78,7 +81,7 @@ public class TopNFunctionTests extends ESTestCase {
             22.0d,
             12.0d,
             120.0d,
-            Map.of("basket", 7L)
+            SubGroup.root("transaction.name", false).addCount("basket", 7L)
         );
         topNFunction.toXContent(actualRequest, ToXContent.EMPTY_PARAMS);
 
@@ -113,8 +116,8 @@ public class TopNFunctionTests extends ESTestCase {
             4.0d,
             23.2d,
             12.0d,
-            Map.of("checkout", 4L, "basket", 12L)
+            SubGroup.root("transaction.name", false).addCount("checkout", 4L).addCount("basket", 12L)
         );
-        EqualsHashCodeTestUtils.checkEqualsAndHashCode(topNFunction, (TopNFunction::clone));
+        EqualsHashCodeTestUtils.checkEqualsAndHashCode(topNFunction, (TopNFunction::copy));
     }
 }

+ 1 - 2
x-pack/plugin/profiling/src/test/java/org/elasticsearch/xpack/profiling/action/TransportGetTopNFunctionsActionTests.java

@@ -9,7 +9,6 @@ package org.elasticsearch.xpack.profiling.action;
 
 import org.elasticsearch.test.ESTestCase;
 
-import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 
@@ -165,7 +164,7 @@ public class TransportGetTopNFunctionsActionTests extends ESTestCase {
             annualCO2TonsInclusive,
             annualCostsUSDExclusive,
             annualCostsUSDInclusive,
-            Collections.emptyMap()
+            null
         );
     }