Răsfoiți Sursa

[ML] add the ability to include and exclude values in Frequent items (#92414)

This PR adds include and excludes to frequent items. This will allow to filter values from the analysis.
Hendrik Muhs 2 ani în urmă
părinte
comite
b9c0315d24

+ 5 - 0
docs/changelog/92414.yaml

@@ -0,0 +1,5 @@
+pr: 92414
+summary: Add the ability to include and exclude values in Frequent items
+area: Machine Learning
+type: enhancement
+issues: []

+ 6 - 0
docs/reference/aggregations/bucket/frequent-items-aggregation.asciidoc

@@ -66,6 +66,12 @@ fields.
 If the combined cardinality of the analyzed fields are high, then the 
 aggregation might require a significant amount of system resources.
 
+It is possible to filter the values for each field. This can be done using the
+`include` and `exclude` parameters which are based on a regular expression string
+or arrays of exact terms. This functionality mirrors the features described in the
+<<search-aggregations-bucket-terms-aggregation,terms aggregation>> documentation.
+Filtered values are removed from the analysis and therefore reduce the runtime.
+
 [discrete]
 [[frequent-items-minimum-set-size]]
 ==== Minimum set size

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetsAggregationBuilder.java

@@ -65,7 +65,7 @@ public final class FrequentItemSetsAggregationBuilder extends AbstractAggregatio
             false,  // timezone aware
             false,  // filtered (not defined per field, but for all fields below)
             false,  // format
-            false   // includes and excludes
+            true    // includes and excludes
         );
         PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, n) -> fieldsParser.parse(p, null).build(), FIELDS);
         PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), MINIMUM_SUPPORT);

+ 23 - 16
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetsAggregatorFactory.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.ml.aggs.frequentitemsets;
 
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.aggregations.AggregationExecutionException;
@@ -15,6 +16,7 @@ import org.elasticsearch.search.aggregations.Aggregator;
 import org.elasticsearch.search.aggregations.AggregatorFactories.Builder;
 import org.elasticsearch.search.aggregations.AggregatorFactory;
 import org.elasticsearch.search.aggregations.CardinalityUpperBound;
+import org.elasticsearch.search.aggregations.bucket.terms.IncludeExclude;
 import org.elasticsearch.search.aggregations.support.AggregationContext;
 import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
 import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
@@ -27,6 +29,8 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.core.Tuple.tuple;
+
 /**
  * Factory for frequent items aggregation
  *
@@ -60,7 +64,7 @@ public class FrequentItemSetsAggregatorFactory extends AggregatorFactory {
     private final double minimumSupport;
     private final int minimumSetSize;
     private final int size;
-    private final QueryBuilder filter;
+    private final QueryBuilder documentFilter;
 
     public FrequentItemSetsAggregatorFactory(
         String name,
@@ -72,32 +76,35 @@ public class FrequentItemSetsAggregatorFactory extends AggregatorFactory {
         double minimumSupport,
         int minimumSetSize,
         int size,
-        QueryBuilder filter
+        QueryBuilder documentFilter
     ) throws IOException {
         super(name, context, parent, subFactoriesBuilder, metadata);
         this.fields = fields;
         this.minimumSupport = minimumSupport;
         this.minimumSetSize = minimumSetSize;
         this.size = size;
-        this.filter = filter;
+        this.documentFilter = documentFilter;
     }
 
     @Override
     protected Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map<String, Object> metadata)
         throws IOException {
 
-        List<ValuesSourceConfig> configs = new ArrayList<>(fields.size());
+        List<Tuple<ValuesSourceConfig, IncludeExclude>> configsAndFilters = new ArrayList<>(fields.size());
         for (MultiValuesSourceFieldConfig field : fields) {
-            configs.add(
-                ValuesSourceConfig.resolve(
-                    context,
-                    field.getUserValueTypeHint(),
-                    field.getFieldName(),
-                    field.getScript(),
-                    field.getMissing(),
-                    field.getTimeZone(),
-                    field.getFormat(),
-                    CoreValuesSourceType.KEYWORD
+            configsAndFilters.add(
+                tuple(
+                    ValuesSourceConfig.resolve(
+                        context,
+                        field.getUserValueTypeHint(),
+                        field.getFieldName(),
+                        field.getScript(),
+                        field.getMissing(),
+                        field.getTimeZone(),
+                        field.getFormat(),
+                        CoreValuesSourceType.KEYWORD
+                    ),
+                    field.getIncludeExclude()
                 )
             );
         }
@@ -113,8 +120,8 @@ public class FrequentItemSetsAggregatorFactory extends AggregatorFactory {
                 parent,
                 metadata,
                 new EclatMapReducer(FrequentItemSetsAggregationBuilder.NAME, minimumSupport, minimumSetSize, size, context.profiling()),
-                configs,
-                filter
+                configsAndFilters,
+                documentFilter
             ) {
         };
     }

+ 13 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/mr/ItemSetMapReduceAggregator.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.common.util.LongObjectPagedHashMap;
 import org.elasticsearch.core.Releasables;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.search.aggregations.AggregationExecutionContext;
 import org.elasticsearch.search.aggregations.Aggregator;
@@ -26,6 +27,7 @@ import org.elasticsearch.search.aggregations.CardinalityUpperBound;
 import org.elasticsearch.search.aggregations.InternalAggregation;
 import org.elasticsearch.search.aggregations.LeafBucketCollector;
 import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
+import org.elasticsearch.search.aggregations.bucket.terms.IncludeExclude;
 import org.elasticsearch.search.aggregations.support.AggregationContext;
 import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
 import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
@@ -47,7 +49,7 @@ public abstract class ItemSetMapReduceAggregator<
     Result extends ToXContent & Writeable> extends AggregatorBase {
 
     private final List<ItemSetMapReduceValueSource> extractors;
-    private final Weight weightFilter;
+    private final Weight weightDocumentFilter;
     private final List<Field> fields;
     private final AbstractItemSetMapReducer<MapContext, MapFinalContext, ReduceContext, Result> mapReducer;
     private final BigArrays bigArraysForMapReduce;
@@ -62,8 +64,8 @@ public abstract class ItemSetMapReduceAggregator<
         Aggregator parent,
         Map<String, Object> metadata,
         AbstractItemSetMapReducer<MapContext, MapFinalContext, ReduceContext, Result> mapReducer,
-        List<ValuesSourceConfig> configs,
-        QueryBuilder filter
+        List<Tuple<ValuesSourceConfig, IncludeExclude>> configsAndValueFilters,
+        QueryBuilder documentFilter
     ) throws IOException {
         super(name, AggregatorFactories.EMPTY, context, parent, CardinalityUpperBound.NONE, metadata);
 
@@ -72,12 +74,14 @@ public abstract class ItemSetMapReduceAggregator<
         IndexSearcher contextSearcher = context.searcher();
 
         int id = 0;
-        this.weightFilter = filter != null
-            ? contextSearcher.createWeight(contextSearcher.rewrite(context.buildQuery(filter)), ScoreMode.COMPLETE_NO_SCORES, 1f)
+        this.weightDocumentFilter = documentFilter != null
+            ? contextSearcher.createWeight(contextSearcher.rewrite(context.buildQuery(documentFilter)), ScoreMode.COMPLETE_NO_SCORES, 1f)
             : null;
 
-        for (ValuesSourceConfig c : configs) {
-            ItemSetMapReduceValueSource e = context.getValuesSourceRegistry().getAggregator(registryKey, c).build(c, id++);
+        for (var c : configsAndValueFilters) {
+            ItemSetMapReduceValueSource e = context.getValuesSourceRegistry()
+                .getAggregator(registryKey, c.v1())
+                .build(c.v1(), id++, c.v2());
             if (e.getField().getName() != null) {
                 fields.add(e.getField());
                 extractors.add(e);
@@ -115,10 +119,10 @@ public abstract class ItemSetMapReduceAggregator<
     @Override
     protected LeafBucketCollector getLeafCollector(AggregationExecutionContext ctx, LeafBucketCollector sub) throws IOException {
 
-        final Bits bits = weightFilter != null
+        final Bits bits = weightDocumentFilter != null
             ? Lucene.asSequentialAccessBits(
                 ctx.getLeafReaderContext().reader().maxDoc(),
-                weightFilter.scorerSupplier(ctx.getLeafReaderContext())
+                weightDocumentFilter.scorerSupplier(ctx.getLeafReaderContext())
             )
             : null;
 

+ 18 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/mr/ItemSetMapReduceValueSource.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
 import org.elasticsearch.search.DocValueFormat;
+import org.elasticsearch.search.aggregations.bucket.terms.IncludeExclude;
 import org.elasticsearch.search.aggregations.support.ValuesSource;
 import org.elasticsearch.search.aggregations.support.ValuesSource.Bytes;
 import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
@@ -35,7 +36,7 @@ public abstract class ItemSetMapReduceValueSource {
 
     @FunctionalInterface
     public interface ValueSourceSupplier {
-        ItemSetMapReduceValueSource build(ValuesSourceConfig config, int id);
+        ItemSetMapReduceValueSource build(ValuesSourceConfig config, int id, IncludeExclude includeExclude);
     }
 
     enum ValueFormatter {
@@ -120,7 +121,7 @@ public abstract class ItemSetMapReduceValueSource {
             return Objects.hash(id, valueFormatter, name, format);
         }
 
-    };
+    }
 
     private final Field field;
 
@@ -128,7 +129,6 @@ public abstract class ItemSetMapReduceValueSource {
 
     ItemSetMapReduceValueSource(ValuesSourceConfig config, int id, ValueFormatter valueFormatter) {
         String fieldName = config.fieldContext() != null ? config.fieldContext().field() : null;
-
         if (Strings.isNullOrEmpty(fieldName)) {
             throw new IllegalArgumentException("scripts are not supported");
         }
@@ -142,10 +142,12 @@ public abstract class ItemSetMapReduceValueSource {
 
     public static class KeywordValueSource extends ItemSetMapReduceValueSource {
         private final ValuesSource.Bytes source;
+        private final IncludeExclude.StringFilter stringFilter;
 
-        public KeywordValueSource(ValuesSourceConfig config, int id) {
+        public KeywordValueSource(ValuesSourceConfig config, int id, IncludeExclude includeExclude) {
             super(config, id, ValueFormatter.BYTES_REF);
             this.source = (Bytes) config.getValuesSource();
+            this.stringFilter = includeExclude == null ? null : includeExclude.convertToStringFilter(config.format());
         }
 
         @Override
@@ -157,20 +159,26 @@ public abstract class ItemSetMapReduceValueSource {
                 List<Object> objects = new ArrayList<>(valuesCount);
 
                 for (int i = 0; i < valuesCount; ++i) {
-                    objects.add(BytesRef.deepCopyOf(values.nextValue()));
+                    BytesRef v = values.nextValue();
+                    if (stringFilter == null || stringFilter.accept(v)) {
+                        objects.add(BytesRef.deepCopyOf(v));
+                    }
                 }
                 return new Tuple<>(getField(), objects);
             }
             return new Tuple<>(getField(), Collections.emptyList());
         }
+
     }
 
     public static class NumericValueSource extends ItemSetMapReduceValueSource {
         private final ValuesSource.Numeric source;
+        private final IncludeExclude.LongFilter longFilter;
 
-        public NumericValueSource(ValuesSourceConfig config, int id) {
+        public NumericValueSource(ValuesSourceConfig config, int id, IncludeExclude includeExclude) {
             super(config, id, ValueFormatter.LONG);
             this.source = (Numeric) config.getValuesSource();
+            this.longFilter = includeExclude == null ? null : includeExclude.convertToLongFilter(config.format());
         }
 
         @Override
@@ -182,7 +190,10 @@ public abstract class ItemSetMapReduceValueSource {
                 List<Object> objects = new ArrayList<>(valuesCount);
 
                 for (int i = 0; i < valuesCount; ++i) {
-                    objects.add(values.nextValue());
+                    long v = values.nextValue();
+                    if (longFilter == null || longFilter.accept(v)) {
+                        objects.add(v);
+                    }
                 }
                 return new Tuple<>(getField(), objects);
             }

+ 31 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetsAggregationBuilderTests.java

@@ -15,6 +15,7 @@ import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.aggregations.AggregationBuilders;
 import org.elasticsearch.search.aggregations.AggregatorFactories;
 import org.elasticsearch.search.aggregations.BaseAggregationBuilder;
+import org.elasticsearch.search.aggregations.bucket.terms.IncludeExclude;
 import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
 import org.elasticsearch.test.AbstractXContentSerializingTestCase;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
@@ -27,6 +28,7 @@ import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+import java.util.TreeSet;
 import java.util.stream.Collectors;
 
 import static org.hamcrest.Matchers.hasSize;
@@ -47,6 +49,10 @@ public class FrequentItemSetsAggregationBuilderTests extends AbstractXContentSer
                 field.setMissing(randomAlphaOfLength(5));
             }
 
+            if (randomBoolean()) {
+                field.setIncludeExclude(randomIncludeExclude());
+            }
+
             return field.build();
         }).collect(Collectors.toList());
 
@@ -186,4 +192,29 @@ public class FrequentItemSetsAggregationBuilderTests extends AbstractXContentSer
         assertEquals("Aggregator [fi] of type [frequent_items] cannot accept sub-aggregations", e.getMessage());
     }
 
+    private static IncludeExclude randomIncludeExclude() {
+        switch (randomInt(7)) {
+            case 0:
+                return new IncludeExclude("incl*de", null, null, null);
+            case 1:
+                return new IncludeExclude("incl*de", "excl*de", null, null);
+            case 2:
+                return new IncludeExclude("incl*de", null, null, new TreeSet<>(Set.of(newBytesRef("exclude"))));
+            case 3:
+                return new IncludeExclude(null, "excl*de", null, null);
+            case 4:
+                return new IncludeExclude(null, "excl*de", new TreeSet<>(Set.of(newBytesRef("include"))), null);
+            case 5:
+                return new IncludeExclude(null, null, new TreeSet<>(Set.of(newBytesRef("include"))), null);
+            case 6:
+                return new IncludeExclude(
+                    null,
+                    null,
+                    new TreeSet<>(Set.of(newBytesRef("include"))),
+                    new TreeSet<>(Set.of(newBytesRef("exclude")))
+                );
+            default:
+                return new IncludeExclude(null, null, null, new TreeSet<>(Set.of(newBytesRef("exclude"))));
+        }
+    }
 }

+ 138 - 12
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/frequentitemsets/FrequentItemSetsAggregatorTests.java

@@ -19,6 +19,7 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.network.InetAddresses;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.time.DateFormatter;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.index.mapper.DateFieldMapper;
 import org.elasticsearch.index.mapper.IpFieldMapper;
 import org.elasticsearch.index.mapper.KeywordFieldMapper;
@@ -27,6 +28,7 @@ import org.elasticsearch.index.mapper.NumberFieldMapper;
 import org.elasticsearch.plugins.SearchPlugin;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregatorTestCase;
+import org.elasticsearch.search.aggregations.bucket.terms.IncludeExclude;
 import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
 import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig;
 import org.elasticsearch.search.aggregations.support.ValuesSourceType;
@@ -39,11 +41,16 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.TreeSet;
 import java.util.stream.Collectors;
 
+import static org.elasticsearch.core.Tuple.tuple;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 
 public class FrequentItemSetsAggregatorTests extends AggregatorTestCase {
@@ -87,7 +94,14 @@ public class FrequentItemSetsAggregatorTests extends AggregatorTestCase {
     public void testKeywordsArray() throws IOException {
         List<MultiValuesSourceFieldConfig> fields = new ArrayList<>();
 
-        fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD1).build());
+        String exclude = randomBoolean() ? randomFrom("item-3", "item-4", "item-5", "item-99") : null;
+        fields.add(
+            new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD1)
+                .setIncludeExclude(
+                    exclude != null ? new IncludeExclude(null, null, null, new TreeSet<>(Set.of(new BytesRef(exclude)))) : null
+                )
+                .build()
+        );
 
         double minimumSupport = randomDoubleBetween(0.13, 0.41, true);
         int minimumSetSize = randomIntBetween(2, 5);
@@ -205,19 +219,62 @@ public class FrequentItemSetsAggregatorTests extends AggregatorTestCase {
             );
         }, (InternalItemSetMapReduceAggregation<?, ?, ?, EclatResult> results) -> {
             assertNotNull(results);
-            assertResults(expectedResults, results.getMapReduceResult().getFrequentItemSets(), minimumSupport, minimumSetSize, size);
+            assertResults(
+                expectedResults,
+                results.getMapReduceResult().getFrequentItemSets(),
+                minimumSupport,
+                minimumSetSize,
+                size,
+                exclude,
+                null
+            );
         }, new AggTestConfig(builder, keywordType).withQuery(query));
     }
 
     public void testMixedSingleValues() throws IOException {
         List<MultiValuesSourceFieldConfig> fields = new ArrayList<>();
 
-        fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD1).build());
-        fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD2).build());
-        fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD3).build());
+        String stringExclude = randomBoolean() ? randomFrom("host-2", "192.168.0.1", "client-2", "127.0.0.1") : null;
+        Integer intExclude = randomBoolean() ? randomIntBetween(0, 10) : null;
+
+        fields.add(
+            new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD1)
+                .setIncludeExclude(
+                    stringExclude != null ? new IncludeExclude(null, null, null, new TreeSet<>(Set.of(new BytesRef(stringExclude)))) : null
+                )
+                .build()
+        );
+        fields.add(
+            new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD2)
+                .setIncludeExclude(
+                    stringExclude != null ? new IncludeExclude(null, null, null, new TreeSet<>(Set.of(new BytesRef(stringExclude)))) : null
+                )
+                .build()
+        );
+        fields.add(
+            new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD3)
+                .setIncludeExclude(
+                    stringExclude != null ? new IncludeExclude(null, null, null, new TreeSet<>(Set.of(new BytesRef(stringExclude)))) : null
+                )
+                .build()
+        );
         fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(FLOAT_FIELD).build());
-        fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(INT_FIELD).build());
-        fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(IP_FIELD).build());
+        fields.add(
+            new MultiValuesSourceFieldConfig.Builder().setFieldName(INT_FIELD)
+                .setIncludeExclude(
+                    intExclude != null
+                        ? new IncludeExclude(null, null, null, new TreeSet<>(Set.of(new BytesRef(String.valueOf(intExclude)))))
+                        : null
+                )
+                .build()
+        );
+        fields.add(
+            new MultiValuesSourceFieldConfig.Builder().setFieldName(IP_FIELD)
+                .setIncludeExclude(
+                    stringExclude != null ? new IncludeExclude(null, null, null, new TreeSet<>(Set.of(new BytesRef(stringExclude)))) : null
+                )
+                .build()
+        );
 
         double minimumSupport = randomDoubleBetween(0.13, 0.51, true);
         int minimumSetSize = randomIntBetween(2, 6);
@@ -382,7 +439,15 @@ public class FrequentItemSetsAggregatorTests extends AggregatorTestCase {
             );
         }, (InternalItemSetMapReduceAggregation<?, ?, ?, EclatResult> results) -> {
             assertNotNull(results);
-            assertResults(expectedResults, results.getMapReduceResult().getFrequentItemSets(), minimumSupport, minimumSetSize, size);
+            assertResults(
+                expectedResults,
+                results.getMapReduceResult().getFrequentItemSets(),
+                minimumSupport,
+                minimumSetSize,
+                size,
+                stringExclude,
+                intExclude
+            );
         }, new AggTestConfig(builder, keywordType1, keywordType2, keywordType3, intType, floatType, ipType).withQuery(query));
 
     }
@@ -390,10 +455,18 @@ public class FrequentItemSetsAggregatorTests extends AggregatorTestCase {
     public void testSingleValueWithDate() throws IOException {
         List<MultiValuesSourceFieldConfig> fields = new ArrayList<>();
 
+        String dateExclude = randomBoolean() ? randomFrom("2022-06-02", "2022-06-03", "1970-01-01") : null;
+
         fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD1).build());
         fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD2).build());
         fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(KEYWORD_FIELD3).build());
-        fields.add(new MultiValuesSourceFieldConfig.Builder().setFieldName(DATE_FIELD).build());
+        fields.add(
+            new MultiValuesSourceFieldConfig.Builder().setFieldName(DATE_FIELD)
+                .setIncludeExclude(
+                    dateExclude != null ? new IncludeExclude(null, null, null, new TreeSet<>(Set.of(new BytesRef(dateExclude)))) : null
+                )
+                .build()
+        );
 
         double minimumSupport = randomDoubleBetween(0.13, 0.51, true);
         int minimumSetSize = randomIntBetween(2, 6);
@@ -571,7 +644,15 @@ public class FrequentItemSetsAggregatorTests extends AggregatorTestCase {
             );
         }, (InternalItemSetMapReduceAggregation<?, ?, ?, EclatResult> results) -> {
             assertNotNull(results);
-            assertResults(expectedResults, results.getMapReduceResult().getFrequentItemSets(), minimumSupport, minimumSetSize, size);
+            assertResults(
+                expectedResults,
+                results.getMapReduceResult().getFrequentItemSets(),
+                minimumSupport,
+                minimumSetSize,
+                size,
+                dateExclude,
+                null
+            );
         }, new AggTestConfig(builder, keywordType1, keywordType2, keywordType3, dateType, ipType).withQuery(query));
 
     }
@@ -594,11 +675,45 @@ public class FrequentItemSetsAggregatorTests extends AggregatorTestCase {
         );
     }
 
-    private void assertResults(List<FrequentItemSet> expected, FrequentItemSet[] actual, double minSupport, int minimumSetSize, int size) {
+    private void assertResults(
+        List<FrequentItemSet> expected,
+        FrequentItemSet[] actual,
+        double minSupport,
+        int minimumSetSize,
+        int size,
+        String stringExclude,
+        Integer intExclude
+    ) {
         // sort the expected results descending by doc count
         expected.get(0).getFields().values().stream().mapToLong(v -> v.stream().count()).sum();
 
-        List<FrequentItemSet> filteredExpected = expected.stream()
+        List<FrequentItemSet> filteredExpectedWithDups = expected.stream()
+            .map(
+                fi -> new FrequentItemSet(
+                    fi.getFields()
+                        .entrySet()
+                        .stream()
+                        .map(
+                            // filter the string exclude from the list of objects
+                            keyValues -> tuple(
+                                keyValues.getKey(),
+                                keyValues.getValue().stream().filter(v -> v.equals(stringExclude) == false).collect(Collectors.toList())
+                            )
+                        )
+                        .map(
+                            // filter the int exclude
+                            keyValues -> tuple(
+                                keyValues.v1(),
+                                keyValues.v2().stream().filter(v -> v.equals(intExclude) == false).collect(Collectors.toList())
+                            )
+                        )
+                        // after filtering out excludes the list of objects might be empty
+                        .filter(t -> t.v2().size() > 0)
+                        .collect(Collectors.toMap(Tuple::v1, Tuple::v2)),
+                    fi.getDocCount(),
+                    fi.getSupport()
+                )
+            )
             .filter(fi -> fi.getSupport() >= minSupport)
             .filter(
                 fi -> {
@@ -638,6 +753,17 @@ public class FrequentItemSetsAggregatorTests extends AggregatorTestCase {
             })
             .collect(Collectors.toList());
 
+        // after removing excluded items there might be duplicate entries, which need to be collapsed, we do this by very simple hashing
+        List<FrequentItemSet> filteredExpected = new ArrayList<>();
+        Set<Integer> valuesSeen = new HashSet<>();
+        for (FrequentItemSet fi : filteredExpectedWithDups) {
+            if (valuesSeen.add(
+                fi.getFields().entrySet().stream().mapToInt(v -> Objects.hash(v.getKey(), v.getValue())).reduce(13, (t, s) -> 41 * t + s)
+            )) {
+                filteredExpected.add(fi);
+            }
+        }
+
         // if size applies, cut the list, however if sets have the same number of items it's unclear which ones are returned
         int additionalSetsThatShareTheSameDocCount = 0;
         if (size < filteredExpected.size()) {

+ 62 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/frequent_items_agg.yml

@@ -368,6 +368,68 @@ setup:
   - match: { aggregations.fi.buckets.0.support: 0.4 }
   - match: { aggregations.fi.buckets.0.key.error_message: ["compressor low pressure"] }
 
+---
+"Test frequent items exclude":
+
+  - do:
+      search:
+        index: store
+        body: >
+          {
+            "size": 0,
+            "aggs": {
+              "fi": {
+                "frequent_items": {
+                  "minimum_set_size": 3,
+                  "minimum_support": 0.3,
+                  "fields": [
+                    {"field": "features"},
+                    {
+                      "field": "error_message",
+                      "exclude": "engine overheated"
+                    }
+                  ]
+                }
+              }
+            }
+          }
+  - length: { aggregations.fi.buckets: 3 }
+  - match: { aggregations.fi.buckets.0.doc_count: 5 }
+  - match: { aggregations.fi.buckets.0.support: 0.5 }
+  - match: { aggregations.fi.buckets.0.key.error_message: ["compressor low pressure"] }
+  - match: { aggregations.fi.buckets.1.doc_count: 3 }
+  - match: { aggregations.fi.buckets.1.support: 0.3 }
+
+---
+"Test frequent items include":
+
+  - do:
+      search:
+        index: store
+        body: >
+          {
+            "size": 0,
+            "aggs": {
+              "fi": {
+                "frequent_items": {
+                  "minimum_set_size": 3,
+                  "minimum_support": 0.3,
+                  "fields": [
+                    {"field": "features"},
+                    {
+                      "field": "error_message",
+                      "include": "en.*ed"
+                    }
+                  ]
+                }
+              }
+            }
+          }
+  - length: { aggregations.fi.buckets: 3 }
+  - match: { aggregations.fi.buckets.0.doc_count: 4 }
+  - match: { aggregations.fi.buckets.0.support: 0.4 }
+  - match: { aggregations.fi.buckets.0.key.error_message: ["engine overheated"] }
+
 ---
 "Test frequent items unsupported types":
   - do: