Browse Source

[ML] Add per-partition categorization option (#57683)

This PR adds the initial Java side changes to enable
use of the per-partition categorization functionality
added in elastic/ml-cpp#1293.

There will be a followup change to complete the work,
as there cannot be any end-to-end integration tests
until elastic/ml-cpp#1293 is merged, and also
elastic/ml-cpp#1293 does not implement some of the
more peripheral functionality, like stop_on_warn and
per-partition stats documents.

The changes so far cover REST APIs, results object
formats, HLRC and docs.
David Roberts 5 years ago
parent
commit
605b4d0ea9
47 changed files with 1086 additions and 142 deletions
  1. 23 4
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetCategoriesRequest.java
  2. 26 5
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/config/AnalysisConfig.java
  3. 29 2
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/config/JobUpdate.java
  4. 95 0
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/config/PerPartitionCategorizationConfig.java
  5. 32 1
      client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/CategoryDefinition.java
  6. 3 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetCategoriesRequestTests.java
  7. 9 1
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/job/config/AnalysisConfigTests.java
  8. 42 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/job/config/PerPartitionCategorizationConfigTests.java
  9. 6 0
      client/rest-high-level/src/test/java/org/elasticsearch/client/ml/job/results/CategoryDefinitionTests.java
  10. 23 3
      docs/reference/ml/anomaly-detection/apis/get-category.asciidoc
  11. 18 0
      docs/reference/ml/anomaly-detection/apis/put-job.asciidoc
  12. 18 0
      docs/reference/ml/anomaly-detection/apis/update-job.asciidoc
  13. 19 0
      docs/reference/ml/ml-shared.asciidoc
  14. 25 3
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetCategoriesAction.java
  15. 19 3
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateProcessAction.java
  16. 78 5
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java
  17. 46 6
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/JobUpdate.java
  18. 105 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/PerPartitionCategorizationConfig.java
  19. 1 6
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/AnomalyRecord.java
  20. 47 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/CategoryDefinition.java
  21. 4 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java
  22. 10 0
      x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json
  23. 3 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetCategoriesRequestTests.java
  24. 1 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateJobActionRequestTests.java
  25. 9 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateProcessActionRequestTests.java
  26. 84 3
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfigTests.java
  27. 18 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/JobUpdateTests.java
  28. 45 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/PerPartitionCategorizationConfigTests.java
  29. 6 3
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/AutodetectResultProcessorIT.java
  30. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetCategoriesAction.java
  31. 1 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateProcessAction.java
  32. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/UpdateJobProcessNotifier.java
  33. 16 7
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java
  34. 4 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java
  35. 9 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java
  36. 5 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/BlackHoleAutodetectProcess.java
  37. 7 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java
  38. 21 3
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/UpdateParams.java
  39. 18 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/UpdateProcessMessage.java
  40. 7 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriter.java
  41. 1 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/results/RestGetCategoriesAction.java
  42. 48 59
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProviderTests.java
  43. 8 6
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/UpdateParamsTests.java
  44. 19 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriterTests.java
  45. 18 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/results/CategoryDefinitionTests.java
  46. 4 0
      x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_categories.json
  47. 52 6
      x-pack/plugin/src/test/resources/rest-api-spec/test/ml/jobs_get_result_categories.yml

+ 23 - 4
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetCategoriesRequest.java

@@ -21,6 +21,7 @@ package org.elasticsearch.client.ml;
 import org.elasticsearch.client.Validatable;
 import org.elasticsearch.client.core.PageParams;
 import org.elasticsearch.client.ml.job.config.Job;
+import org.elasticsearch.client.ml.job.results.CategoryDefinition;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -34,8 +35,8 @@ import java.util.Objects;
  */
 public class GetCategoriesRequest implements Validatable, ToXContentObject {
 
-
-    public static final ParseField CATEGORY_ID = new ParseField("category_id");
+    public static final ParseField CATEGORY_ID = CategoryDefinition.CATEGORY_ID;
+    public static final ParseField PARTITION_FIELD_VALUE = CategoryDefinition.PARTITION_FIELD_VALUE;
 
     public static final ConstructingObjectParser<GetCategoriesRequest, Void> PARSER = new ConstructingObjectParser<>(
         "get_categories_request", a -> new GetCategoriesRequest((String) a[0]));
@@ -45,11 +46,13 @@ public class GetCategoriesRequest implements Validatable, ToXContentObject {
         PARSER.declareString(ConstructingObjectParser.constructorArg(), Job.ID);
         PARSER.declareLong(GetCategoriesRequest::setCategoryId, CATEGORY_ID);
         PARSER.declareObject(GetCategoriesRequest::setPageParams, PageParams.PARSER, PageParams.PAGE);
+        PARSER.declareString(GetCategoriesRequest::setPartitionFieldValue, PARTITION_FIELD_VALUE);
     }
 
     private final String jobId;
     private Long categoryId;
     private PageParams pageParams;
+    private String partitionFieldValue;
 
     /**
      * Constructs a request to retrieve category information from a given job
@@ -87,6 +90,18 @@ public class GetCategoriesRequest implements Validatable, ToXContentObject {
         this.pageParams = pageParams;
     }
 
+    public String getPartitionFieldValue() {
+        return partitionFieldValue;
+    }
+
+    /**
+     * Sets the partition field value
+     * @param partitionFieldValue the partition field value
+     */
+    public void setPartitionFieldValue(String partitionFieldValue) {
+        this.partitionFieldValue = partitionFieldValue;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
@@ -97,6 +112,9 @@ public class GetCategoriesRequest implements Validatable, ToXContentObject {
         if (pageParams != null) {
             builder.field(PageParams.PAGE.getPreferredName(), pageParams);
         }
+        if (partitionFieldValue != null) {
+            builder.field(PARTITION_FIELD_VALUE.getPreferredName(), partitionFieldValue);
+        }
         builder.endObject();
         return builder;
     }
@@ -112,11 +130,12 @@ public class GetCategoriesRequest implements Validatable, ToXContentObject {
         GetCategoriesRequest request = (GetCategoriesRequest) obj;
         return Objects.equals(jobId, request.jobId)
             && Objects.equals(categoryId, request.categoryId)
-            && Objects.equals(pageParams, request.pageParams);
+            && Objects.equals(pageParams, request.pageParams)
+            && Objects.equals(partitionFieldValue, request.partitionFieldValue);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(jobId, categoryId, pageParams);
+        return Objects.hash(jobId, categoryId, pageParams, partitionFieldValue);
     }
 }

+ 26 - 5
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/config/AnalysisConfig.java

@@ -56,6 +56,7 @@ public class AnalysisConfig implements ToXContentObject {
     public static final ParseField CATEGORIZATION_FIELD_NAME = new ParseField("categorization_field_name");
     public static final ParseField CATEGORIZATION_FILTERS = new ParseField("categorization_filters");
     public static final ParseField CATEGORIZATION_ANALYZER = CategorizationAnalyzerConfig.CATEGORIZATION_ANALYZER;
+    public static final ParseField PER_PARTITION_CATEGORIZATION = new ParseField("per_partition_categorization");
     public static final ParseField LATENCY = new ParseField("latency");
     public static final ParseField SUMMARY_COUNT_FIELD_NAME = new ParseField("summary_count_field_name");
     public static final ParseField DETECTORS = new ParseField("detectors");
@@ -78,6 +79,8 @@ public class AnalysisConfig implements ToXContentObject {
         PARSER.declareField(Builder::setCategorizationAnalyzerConfig,
             (p, c) -> CategorizationAnalyzerConfig.buildFromXContentFragment(p),
             CATEGORIZATION_ANALYZER, ObjectParser.ValueType.OBJECT_OR_STRING);
+        PARSER.declareObject(Builder::setPerPartitionCategorizationConfig, PerPartitionCategorizationConfig.PARSER,
+            PER_PARTITION_CATEGORIZATION);
         PARSER.declareString((builder, val) ->
             builder.setLatency(TimeValue.parseTimeValue(val, LATENCY.getPreferredName())), LATENCY);
         PARSER.declareString(Builder::setSummaryCountFieldName, SUMMARY_COUNT_FIELD_NAME);
@@ -92,6 +95,7 @@ public class AnalysisConfig implements ToXContentObject {
     private final String categorizationFieldName;
     private final List<String> categorizationFilters;
     private final CategorizationAnalyzerConfig categorizationAnalyzerConfig;
+    private final PerPartitionCategorizationConfig perPartitionCategorizationConfig;
     private final TimeValue latency;
     private final String summaryCountFieldName;
     private final List<Detector> detectors;
@@ -99,13 +103,15 @@ public class AnalysisConfig implements ToXContentObject {
     private final Boolean multivariateByFields;
 
     private AnalysisConfig(TimeValue bucketSpan, String categorizationFieldName, List<String> categorizationFilters,
-                           CategorizationAnalyzerConfig categorizationAnalyzerConfig, TimeValue latency, String summaryCountFieldName,
-                           List<Detector> detectors, List<String> influencers, Boolean multivariateByFields) {
+                           CategorizationAnalyzerConfig categorizationAnalyzerConfig,
+                           PerPartitionCategorizationConfig perPartitionCategorizationConfig, TimeValue latency,
+                           String summaryCountFieldName, List<Detector> detectors, List<String> influencers, Boolean multivariateByFields) {
         this.detectors = Collections.unmodifiableList(detectors);
         this.bucketSpan = bucketSpan;
         this.latency = latency;
         this.categorizationFieldName = categorizationFieldName;
         this.categorizationAnalyzerConfig = categorizationAnalyzerConfig;
+        this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
         this.categorizationFilters = categorizationFilters == null ? null : Collections.unmodifiableList(categorizationFilters);
         this.summaryCountFieldName = summaryCountFieldName;
         this.influencers = Collections.unmodifiableList(influencers);
@@ -133,6 +139,10 @@ public class AnalysisConfig implements ToXContentObject {
         return categorizationAnalyzerConfig;
     }
 
+    public PerPartitionCategorizationConfig getPerPartitionCategorizationConfig() {
+        return perPartitionCategorizationConfig;
+    }
+
     /**
      * The latency interval during which out-of-order records should be handled.
      *
@@ -226,6 +236,9 @@ public class AnalysisConfig implements ToXContentObject {
             // gets written as a single string.
             categorizationAnalyzerConfig.toXContent(builder, params);
         }
+        if (perPartitionCategorizationConfig != null) {
+            builder.field(PER_PARTITION_CATEGORIZATION.getPreferredName(), perPartitionCategorizationConfig);
+        }
         if (latency != null) {
             builder.field(LATENCY.getPreferredName(), latency.getStringRep());
         }
@@ -261,6 +274,7 @@ public class AnalysisConfig implements ToXContentObject {
             Objects.equals(categorizationFieldName, that.categorizationFieldName) &&
             Objects.equals(categorizationFilters, that.categorizationFilters) &&
             Objects.equals(categorizationAnalyzerConfig, that.categorizationAnalyzerConfig) &&
+            Objects.equals(perPartitionCategorizationConfig, that.perPartitionCategorizationConfig) &&
             Objects.equals(summaryCountFieldName, that.summaryCountFieldName) &&
             Objects.equals(detectors, that.detectors) &&
             Objects.equals(influencers, that.influencers) &&
@@ -270,8 +284,8 @@ public class AnalysisConfig implements ToXContentObject {
     @Override
     public int hashCode() {
         return Objects.hash(
-            bucketSpan, categorizationFieldName, categorizationFilters, categorizationAnalyzerConfig, latency,
-            summaryCountFieldName, detectors, influencers, multivariateByFields);
+            bucketSpan, categorizationFieldName, categorizationFilters, categorizationAnalyzerConfig, perPartitionCategorizationConfig,
+            latency, summaryCountFieldName, detectors, influencers, multivariateByFields);
     }
 
     public static Builder builder(List<Detector> detectors) {
@@ -286,6 +300,7 @@ public class AnalysisConfig implements ToXContentObject {
         private String categorizationFieldName;
         private List<String> categorizationFilters;
         private CategorizationAnalyzerConfig categorizationAnalyzerConfig;
+        private PerPartitionCategorizationConfig perPartitionCategorizationConfig;
         private String summaryCountFieldName;
         private List<String> influencers = new ArrayList<>();
         private Boolean multivariateByFields;
@@ -302,6 +317,7 @@ public class AnalysisConfig implements ToXContentObject {
             this.categorizationFilters = analysisConfig.categorizationFilters == null ? null
                 : new ArrayList<>(analysisConfig.categorizationFilters);
             this.categorizationAnalyzerConfig = analysisConfig.categorizationAnalyzerConfig;
+            this.perPartitionCategorizationConfig = analysisConfig.perPartitionCategorizationConfig;
             this.summaryCountFieldName = analysisConfig.summaryCountFieldName;
             this.influencers = new ArrayList<>(analysisConfig.influencers);
             this.multivariateByFields = analysisConfig.multivariateByFields;
@@ -351,6 +367,11 @@ public class AnalysisConfig implements ToXContentObject {
             return this;
         }
 
+        public Builder setPerPartitionCategorizationConfig(PerPartitionCategorizationConfig perPartitionCategorizationConfig) {
+            this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
+            return this;
+        }
+
         public Builder setSummaryCountFieldName(String summaryCountFieldName) {
             this.summaryCountFieldName = summaryCountFieldName;
             return this;
@@ -369,7 +390,7 @@ public class AnalysisConfig implements ToXContentObject {
         public AnalysisConfig build() {
 
             return new AnalysisConfig(bucketSpan, categorizationFieldName, categorizationFilters, categorizationAnalyzerConfig,
-                latency, summaryCountFieldName, detectors, influencers, multivariateByFields);
+                perPartitionCategorizationConfig, latency, summaryCountFieldName, detectors, influencers, multivariateByFields);
         }
     }
 }

+ 29 - 2
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/config/JobUpdate.java

@@ -54,6 +54,8 @@ public class JobUpdate implements ToXContentObject {
         PARSER.declareLong(Builder::setModelSnapshotRetentionDays, Job.MODEL_SNAPSHOT_RETENTION_DAYS);
         PARSER.declareLong(Builder::setDailyModelSnapshotRetentionAfterDays, Job.DAILY_MODEL_SNAPSHOT_RETENTION_AFTER_DAYS);
         PARSER.declareStringArray(Builder::setCategorizationFilters, AnalysisConfig.CATEGORIZATION_FILTERS);
+        PARSER.declareObject(Builder::setPerPartitionCategorizationConfig, PerPartitionCategorizationConfig.PARSER,
+                AnalysisConfig.PER_PARTITION_CATEGORIZATION);
         PARSER.declareField(Builder::setCustomSettings, (p, c) -> p.map(), Job.CUSTOM_SETTINGS, ObjectParser.ValueType.OBJECT);
         PARSER.declareBoolean(Builder::setAllowLazyOpen, Job.ALLOW_LAZY_OPEN);
     }
@@ -70,6 +72,7 @@ public class JobUpdate implements ToXContentObject {
     private final Long dailyModelSnapshotRetentionAfterDays;
     private final Long resultsRetentionDays;
     private final List<String> categorizationFilters;
+    private final PerPartitionCategorizationConfig perPartitionCategorizationConfig;
     private final Map<String, Object> customSettings;
     private final Boolean allowLazyOpen;
 
@@ -79,6 +82,7 @@ public class JobUpdate implements ToXContentObject {
                       @Nullable Long renormalizationWindowDays, @Nullable Long resultsRetentionDays,
                       @Nullable Long modelSnapshotRetentionDays, @Nullable Long dailyModelSnapshotRetentionAfterDays,
                       @Nullable List<String> categorizationFilters,
+                      @Nullable PerPartitionCategorizationConfig perPartitionCategorizationConfig,
                       @Nullable Map<String, Object> customSettings, @Nullable Boolean allowLazyOpen) {
         this.jobId = jobId;
         this.groups = groups;
@@ -92,6 +96,7 @@ public class JobUpdate implements ToXContentObject {
         this.dailyModelSnapshotRetentionAfterDays = dailyModelSnapshotRetentionAfterDays;
         this.resultsRetentionDays = resultsRetentionDays;
         this.categorizationFilters = categorizationFilters;
+        this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
         this.customSettings = customSettings;
         this.allowLazyOpen = allowLazyOpen;
     }
@@ -140,6 +145,10 @@ public class JobUpdate implements ToXContentObject {
         return categorizationFilters;
     }
 
+    public PerPartitionCategorizationConfig getPerPartitionCategorizationConfig() {
+        return perPartitionCategorizationConfig;
+    }
+
     public Map<String, Object> getCustomSettings() {
         return customSettings;
     }
@@ -185,6 +194,9 @@ public class JobUpdate implements ToXContentObject {
         if (categorizationFilters != null) {
             builder.field(AnalysisConfig.CATEGORIZATION_FILTERS.getPreferredName(), categorizationFilters);
         }
+        if (perPartitionCategorizationConfig != null) {
+            builder.field(AnalysisConfig.PER_PARTITION_CATEGORIZATION.getPreferredName(), perPartitionCategorizationConfig);
+        }
         if (customSettings != null) {
             builder.field(Job.CUSTOM_SETTINGS.getPreferredName(), customSettings);
         }
@@ -219,6 +231,7 @@ public class JobUpdate implements ToXContentObject {
             && Objects.equals(this.dailyModelSnapshotRetentionAfterDays, that.dailyModelSnapshotRetentionAfterDays)
             && Objects.equals(this.resultsRetentionDays, that.resultsRetentionDays)
             && Objects.equals(this.categorizationFilters, that.categorizationFilters)
+            && Objects.equals(this.perPartitionCategorizationConfig, that.perPartitionCategorizationConfig)
             && Objects.equals(this.customSettings, that.customSettings)
             && Objects.equals(this.allowLazyOpen, that.allowLazyOpen);
     }
@@ -227,7 +240,7 @@ public class JobUpdate implements ToXContentObject {
     public int hashCode() {
         return Objects.hash(jobId, groups, description, detectorUpdates, modelPlotConfig, analysisLimits, renormalizationWindowDays,
             backgroundPersistInterval, modelSnapshotRetentionDays, dailyModelSnapshotRetentionAfterDays, resultsRetentionDays,
-            categorizationFilters, customSettings, allowLazyOpen);
+            categorizationFilters, perPartitionCategorizationConfig, customSettings, allowLazyOpen);
     }
 
     public static class DetectorUpdate implements ToXContentObject {
@@ -323,6 +336,7 @@ public class JobUpdate implements ToXContentObject {
         private Long dailyModelSnapshotRetentionAfterDays;
         private Long resultsRetentionDays;
         private List<String> categorizationFilters;
+        private PerPartitionCategorizationConfig perPartitionCategorizationConfig;
         private Map<String, Object> customSettings;
         private Boolean allowLazyOpen;
 
@@ -468,6 +482,19 @@ public class JobUpdate implements ToXContentObject {
             return this;
         }
 
+        /**
+         * Sets the per-partition categorization options on the {@link Job}
+         *
+         * Updates the {@link AnalysisConfig#perPartitionCategorizationConfig} setting.
+         * Requires {@link AnalysisConfig#perPartitionCategorizationConfig} to have been set on the existing Job.
+         *
+         * @param perPartitionCategorizationConfig per-partition categorization options for the Job's {@link AnalysisConfig}
+         */
+        public Builder setPerPartitionCategorizationConfig(PerPartitionCategorizationConfig perPartitionCategorizationConfig) {
+            this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
+            return this;
+        }
+
         /**
          * Contains custom meta data about the job.
          *
@@ -488,7 +515,7 @@ public class JobUpdate implements ToXContentObject {
         public JobUpdate build() {
             return new JobUpdate(jobId, groups, description, detectorUpdates, modelPlotConfig, analysisLimits, backgroundPersistInterval,
                 renormalizationWindowDays, resultsRetentionDays, modelSnapshotRetentionDays, dailyModelSnapshotRetentionAfterDays,
-                categorizationFilters, customSettings, allowLazyOpen);
+                categorizationFilters, perPartitionCategorizationConfig, customSettings, allowLazyOpen);
         }
     }
 }

+ 95 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/config/PerPartitionCategorizationConfig.java

@@ -0,0 +1,95 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.job.config;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContent;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class PerPartitionCategorizationConfig implements ToXContentObject {
+
+    public static final ParseField TYPE_FIELD = new ParseField("per_partition_categorization");
+    public static final ParseField ENABLED_FIELD = new ParseField("enabled");
+    public static final ParseField STOP_ON_WARN = new ParseField("stop_on_warn");
+
+    public static final ConstructingObjectParser<PerPartitionCategorizationConfig, Void> PARSER =
+        new ConstructingObjectParser<>(TYPE_FIELD.getPreferredName(), true,
+            a -> new PerPartitionCategorizationConfig((boolean) a[0], (Boolean) a[1]));
+
+    static {
+        PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ENABLED_FIELD);
+        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), STOP_ON_WARN);
+    }
+
+    private final boolean enabled;
+    private final boolean stopOnWarn;
+
+    public PerPartitionCategorizationConfig() {
+        this(false, null);
+    }
+
+    public PerPartitionCategorizationConfig(boolean enabled, Boolean stopOnWarn) {
+        this.enabled = enabled;
+        this.stopOnWarn = (stopOnWarn == null) ? false : stopOnWarn;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+        builder.field(ENABLED_FIELD.getPreferredName(), enabled);
+        if (enabled) {
+            builder.field(STOP_ON_WARN.getPreferredName(), stopOnWarn);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    public boolean isEnabled() {
+        return enabled;
+    }
+
+    public boolean isStopOnWarn() {
+        return stopOnWarn;
+    }
+
+    @Override
+    public boolean equals(Object other) {
+        if (this == other) {
+            return true;
+        }
+
+        if (other instanceof PerPartitionCategorizationConfig == false) {
+            return false;
+        }
+
+        PerPartitionCategorizationConfig that = (PerPartitionCategorizationConfig) other;
+        return this.enabled == that.enabled && this.stopOnWarn == that.stopOnWarn;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(enabled, stopOnWarn);
+    }
+}

+ 32 - 1
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/results/CategoryDefinition.java

@@ -38,6 +38,8 @@ public class CategoryDefinition implements ToXContentObject {
     public static final ParseField TYPE = new ParseField("category_definition");
 
     public static final ParseField CATEGORY_ID = new ParseField("category_id");
+    public static final ParseField PARTITION_FIELD_NAME = new ParseField("partition_field_name");
+    public static final ParseField PARTITION_FIELD_VALUE = new ParseField("partition_field_value");
     public static final ParseField TERMS = new ParseField("terms");
     public static final ParseField REGEX = new ParseField("regex");
     public static final ParseField MAX_MATCHING_LENGTH = new ParseField("max_matching_length");
@@ -55,6 +57,8 @@ public class CategoryDefinition implements ToXContentObject {
     static {
         PARSER.declareString(ConstructingObjectParser.constructorArg(), Job.ID);
         PARSER.declareLong(CategoryDefinition::setCategoryId, CATEGORY_ID);
+        PARSER.declareString(CategoryDefinition::setPartitionFieldName, PARTITION_FIELD_NAME);
+        PARSER.declareString(CategoryDefinition::setPartitionFieldValue, PARTITION_FIELD_VALUE);
         PARSER.declareString(CategoryDefinition::setTerms, TERMS);
         PARSER.declareString(CategoryDefinition::setRegex, REGEX);
         PARSER.declareLong(CategoryDefinition::setMaxMatchingLength, MAX_MATCHING_LENGTH);
@@ -66,6 +70,8 @@ public class CategoryDefinition implements ToXContentObject {
 
     private final String jobId;
     private long categoryId = 0L;
+    private String partitionFieldName;
+    private String partitionFieldValue;
     private String terms = "";
     private String regex = "";
     private long maxMatchingLength = 0L;
@@ -90,6 +96,22 @@ public class CategoryDefinition implements ToXContentObject {
         this.categoryId = categoryId;
     }
 
+    public String getPartitionFieldName() {
+        return partitionFieldName;
+    }
+
+    public void setPartitionFieldName(String partitionFieldName) {
+        this.partitionFieldName = partitionFieldName;
+    }
+
+    public String getPartitionFieldValue() {
+        return partitionFieldValue;
+    }
+
+    public void setPartitionFieldValue(String partitionFieldValue) {
+        this.partitionFieldValue = partitionFieldValue;
+    }
+
     public String getTerms() {
         return terms;
     }
@@ -156,6 +178,12 @@ public class CategoryDefinition implements ToXContentObject {
         builder.startObject();
         builder.field(Job.ID.getPreferredName(), jobId);
         builder.field(CATEGORY_ID.getPreferredName(), categoryId);
+        if (partitionFieldName != null) {
+            builder.field(PARTITION_FIELD_NAME.getPreferredName(), partitionFieldName);
+        }
+        if (partitionFieldValue != null) {
+            builder.field(PARTITION_FIELD_VALUE.getPreferredName(), partitionFieldValue);
+        }
         builder.field(TERMS.getPreferredName(), terms);
         builder.field(REGEX.getPreferredName(), regex);
         builder.field(MAX_MATCHING_LENGTH.getPreferredName(), maxMatchingLength);
@@ -182,6 +210,8 @@ public class CategoryDefinition implements ToXContentObject {
         CategoryDefinition that = (CategoryDefinition) other;
         return Objects.equals(this.jobId, that.jobId)
             && Objects.equals(this.categoryId, that.categoryId)
+            && Objects.equals(this.partitionFieldName, that.partitionFieldName)
+            && Objects.equals(this.partitionFieldValue, that.partitionFieldValue)
             && Objects.equals(this.terms, that.terms)
             && Objects.equals(this.regex, that.regex)
             && Objects.equals(this.maxMatchingLength, that.maxMatchingLength)
@@ -193,6 +223,7 @@ public class CategoryDefinition implements ToXContentObject {
 
     @Override
     public int hashCode() {
-        return Objects.hash(jobId, categoryId, terms, regex, maxMatchingLength, examples, preferredToCategories, numMatches, grokPattern);
+        return Objects.hash(jobId, categoryId, partitionFieldName, partitionFieldValue, terms, regex, maxMatchingLength, examples,
+            preferredToCategories, numMatches, grokPattern);
     }
 }

+ 3 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetCategoriesRequestTests.java

@@ -36,6 +36,9 @@ public class GetCategoriesRequestTests extends AbstractXContentTestCase<GetCateg
             int size = randomInt(10000);
             request.setPageParams(new PageParams(from, size));
         }
+        if (randomBoolean()) {
+            request.setPartitionFieldValue(randomAlphaOfLength(10));
+        }
         return request;
     }
 

+ 9 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/job/config/AnalysisConfigTests.java

@@ -40,7 +40,10 @@ public class AnalysisConfigTests extends AbstractXContentTestCase<AnalysisConfig
         int numDetectors = randomIntBetween(1, 10);
         for (int i = 0; i < numDetectors; i++) {
             Detector.Builder builder = new Detector.Builder("count", null);
-            builder.setPartitionFieldName(isCategorization ? "mlcategory" : "part");
+            if (isCategorization) {
+                builder.setByFieldName("mlcategory");
+            }
+            builder.setPartitionFieldName("part");
             detectors.add(builder.build());
         }
         AnalysisConfig.Builder builder = new AnalysisConfig.Builder(detectors);
@@ -82,6 +85,11 @@ public class AnalysisConfigTests extends AbstractXContentTestCase<AnalysisConfig
                 }
                 builder.setCategorizationAnalyzerConfig(analyzerBuilder.build());
             }
+            if (randomBoolean()) {
+                boolean enabled = randomBoolean();
+                builder.setPerPartitionCategorizationConfig(
+                    new PerPartitionCategorizationConfig(enabled, enabled && randomBoolean()));
+            }
         }
         if (randomBoolean()) {
             builder.setLatency(TimeValue.timeValueSeconds(randomIntBetween(1, 1_000_000)));

+ 42 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/job/config/PerPartitionCategorizationConfigTests.java

@@ -0,0 +1,42 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml.job.config;
+
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractXContentTestCase;
+
+public class PerPartitionCategorizationConfigTests extends AbstractXContentTestCase<PerPartitionCategorizationConfig> {
+
+    @Override
+    protected PerPartitionCategorizationConfig createTestInstance() {
+        boolean enabled = randomBoolean();
+        return new PerPartitionCategorizationConfig(enabled, randomBoolean() ? null : enabled && randomBoolean());
+    }
+
+    @Override
+    protected PerPartitionCategorizationConfig doParseInstance(XContentParser parser) {
+        return PerPartitionCategorizationConfig.PARSER.apply(parser, null);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return true;
+    }
+}

+ 6 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/job/results/CategoryDefinitionTests.java

@@ -31,6 +31,10 @@ public class CategoryDefinitionTests extends AbstractXContentTestCase<CategoryDe
     public static CategoryDefinition createTestInstance(String jobId) {
         CategoryDefinition categoryDefinition = new CategoryDefinition(jobId);
         categoryDefinition.setCategoryId(randomLong());
+        if (randomBoolean()) {
+            categoryDefinition.setPartitionFieldName(randomAlphaOfLength(10));
+            categoryDefinition.setPartitionFieldValue(randomAlphaOfLength(20));
+        }
         categoryDefinition.setTerms(randomAlphaOfLength(10));
         categoryDefinition.setRegex(randomAlphaOfLength(10));
         categoryDefinition.setMaxMatchingLength(randomLong());
@@ -128,6 +132,8 @@ public class CategoryDefinitionTests extends AbstractXContentTestCase<CategoryDe
     private static CategoryDefinition createFullyPopulatedCategoryDefinition() {
         CategoryDefinition category = new CategoryDefinition("jobName");
         category.setCategoryId(42);
+        category.setPartitionFieldName("p");
+        category.setPartitionFieldValue("v");
         category.setTerms("foo bar");
         category.setRegex(".*?foo.*?bar.*");
         category.setMaxMatchingLength(120L);

+ 23 - 3
docs/reference/ml/anomaly-detection/apis/get-category.asciidoc

@@ -47,8 +47,11 @@ examine the description and examples of that category. For more information, see
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=job-id-anomaly-detection]
 
 `<category_id>`::
-(Optional, long) Identifier for the category. If you do not specify this
-parameter, the API returns information about all categories in the {anomaly-job}.
+(Optional, long) Identifier for the category, which is unique in the job. If you
+specify neither the category ID nor the `partition_field_value`, the API returns
+information about all categories. If you specify only the
+`partition_field_value`, it returns information about all categories for the
+specified partition.
 
 [[ml-get-category-request-body]]
 ==== {api-request-body-title}
@@ -58,13 +61,18 @@ parameter, the API returns information about all categories in the {anomaly-job}
 `page`.`size`::
 (Optional, integer) Specifies the maximum number of categories to obtain.
 
+`partition_field_value`::
+(Optional, string) Only return categories for the specified partition.
+
 [[ml-get-category-results]]
 ==== {api-response-body-title}
 
 The API returns an array of category objects, which have the following properties:
 
 `category_id`::
-(unsigned integer) A unique identifier for the category.
+(unsigned integer) A unique identifier for the category. `category_id` is unique
+at the job level, even when per-partition categorization is enabled.
+
 
 `examples`::
 (array) A list of examples of actual values that matched the category.
@@ -85,6 +93,18 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=job-id-anomaly-detection]
 The value is increased by 10% to enable matching for similar fields that have
 not been analyzed.
 
+// This doesn't use the shared description because there are
+// categorization-specific aspects to its use in this context
+`partition_field_name`::
+(string) If per-partition categorization is enabled, this property identifies
+the field used to segment the categorization. It is not present when
+per-partition categorization is disabled.
+
+`partition_field_value`::
+(string) If per-partition categorization is enabled, this property identifies
+the value of the `partition_field_name` for the category. It is not present when
+per-partition categorization is disabled.
+
 `regex`::
 (string) A regular expression that is used to search for values that match the
 category.

+ 18 - 0
docs/reference/ml/anomaly-detection/apis/put-job.asciidoc

@@ -186,6 +186,24 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=latency]
 (boolean)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=multivariate-by-fields]
 
+//Begin analysis_config.per_partition_categorization
+`per_partition_categorization`:::
+(Optional, object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=per-partition-categorization]
++
+.Properties of `per_partition_categorization`
+[%collapsible%open]
+=====
+`enabled`::::
+(boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=per-partition-categorization-enabled]
+
+`stop_on_warn`::::
+(boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=per-partition-categorization-stop-on-warn]
+=====
+//End analysis_config.per_partition_categorization
+
 `summary_count_field_name`:::
 (string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=summary-count-field-name]

+ 18 - 0
docs/reference/ml/anomaly-detection/apis/update-job.asciidoc

@@ -198,6 +198,24 @@ include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-plot-config-terms]
 (long)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-snapshot-retention-days]
 
+//Begin per_partition_categorization
+`per_partition_categorization`:::
+(object)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=per-partition-categorization]
++
+.Properties of `per_partition_categorization`
+[%collapsible%open]
+====
+`enabled`:::
+(boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=per-partition-categorization-enabled]
+
+`stop_on_warn`:::
+(boolean)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=per-partition-categorization-stop-on-warn]
+====
+//End per_partition_categorization
+
 `renormalization_window_days`::
 (long)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=renormalization-window-days]

+ 19 - 0
docs/reference/ml/ml-shared.asciidoc

@@ -1113,6 +1113,25 @@ The field used to segment the analysis. When you use this property, you have
 completely independent baselines for each value of this field.
 end::partition-field-name[]
 
+tag::per-partition-categorization[]
+Settings related to how categorization interacts with partition fields.
+end::per-partition-categorization[]
+
+tag::per-partition-categorization-enabled[]
+To enable this setting, you must also set the partition_field_name property to
+the same value in every detector that uses the keyword mlcategory. Otherwise,
+job creation fails.
+end::per-partition-categorization-enabled[]
+
+tag::per-partition-categorization-stop-on-warn[]
+This setting can be set to true only if per-partition categorization is enabled.
+If true, both categorization and subsequent anomaly detection stops for
+partitions where the categorization status changes to `warn`. This setting makes
+it viable to have a job where it is expected that categorization works well for
+some partitions but not others; you do not pay the cost of bad categorization
+forever in the partitions where it works badly.
+end::per-partition-categorization-stop-on-warn[]
+
 tag::prediction-field-name[]
 Defines the name of the prediction field in the results. 
 Defaults to `<dependent_variable>_prediction`.

+ 25 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetCategoriesAction.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.action;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.action.ActionRequestBuilder;
 import org.elasticsearch.action.ActionRequestValidationException;
@@ -40,9 +41,10 @@ public class GetCategoriesAction extends ActionType<GetCategoriesAction.Response
 
     public static class Request extends ActionRequest implements ToXContentObject {
 
-        public static final ParseField CATEGORY_ID = new ParseField("category_id");
+        public static final ParseField CATEGORY_ID = CategoryDefinition.CATEGORY_ID;
         public static final ParseField FROM = new ParseField("from");
         public static final ParseField SIZE = new ParseField("size");
+        public static final ParseField PARTITION_FIELD_VALUE = CategoryDefinition.PARTITION_FIELD_VALUE;
 
         private static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
 
@@ -50,6 +52,7 @@ public class GetCategoriesAction extends ActionType<GetCategoriesAction.Response
             PARSER.declareString((request, jobId) -> request.jobId = jobId, Job.ID);
             PARSER.declareLong(Request::setCategoryId, CATEGORY_ID);
             PARSER.declareObject(Request::setPageParams, PageParams.PARSER, PageParams.PAGE);
+            PARSER.declareString(Request::setPartitionFieldValue, PARTITION_FIELD_VALUE);
         }
 
         public static Request parseRequest(String jobId, XContentParser parser) {
@@ -63,6 +66,7 @@ public class GetCategoriesAction extends ActionType<GetCategoriesAction.Response
         private String jobId;
         private Long categoryId;
         private PageParams pageParams;
+        private String partitionFieldValue;
 
         public Request(String jobId) {
             this.jobId = ExceptionsHelper.requireNonNull(jobId, Job.ID.getPreferredName());
@@ -76,6 +80,9 @@ public class GetCategoriesAction extends ActionType<GetCategoriesAction.Response
             jobId = in.readString();
             categoryId = in.readOptionalLong();
             pageParams = in.readOptionalWriteable(PageParams::new);
+            if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+                partitionFieldValue = in.readOptionalString();
+            }
         }
 
         public String getJobId() { return jobId; }
@@ -100,6 +107,14 @@ public class GetCategoriesAction extends ActionType<GetCategoriesAction.Response
             this.pageParams = pageParams;
         }
 
+        public String getPartitionFieldValue() {
+            return partitionFieldValue;
+        }
+
+        public void setPartitionFieldValue(String partitionFieldValue) {
+            this.partitionFieldValue = partitionFieldValue;
+        }
+
         @Override
         public ActionRequestValidationException validate() {
             ActionRequestValidationException validationException = null;
@@ -117,6 +132,9 @@ public class GetCategoriesAction extends ActionType<GetCategoriesAction.Response
             out.writeString(jobId);
             out.writeOptionalLong(categoryId);
             out.writeOptionalWriteable(pageParams);
+            if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+                out.writeOptionalString(partitionFieldValue);
+            }
         }
 
         @Override
@@ -129,6 +147,9 @@ public class GetCategoriesAction extends ActionType<GetCategoriesAction.Response
             if (pageParams != null) {
                 builder.field(PageParams.PAGE.getPreferredName(), pageParams);
             }
+            if (partitionFieldValue != null) {
+                builder.field(PARTITION_FIELD_VALUE.getPreferredName(), partitionFieldValue);
+            }
             builder.endObject();
             return builder;
         }
@@ -142,12 +163,13 @@ public class GetCategoriesAction extends ActionType<GetCategoriesAction.Response
             Request request = (Request) o;
             return Objects.equals(jobId, request.jobId)
                     && Objects.equals(categoryId, request.categoryId)
-                    && Objects.equals(pageParams, request.pageParams);
+                    && Objects.equals(pageParams, request.pageParams)
+                    && Objects.equals(partitionFieldValue, request.partitionFieldValue);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(jobId, categoryId, pageParams);
+            return Objects.hash(jobId, categoryId, pageParams, partitionFieldValue);
         }
     }
 

+ 19 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateProcessAction.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.action;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.ActionRequestBuilder;
 import org.elasticsearch.action.support.tasks.BaseTasksResponse;
@@ -18,6 +19,7 @@ import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.core.ml.job.config.JobUpdate;
 import org.elasticsearch.xpack.core.ml.job.config.MlFilter;
 import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 
 import java.io.IOException;
 import java.util.List;
@@ -98,6 +100,7 @@ public class UpdateProcessAction extends ActionType<UpdateProcessAction.Response
     public static class Request extends JobTaskRequest<Request> {
 
         private ModelPlotConfig modelPlotConfig;
+        private PerPartitionCategorizationConfig perPartitionCategorizationConfig;
         private List<JobUpdate.DetectorUpdate> detectorUpdates;
         private MlFilter filter;
         private boolean updateScheduledEvents = false;
@@ -107,6 +110,9 @@ public class UpdateProcessAction extends ActionType<UpdateProcessAction.Response
         public Request(StreamInput in) throws IOException {
             super(in);
             modelPlotConfig = in.readOptionalWriteable(ModelPlotConfig::new);
+            if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+                perPartitionCategorizationConfig = in.readOptionalWriteable(PerPartitionCategorizationConfig::new);
+            }
             if (in.readBoolean()) {
                 detectorUpdates = in.readList(JobUpdate.DetectorUpdate::new);
             }
@@ -118,6 +124,9 @@ public class UpdateProcessAction extends ActionType<UpdateProcessAction.Response
         public void writeTo(StreamOutput out) throws IOException {
             super.writeTo(out);
             out.writeOptionalWriteable(modelPlotConfig);
+            if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+                out.writeOptionalWriteable(perPartitionCategorizationConfig);
+            }
             boolean hasDetectorUpdates = detectorUpdates != null;
             out.writeBoolean(hasDetectorUpdates);
             if (hasDetectorUpdates) {
@@ -127,10 +136,11 @@ public class UpdateProcessAction extends ActionType<UpdateProcessAction.Response
             out.writeBoolean(updateScheduledEvents);
         }
 
-        public Request(String jobId, ModelPlotConfig modelPlotConfig, List<JobUpdate.DetectorUpdate> detectorUpdates, MlFilter filter,
-                       boolean updateScheduledEvents) {
+        public Request(String jobId, ModelPlotConfig modelPlotConfig, PerPartitionCategorizationConfig perPartitionCategorizationConfig,
+                       List<JobUpdate.DetectorUpdate> detectorUpdates, MlFilter filter, boolean updateScheduledEvents) {
             super(jobId);
             this.modelPlotConfig = modelPlotConfig;
+            this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
             this.detectorUpdates = detectorUpdates;
             this.filter = filter;
             this.updateScheduledEvents = updateScheduledEvents;
@@ -140,6 +150,10 @@ public class UpdateProcessAction extends ActionType<UpdateProcessAction.Response
             return modelPlotConfig;
         }
 
+        public PerPartitionCategorizationConfig getPerPartitionCategorizationConfig() {
+            return perPartitionCategorizationConfig;
+        }
+
         public List<JobUpdate.DetectorUpdate> getDetectorUpdates() {
             return detectorUpdates;
         }
@@ -154,7 +168,8 @@ public class UpdateProcessAction extends ActionType<UpdateProcessAction.Response
 
         @Override
         public int hashCode() {
-            return Objects.hash(getJobId(), modelPlotConfig, detectorUpdates, filter, updateScheduledEvents);
+            return Objects.hash(getJobId(), modelPlotConfig, perPartitionCategorizationConfig, detectorUpdates, filter,
+                updateScheduledEvents);
         }
 
         @Override
@@ -169,6 +184,7 @@ public class UpdateProcessAction extends ActionType<UpdateProcessAction.Response
 
             return Objects.equals(getJobId(), other.getJobId()) &&
                     Objects.equals(modelPlotConfig, other.modelPlotConfig) &&
+                    Objects.equals(perPartitionCategorizationConfig, other.perPartitionCategorizationConfig) &&
                     Objects.equals(detectorUpdates, other.detectorUpdates) &&
                     Objects.equals(filter, other.filter) &&
                     Objects.equals(updateScheduledEvents, other.updateScheduledEvents);

+ 78 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfig.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.job.config;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -29,6 +30,7 @@ import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeSet;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
 import java.util.regex.Pattern;
 import java.util.regex.PatternSyntaxException;
@@ -56,6 +58,7 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
     public static final ParseField CATEGORIZATION_FIELD_NAME = new ParseField("categorization_field_name");
     public static final ParseField CATEGORIZATION_FILTERS = new ParseField("categorization_filters");
     public static final ParseField CATEGORIZATION_ANALYZER = CategorizationAnalyzerConfig.CATEGORIZATION_ANALYZER;
+    public static final ParseField PER_PARTITION_CATEGORIZATION = new ParseField("per_partition_categorization");
     public static final ParseField LATENCY = new ParseField("latency");
     public static final ParseField SUMMARY_COUNT_FIELD_NAME = new ParseField("summary_count_field_name");
     public static final ParseField DETECTORS = new ParseField("detectors");
@@ -85,6 +88,9 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
         parser.declareField(Builder::setCategorizationAnalyzerConfig,
             (p, c) -> CategorizationAnalyzerConfig.buildFromXContentFragment(p, ignoreUnknownFields),
             CATEGORIZATION_ANALYZER, ObjectParser.ValueType.OBJECT_OR_STRING);
+        parser.declareObject(Builder::setPerPartitionCategorizationConfig,
+            ignoreUnknownFields ? PerPartitionCategorizationConfig.LENIENT_PARSER : PerPartitionCategorizationConfig.STRICT_PARSER,
+            PER_PARTITION_CATEGORIZATION);
         parser.declareString((builder, val) ->
             builder.setLatency(TimeValue.parseTimeValue(val, LATENCY.getPreferredName())), LATENCY);
         parser.declareString(Builder::setSummaryCountFieldName, SUMMARY_COUNT_FIELD_NAME);
@@ -101,6 +107,7 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
     private final String categorizationFieldName;
     private final List<String> categorizationFilters;
     private final CategorizationAnalyzerConfig categorizationAnalyzerConfig;
+    private final PerPartitionCategorizationConfig perPartitionCategorizationConfig;
     private final TimeValue latency;
     private final String summaryCountFieldName;
     private final List<Detector> detectors;
@@ -108,14 +115,16 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
     private final Boolean multivariateByFields;
 
     private AnalysisConfig(TimeValue bucketSpan, String categorizationFieldName, List<String> categorizationFilters,
-                           CategorizationAnalyzerConfig categorizationAnalyzerConfig, TimeValue latency, String summaryCountFieldName,
-                           List<Detector> detectors, List<String> influencers, Boolean multivariateByFields) {
+                           CategorizationAnalyzerConfig categorizationAnalyzerConfig,
+                           PerPartitionCategorizationConfig perPartitionCategorizationConfig, TimeValue latency,
+                           String summaryCountFieldName, List<Detector> detectors, List<String> influencers, Boolean multivariateByFields) {
         this.detectors = detectors;
         this.bucketSpan = bucketSpan;
         this.latency = latency;
         this.categorizationFieldName = categorizationFieldName;
         this.categorizationAnalyzerConfig = categorizationAnalyzerConfig;
         this.categorizationFilters = categorizationFilters == null ? null : Collections.unmodifiableList(categorizationFilters);
+        this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
         this.summaryCountFieldName = summaryCountFieldName;
         this.influencers = Collections.unmodifiableList(influencers);
         this.multivariateByFields = multivariateByFields;
@@ -126,6 +135,11 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
         categorizationFieldName = in.readOptionalString();
         categorizationFilters = in.readBoolean() ? Collections.unmodifiableList(in.readStringList()) : null;
         categorizationAnalyzerConfig = in.readOptionalWriteable(CategorizationAnalyzerConfig::new);
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            perPartitionCategorizationConfig = new PerPartitionCategorizationConfig(in);
+        } else {
+            perPartitionCategorizationConfig = new PerPartitionCategorizationConfig();
+        }
         latency = in.readOptionalTimeValue();
         summaryCountFieldName = in.readOptionalString();
         detectors = Collections.unmodifiableList(in.readList(Detector::new));
@@ -145,6 +159,9 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
             out.writeBoolean(false);
         }
         out.writeOptionalWriteable(categorizationAnalyzerConfig);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            perPartitionCategorizationConfig.writeTo(out);
+        }
         out.writeOptionalTimeValue(latency);
         out.writeOptionalString(summaryCountFieldName);
         out.writeList(detectors);
@@ -174,6 +191,10 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
         return categorizationAnalyzerConfig;
     }
 
+    public PerPartitionCategorizationConfig getPerPartitionCategorizationConfig() {
+        return perPartitionCategorizationConfig;
+    }
+
     /**
      * The latency interval during which out-of-order records should be handled.
      *
@@ -325,6 +346,11 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
             // gets written as a single string.
             categorizationAnalyzerConfig.toXContent(builder, params);
         }
+        // perPartitionCategorizationConfig is never null on the server side (it can be in the equivalent client class),
+        // but is not useful to know when categorization is not being used
+        if (categorizationFieldName != null) {
+            builder.field(PER_PARTITION_CATEGORIZATION.getPreferredName(), perPartitionCategorizationConfig);
+        }
         if (latency != null) {
             builder.field(LATENCY.getPreferredName(), latency.getStringRep());
         }
@@ -354,6 +380,7 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
                 Objects.equals(categorizationFieldName, that.categorizationFieldName) &&
                 Objects.equals(categorizationFilters, that.categorizationFilters) &&
                 Objects.equals(categorizationAnalyzerConfig, that.categorizationAnalyzerConfig) &&
+                Objects.equals(perPartitionCategorizationConfig, that.perPartitionCategorizationConfig) &&
                 Objects.equals(summaryCountFieldName, that.summaryCountFieldName) &&
                 Objects.equals(detectors, that.detectors) &&
                 Objects.equals(influencers, that.influencers) &&
@@ -363,8 +390,8 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
     @Override
     public int hashCode() {
         return Objects.hash(
-                bucketSpan, categorizationFieldName, categorizationFilters, categorizationAnalyzerConfig, latency,
-                summaryCountFieldName, detectors, influencers, multivariateByFields);
+                bucketSpan, categorizationFieldName, categorizationFilters, categorizationAnalyzerConfig, perPartitionCategorizationConfig,
+                latency, summaryCountFieldName, detectors, influencers, multivariateByFields);
     }
 
     public static class Builder {
@@ -377,6 +404,7 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
         private String categorizationFieldName;
         private List<String> categorizationFilters;
         private CategorizationAnalyzerConfig categorizationAnalyzerConfig;
+        private PerPartitionCategorizationConfig perPartitionCategorizationConfig = new PerPartitionCategorizationConfig();
         private String summaryCountFieldName;
         private List<String> influencers = new ArrayList<>();
         private Boolean multivariateByFields;
@@ -393,6 +421,7 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
             this.categorizationFilters = analysisConfig.categorizationFilters == null ? null
                     : new ArrayList<>(analysisConfig.categorizationFilters);
             this.categorizationAnalyzerConfig = analysisConfig.categorizationAnalyzerConfig;
+            this.perPartitionCategorizationConfig = analysisConfig.perPartitionCategorizationConfig;
             this.summaryCountFieldName = analysisConfig.summaryCountFieldName;
             this.influencers = new ArrayList<>(analysisConfig.influencers);
             this.multivariateByFields = analysisConfig.multivariateByFields;
@@ -445,6 +474,12 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
             return this;
         }
 
+        public Builder setPerPartitionCategorizationConfig(PerPartitionCategorizationConfig perPartitionCategorizationConfig) {
+            this.perPartitionCategorizationConfig =
+                ExceptionsHelper.requireNonNull(perPartitionCategorizationConfig, PER_PARTITION_CATEGORIZATION.getPreferredName());
+            return this;
+        }
+
         public Builder setSummaryCountFieldName(String summaryCountFieldName) {
             this.summaryCountFieldName = summaryCountFieldName;
             return this;
@@ -485,13 +520,51 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
             verifyMlCategoryIsUsedWhenCategorizationFieldNameIsSet();
             verifyCategorizationAnalyzer();
             verifyCategorizationFilters();
+            verifyConfigConsistentWithPerPartitionCategorization();
 
             verifyNoMetricFunctionsWhenSummaryCountFieldNameIsSet();
 
             verifyNoInconsistentNestedFieldNames();
 
             return new AnalysisConfig(bucketSpan, categorizationFieldName, categorizationFilters, categorizationAnalyzerConfig,
-                    latency, summaryCountFieldName, detectors, influencers, multivariateByFields);
+                perPartitionCategorizationConfig, latency, summaryCountFieldName, detectors, influencers, multivariateByFields);
+        }
+
+        private void verifyConfigConsistentWithPerPartitionCategorization() {
+            if (perPartitionCategorizationConfig.isEnabled() == false) {
+                return;
+            }
+
+            if (categorizationFieldName == null) {
+                throw ExceptionsHelper.badRequestException(CATEGORIZATION_FIELD_NAME.getPreferredName()
+                    + " must be set when per-partition categorization is enabled");
+            }
+
+            AtomicReference<String> singlePartitionFieldName = new AtomicReference<>();
+            detectors.forEach(d -> {
+                String thisDetectorPartitionFieldName = d.getPartitionFieldName();
+                if (d.getByOverPartitionTerms().contains(ML_CATEGORY_FIELD)) {
+                    if (ML_CATEGORY_FIELD.equals(d.getPartitionFieldName())) {
+                        throw ExceptionsHelper.badRequestException(ML_CATEGORY_FIELD + " cannot be used as a "
+                            + Detector.PARTITION_FIELD_NAME_FIELD.getPreferredName()
+                            + " when per-partition categorization is enabled");
+                    }
+                    if (thisDetectorPartitionFieldName == null) {
+                        throw ExceptionsHelper.badRequestException(Detector.PARTITION_FIELD_NAME_FIELD.getPreferredName()
+                            + " must be set for detectors that reference " + ML_CATEGORY_FIELD
+                            + " when per-partition categorization is enabled");
+                    }
+                }
+                if (thisDetectorPartitionFieldName != null) {
+                    String previousPartitionFieldName = singlePartitionFieldName.getAndSet(thisDetectorPartitionFieldName);
+                    if (previousPartitionFieldName != null &&
+                        previousPartitionFieldName.equals(thisDetectorPartitionFieldName) == false) {
+                        throw ExceptionsHelper.badRequestException(Detector.PARTITION_FIELD_NAME_FIELD.getPreferredName()
+                            + " cannot vary between detectors when per-partition categorization is enabled: ["
+                            + previousPartitionFieldName + "] and [" + thisDetectorPartitionFieldName + "] are used");
+                    }
+                }
+            });
         }
 
         private void verifyNoMetricFunctionsWhenSummaryCountFieldNameIsSet() {

+ 46 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/JobUpdate.java

@@ -54,6 +54,8 @@ public class JobUpdate implements Writeable, ToXContentObject {
             parser.declareLong(Builder::setModelSnapshotRetentionDays, Job.MODEL_SNAPSHOT_RETENTION_DAYS);
             parser.declareLong(Builder::setDailyModelSnapshotRetentionAfterDays, Job.DAILY_MODEL_SNAPSHOT_RETENTION_AFTER_DAYS);
             parser.declareStringArray(Builder::setCategorizationFilters, AnalysisConfig.CATEGORIZATION_FILTERS);
+            parser.declareObject(Builder::setPerPartitionCategorizationConfig, PerPartitionCategorizationConfig.STRICT_PARSER,
+                    AnalysisConfig.PER_PARTITION_CATEGORIZATION);
             parser.declareField(Builder::setCustomSettings, (p, c) -> p.map(), Job.CUSTOM_SETTINGS, ObjectParser.ValueType.OBJECT);
             parser.declareBoolean(Builder::setAllowLazyOpen, Job.ALLOW_LAZY_OPEN);
         }
@@ -76,6 +78,7 @@ public class JobUpdate implements Writeable, ToXContentObject {
     private final Long dailyModelSnapshotRetentionAfterDays;
     private final Long resultsRetentionDays;
     private final List<String> categorizationFilters;
+    private final PerPartitionCategorizationConfig perPartitionCategorizationConfig;
     private final Map<String, Object> customSettings;
     private final String modelSnapshotId;
     private final Version modelSnapshotMinVersion;
@@ -89,6 +92,7 @@ public class JobUpdate implements Writeable, ToXContentObject {
                       @Nullable Long renormalizationWindowDays, @Nullable Long resultsRetentionDays,
                       @Nullable Long modelSnapshotRetentionDays, @Nullable Long dailyModelSnapshotRetentionAfterDays,
                       @Nullable List<String> categorizationFilters,
+                      @Nullable PerPartitionCategorizationConfig perPartitionCategorizationConfig,
                       @Nullable Map<String, Object> customSettings, @Nullable String modelSnapshotId,
                       @Nullable Version modelSnapshotMinVersion, @Nullable Version jobVersion, @Nullable Boolean clearJobFinishTime,
                       @Nullable Boolean allowLazyOpen) {
@@ -104,6 +108,7 @@ public class JobUpdate implements Writeable, ToXContentObject {
         this.dailyModelSnapshotRetentionAfterDays = dailyModelSnapshotRetentionAfterDays;
         this.resultsRetentionDays = resultsRetentionDays;
         this.categorizationFilters = categorizationFilters;
+        this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
         this.customSettings = customSettings;
         this.modelSnapshotId = modelSnapshotId;
         this.modelSnapshotMinVersion = modelSnapshotMinVersion;
@@ -134,6 +139,11 @@ public class JobUpdate implements Writeable, ToXContentObject {
         } else {
             categorizationFilters = null;
         }
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            perPartitionCategorizationConfig = in.readOptionalWriteable(PerPartitionCategorizationConfig::new);
+        } else {
+            perPartitionCategorizationConfig = null;
+        }
         customSettings = in.readMap();
         modelSnapshotId = in.readOptionalString();
         if (in.readBoolean()) {
@@ -153,7 +163,7 @@ public class JobUpdate implements Writeable, ToXContentObject {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeString(jobId);
-        String[] groupsArray = groups == null ? null : groups.toArray(new String[groups.size()]);
+        String[] groupsArray = groups == null ? null : groups.toArray(new String[0]);
         out.writeOptionalStringArray(groupsArray);
         out.writeOptionalString(description);
         out.writeBoolean(detectorUpdates != null);
@@ -171,6 +181,9 @@ public class JobUpdate implements Writeable, ToXContentObject {
         if (categorizationFilters != null) {
             out.writeStringCollection(categorizationFilters);
         }
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeOptionalWriteable(perPartitionCategorizationConfig);
+        }
         out.writeMap(customSettings);
         out.writeOptionalString(modelSnapshotId);
         if (jobVersion != null) {
@@ -237,6 +250,10 @@ public class JobUpdate implements Writeable, ToXContentObject {
         return categorizationFilters;
     }
 
+    public PerPartitionCategorizationConfig getPerPartitionCategorizationConfig() {
+        return perPartitionCategorizationConfig;
+    }
+
     public Map<String, Object> getCustomSettings() {
         return customSettings;
     }
@@ -262,7 +279,7 @@ public class JobUpdate implements Writeable, ToXContentObject {
     }
 
     public boolean isAutodetectProcessUpdate() {
-        return modelPlotConfig != null || detectorUpdates != null || groups != null;
+        return modelPlotConfig != null || perPartitionCategorizationConfig != null || detectorUpdates != null || groups != null;
     }
 
     @Override
@@ -302,6 +319,9 @@ public class JobUpdate implements Writeable, ToXContentObject {
         if (categorizationFilters != null) {
             builder.field(AnalysisConfig.CATEGORIZATION_FILTERS.getPreferredName(), categorizationFilters);
         }
+        if (perPartitionCategorizationConfig != null) {
+            builder.field(AnalysisConfig.PER_PARTITION_CATEGORIZATION.getPreferredName(), perPartitionCategorizationConfig);
+        }
         if (customSettings != null) {
             builder.field(Job.CUSTOM_SETTINGS.getPreferredName(), customSettings);
         }
@@ -359,6 +379,9 @@ public class JobUpdate implements Writeable, ToXContentObject {
         if (categorizationFilters != null) {
             updateFields.add(AnalysisConfig.CATEGORIZATION_FILTERS.getPreferredName());
         }
+        if (perPartitionCategorizationConfig != null) {
+            updateFields.add(AnalysisConfig.PER_PARTITION_CATEGORIZATION.getPreferredName());
+        }
         if (customSettings != null) {
             updateFields.add(Job.CUSTOM_SETTINGS.getPreferredName());
         }
@@ -440,6 +463,14 @@ public class JobUpdate implements Writeable, ToXContentObject {
         if (categorizationFilters != null) {
             newAnalysisConfig.setCategorizationFilters(categorizationFilters);
         }
+        if (perPartitionCategorizationConfig != null) {
+            // Whether per-partition categorization is enabled cannot be changed, only the lower level details
+            if (perPartitionCategorizationConfig.isEnabled() !=
+                    currentAnalysisConfig.getPerPartitionCategorizationConfig().isEnabled()) {
+                throw ExceptionsHelper.badRequestException("analysis_config.per_partition_categorization.enabled cannot be updated");
+            }
+            newAnalysisConfig.setPerPartitionCategorizationConfig(perPartitionCategorizationConfig);
+        }
         if (customSettings != null) {
             builder.setCustomSettings(customSettings);
         }
@@ -477,6 +508,8 @@ public class JobUpdate implements Writeable, ToXContentObject {
                 && (resultsRetentionDays == null || Objects.equals(resultsRetentionDays, job.getResultsRetentionDays()))
                 && (categorizationFilters == null
                         || Objects.equals(categorizationFilters, job.getAnalysisConfig().getCategorizationFilters()))
+                && (perPartitionCategorizationConfig == null
+                        || Objects.equals(perPartitionCategorizationConfig, job.getAnalysisConfig().getPerPartitionCategorizationConfig()))
                 && (customSettings == null || Objects.equals(customSettings, job.getCustomSettings()))
                 && (modelSnapshotId == null || Objects.equals(modelSnapshotId, job.getModelSnapshotId()))
                 && (modelSnapshotMinVersion == null || Objects.equals(modelSnapshotMinVersion, job.getModelSnapshotMinVersion()))
@@ -527,6 +560,7 @@ public class JobUpdate implements Writeable, ToXContentObject {
                 && Objects.equals(this.dailyModelSnapshotRetentionAfterDays, that.dailyModelSnapshotRetentionAfterDays)
                 && Objects.equals(this.resultsRetentionDays, that.resultsRetentionDays)
                 && Objects.equals(this.categorizationFilters, that.categorizationFilters)
+                && Objects.equals(this.perPartitionCategorizationConfig, that.perPartitionCategorizationConfig)
                 && Objects.equals(this.customSettings, that.customSettings)
                 && Objects.equals(this.modelSnapshotId, that.modelSnapshotId)
                 && Objects.equals(this.modelSnapshotMinVersion, that.modelSnapshotMinVersion)
@@ -539,8 +573,8 @@ public class JobUpdate implements Writeable, ToXContentObject {
     public int hashCode() {
         return Objects.hash(jobId, groups, description, detectorUpdates, modelPlotConfig, analysisLimits, renormalizationWindowDays,
                 backgroundPersistInterval, modelSnapshotRetentionDays, dailyModelSnapshotRetentionAfterDays, resultsRetentionDays,
-                categorizationFilters, customSettings, modelSnapshotId, modelSnapshotMinVersion, jobVersion, clearJobFinishTime,
-                allowLazyOpen);
+                categorizationFilters, perPartitionCategorizationConfig, customSettings, modelSnapshotId, modelSnapshotMinVersion,
+                jobVersion, clearJobFinishTime, allowLazyOpen);
     }
 
     public static class DetectorUpdate implements Writeable, ToXContentObject {
@@ -648,6 +682,7 @@ public class JobUpdate implements Writeable, ToXContentObject {
         private Long dailyModelSnapshotRetentionAfterDays;
         private Long resultsRetentionDays;
         private List<String> categorizationFilters;
+        private PerPartitionCategorizationConfig perPartitionCategorizationConfig;
         private Map<String, Object> customSettings;
         private String modelSnapshotId;
         private Version modelSnapshotMinVersion;
@@ -719,6 +754,11 @@ public class JobUpdate implements Writeable, ToXContentObject {
             return this;
         }
 
+        public Builder setPerPartitionCategorizationConfig(PerPartitionCategorizationConfig perPartitionCategorizationConfig) {
+            this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
+            return this;
+        }
+
         public Builder setCustomSettings(Map<String, Object> customSettings) {
             this.customSettings = customSettings;
             return this;
@@ -762,8 +802,8 @@ public class JobUpdate implements Writeable, ToXContentObject {
         public JobUpdate build() {
             return new JobUpdate(jobId, groups, description, detectorUpdates, modelPlotConfig, analysisLimits, backgroundPersistInterval,
                     renormalizationWindowDays, resultsRetentionDays, modelSnapshotRetentionDays, dailyModelSnapshotRetentionAfterDays,
-                    categorizationFilters, customSettings, modelSnapshotId, modelSnapshotMinVersion, jobVersion, clearJobFinishTime,
-                    allowLazyOpen);
+                    categorizationFilters, perPartitionCategorizationConfig, customSettings, modelSnapshotId, modelSnapshotMinVersion,
+                    jobVersion, clearJobFinishTime, allowLazyOpen);
         }
     }
 }

+ 105 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/PerPartitionCategorizationConfig.java

@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.job.config;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class PerPartitionCategorizationConfig implements ToXContentObject, Writeable {
+
+    public static final ParseField TYPE_FIELD = new ParseField("per_partition_categorization");
+    public static final ParseField ENABLED_FIELD = new ParseField("enabled");
+    public static final ParseField STOP_ON_WARN = new ParseField("stop_on_warn");
+
+    // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
+    public static final ConstructingObjectParser<PerPartitionCategorizationConfig, Void> LENIENT_PARSER = createParser(true);
+    public static final ConstructingObjectParser<PerPartitionCategorizationConfig, Void> STRICT_PARSER = createParser(false);
+
+    private static ConstructingObjectParser<PerPartitionCategorizationConfig, Void> createParser(boolean ignoreUnknownFields) {
+        ConstructingObjectParser<PerPartitionCategorizationConfig, Void> parser =
+            new ConstructingObjectParser<>(TYPE_FIELD.getPreferredName(), ignoreUnknownFields,
+                a -> new PerPartitionCategorizationConfig((boolean) a[0], (Boolean) a[1]));
+
+        parser.declareBoolean(ConstructingObjectParser.constructorArg(), ENABLED_FIELD);
+        parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), STOP_ON_WARN);
+
+        return parser;
+    }
+
+    private final boolean enabled;
+    private final boolean stopOnWarn;
+
+    public PerPartitionCategorizationConfig() {
+        this(false, null);
+    }
+
+    public PerPartitionCategorizationConfig(boolean enabled, Boolean stopOnWarn) {
+        this.enabled = enabled;
+        this.stopOnWarn = (stopOnWarn == null) ? false : stopOnWarn;
+        if (this.enabled == false && this.stopOnWarn) {
+            throw ExceptionsHelper.badRequestException(STOP_ON_WARN.getPreferredName() + " cannot be true in "
+                + TYPE_FIELD.getPreferredName() + " when " + ENABLED_FIELD.getPreferredName() + " is false");
+        }
+    }
+
+    public PerPartitionCategorizationConfig(StreamInput in) throws IOException {
+        enabled = in.readBoolean();
+        stopOnWarn = in.readBoolean();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeBoolean(enabled);
+        out.writeBoolean(stopOnWarn);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(ENABLED_FIELD.getPreferredName(), enabled);
+        if (enabled) {
+            builder.field(STOP_ON_WARN.getPreferredName(), stopOnWarn);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    public boolean isEnabled() {
+        return enabled;
+    }
+
+    public boolean isStopOnWarn() {
+        return stopOnWarn;
+    }
+
+    @Override
+    public boolean equals(Object other) {
+        if (this == other) {
+            return true;
+        }
+
+        if (other instanceof PerPartitionCategorizationConfig == false) {
+            return false;
+        }
+
+        PerPartitionCategorizationConfig that = (PerPartitionCategorizationConfig) other;
+        return this.enabled == that.enabled && this.stopOnWarn == that.stopOnWarn;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(enabled, stopOnWarn);
+    }
+}

+ 1 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/AnomalyRecord.java

@@ -248,12 +248,6 @@ public class AnomalyRecord implements ToXContentObject, Writeable {
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
-        innerToXContent(builder, params);
-        builder.endObject();
-        return builder;
-    }
-
-    XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
         builder.field(Job.ID.getPreferredName(), jobId);
         builder.field(Result.RESULT_TYPE.getPreferredName(), RESULT_TYPE_VALUE);
         builder.field(PROBABILITY.getPreferredName(), probability);
@@ -316,6 +310,7 @@ public class AnomalyRecord implements ToXContentObject, Writeable {
         for (String fieldName : inputFields.keySet()) {
             builder.field(fieldName, inputFields.get(fieldName));
         }
+        builder.endObject();
         return builder;
     }
 

+ 47 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/CategoryDefinition.java

@@ -32,6 +32,8 @@ public class CategoryDefinition implements ToXContentObject, Writeable {
     public static final ParseField TYPE = new ParseField("category_definition");
 
     public static final ParseField CATEGORY_ID = new ParseField("category_id");
+    public static final ParseField PARTITION_FIELD_NAME = new ParseField("partition_field_name");
+    public static final ParseField PARTITION_FIELD_VALUE = new ParseField("partition_field_value");
     public static final ParseField TERMS = new ParseField("terms");
     public static final ParseField REGEX = new ParseField("regex");
     public static final ParseField MAX_MATCHING_LENGTH = new ParseField("max_matching_length");
@@ -52,6 +54,8 @@ public class CategoryDefinition implements ToXContentObject, Writeable {
 
         parser.declareString(ConstructingObjectParser.constructorArg(), Job.ID);
         parser.declareLong(CategoryDefinition::setCategoryId, CATEGORY_ID);
+        parser.declareString(CategoryDefinition::setPartitionFieldName, PARTITION_FIELD_NAME);
+        parser.declareString(CategoryDefinition::setPartitionFieldValue, PARTITION_FIELD_VALUE);
         parser.declareString(CategoryDefinition::setTerms, TERMS);
         parser.declareString(CategoryDefinition::setRegex, REGEX);
         parser.declareLong(CategoryDefinition::setMaxMatchingLength, MAX_MATCHING_LENGTH);
@@ -64,6 +68,8 @@ public class CategoryDefinition implements ToXContentObject, Writeable {
 
     private final String jobId;
     private long categoryId = 0L;
+    private String partitionFieldName;
+    private String partitionFieldValue;
     private String terms = "";
     private String regex = "";
     private long maxMatchingLength = 0L;
@@ -80,6 +86,10 @@ public class CategoryDefinition implements ToXContentObject, Writeable {
     public CategoryDefinition(StreamInput in) throws IOException {
         jobId = in.readString();
         categoryId = in.readLong();
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            partitionFieldName = in.readOptionalString();
+            partitionFieldValue = in.readOptionalString();
+        }
         terms = in.readString();
         regex = in.readString();
         maxMatchingLength = in.readLong();
@@ -95,6 +105,10 @@ public class CategoryDefinition implements ToXContentObject, Writeable {
     public void writeTo(StreamOutput out) throws IOException {
         out.writeString(jobId);
         out.writeLong(categoryId);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeOptionalString(partitionFieldName);
+            out.writeOptionalString(partitionFieldValue);
+        }
         out.writeString(terms);
         out.writeString(regex);
         out.writeLong(maxMatchingLength);
@@ -122,6 +136,22 @@ public class CategoryDefinition implements ToXContentObject, Writeable {
         this.categoryId = categoryId;
     }
 
+    public String getPartitionFieldName() {
+        return partitionFieldName;
+    }
+
+    public void setPartitionFieldName(String partitionFieldName) {
+        this.partitionFieldName = partitionFieldName;
+    }
+
+    public String getPartitionFieldValue() {
+        return partitionFieldValue;
+    }
+
+    public void setPartitionFieldValue(String partitionFieldValue) {
+        this.partitionFieldValue = partitionFieldValue;
+    }
+
     public String getTerms() {
         return terms;
     }
@@ -200,6 +230,12 @@ public class CategoryDefinition implements ToXContentObject, Writeable {
         builder.startObject();
         builder.field(Job.ID.getPreferredName(), jobId);
         builder.field(CATEGORY_ID.getPreferredName(), categoryId);
+        if (partitionFieldName != null) {
+            builder.field(PARTITION_FIELD_NAME.getPreferredName(), partitionFieldName);
+        }
+        if (partitionFieldValue != null) {
+            builder.field(PARTITION_FIELD_VALUE.getPreferredName(), partitionFieldValue);
+        }
         builder.field(TERMS.getPreferredName(), terms);
         builder.field(REGEX.getPreferredName(), regex);
         builder.field(MAX_MATCHING_LENGTH.getPreferredName(), maxMatchingLength);
@@ -213,6 +249,13 @@ public class CategoryDefinition implements ToXContentObject, Writeable {
         if (numMatches > 0) {
             builder.field(NUM_MATCHES.getPreferredName(), numMatches);
         }
+
+        // Copy the patten from AnomalyRecord that by/over/partition field values are added to results
+        // as key value pairs after all the fixed fields if they won't clash with reserved fields
+        if (partitionFieldName != null && partitionFieldValue != null && ReservedFieldNames.isValidFieldName(partitionFieldName)) {
+            builder.field(partitionFieldName, partitionFieldValue);
+        }
+
         builder.endObject();
         return builder;
     }
@@ -228,6 +271,8 @@ public class CategoryDefinition implements ToXContentObject, Writeable {
         CategoryDefinition that = (CategoryDefinition) other;
         return Objects.equals(this.jobId, that.jobId)
                 && Objects.equals(this.categoryId, that.categoryId)
+                && Objects.equals(this.partitionFieldName, that.partitionFieldName)
+                && Objects.equals(this.partitionFieldValue, that.partitionFieldValue)
                 && Objects.equals(this.terms, that.terms)
                 && Objects.equals(this.regex, that.regex)
                 && Objects.equals(this.maxMatchingLength, that.maxMatchingLength)
@@ -241,6 +286,8 @@ public class CategoryDefinition implements ToXContentObject, Writeable {
     public int hashCode() {
         return Objects.hash(jobId,
             categoryId,
+            partitionFieldName,
+            partitionFieldValue,
             terms,
             regex,
             maxMatchingLength,

+ 4 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java

@@ -25,6 +25,7 @@ import org.elasticsearch.xpack.core.ml.job.config.Detector;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
 import org.elasticsearch.xpack.core.ml.job.config.Operator;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 import org.elasticsearch.xpack.core.ml.job.config.RuleCondition;
 import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;
@@ -242,6 +243,7 @@ public final class ReservedFieldNames {
             AnalysisConfig.CATEGORIZATION_FIELD_NAME.getPreferredName(),
             AnalysisConfig.CATEGORIZATION_FILTERS.getPreferredName(),
             AnalysisConfig.CATEGORIZATION_ANALYZER.getPreferredName(),
+            AnalysisConfig.PER_PARTITION_CATEGORIZATION.getPreferredName(),
             AnalysisConfig.LATENCY.getPreferredName(),
             AnalysisConfig.SUMMARY_COUNT_FIELD_NAME.getPreferredName(),
             AnalysisConfig.DETECTORS.getPreferredName(),
@@ -279,6 +281,8 @@ public final class ReservedFieldNames {
             ModelPlotConfig.TERMS_FIELD.getPreferredName(),
             ModelPlotConfig.ANNOTATIONS_ENABLED_FIELD.getPreferredName(),
 
+            PerPartitionCategorizationConfig.STOP_ON_WARN.getPreferredName(),
+
             DatafeedConfig.ID.getPreferredName(),
             DatafeedConfig.QUERY_DELAY.getPreferredName(),
             DatafeedConfig.FREQUENCY.getPreferredName(),

+ 10 - 0
x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json

@@ -193,6 +193,16 @@
           "multivariate_by_fields" : {
             "type" : "boolean"
           },
+          "per_partition_categorization" : {
+            "properties" : {
+              "enabled" : {
+                "type" : "boolean"
+              },
+              "stop_on_warn" : {
+                "type" : "boolean"
+              }
+            }
+          },
           "summary_count_field_name" : {
             "type" : "keyword"
           }

+ 3 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetCategoriesRequestTests.java

@@ -23,6 +23,9 @@ public class GetCategoriesRequestTests extends AbstractSerializingTestCase<GetCa
             int size = randomInt(10000);
             request.setPageParams(new PageParams(from, size));
         }
+        if (randomBoolean()) {
+            request.setPartitionFieldValue(randomAlphaOfLength(10));
+        }
         return request;
     }
 

+ 1 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateJobActionRequestTests.java

@@ -10,8 +10,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits;
 import org.elasticsearch.xpack.core.ml.job.config.JobUpdate;
 
-public class UpdateJobActionRequestTests
-        extends AbstractWireSerializingTestCase<UpdateJobAction.Request> {
+public class UpdateJobActionRequestTests extends AbstractWireSerializingTestCase<UpdateJobAction.Request> {
 
     @Override
     protected UpdateJobAction.Request createTestInstance() {

+ 9 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/UpdateProcessActionRequestTests.java

@@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.job.config.MlFilter;
 import org.elasticsearch.xpack.core.ml.job.config.MlFilterTests;
 import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
 import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfigTests;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 
 import java.util.ArrayList;
 import java.util.List;
@@ -20,9 +21,13 @@ public class UpdateProcessActionRequestTests extends AbstractWireSerializingTest
 
     @Override
     protected UpdateProcessAction.Request createTestInstance() {
-        ModelPlotConfig config = null;
+        ModelPlotConfig modelPlotConfig = null;
         if (randomBoolean()) {
-            config = ModelPlotConfigTests.createRandomized();
+            modelPlotConfig = ModelPlotConfigTests.createRandomized();
+        }
+        PerPartitionCategorizationConfig perPartitionCategorizationConfig = null;
+        if (randomBoolean()) {
+            perPartitionCategorizationConfig = new PerPartitionCategorizationConfig(true, randomBoolean());
         }
         List<JobUpdate.DetectorUpdate> updates = null;
         if (randomBoolean()) {
@@ -36,7 +41,8 @@ public class UpdateProcessActionRequestTests extends AbstractWireSerializingTest
         if (randomBoolean()) {
             filter = MlFilterTests.createTestFilter();
         }
-        return new UpdateProcessAction.Request(randomAlphaOfLength(10), config, updates, filter, randomBoolean());
+        return new UpdateProcessAction.Request(randomAlphaOfLength(10), modelPlotConfig, perPartitionCategorizationConfig, updates,
+            filter, randomBoolean());
     }
 
     @Override

+ 84 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/AnalysisConfigTests.java

@@ -25,6 +25,7 @@ import java.util.Set;
 import java.util.TreeSet;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
 
 public class AnalysisConfigTests extends AbstractSerializingTestCase<AnalysisConfig> {
 
@@ -39,14 +40,16 @@ public class AnalysisConfigTests extends AbstractSerializingTestCase<AnalysisCon
         int numDetectors = randomIntBetween(1, 10);
         for (int i = 0; i < numDetectors; i++) {
             Detector.Builder builder = new Detector.Builder("count", null);
-            builder.setPartitionFieldName(isCategorization ? "mlcategory" : "part");
+            if (isCategorization) {
+                builder.setByFieldName("mlcategory");
+            }
+            builder.setPartitionFieldName("part");
             detectors.add(builder.build());
         }
         AnalysisConfig.Builder builder = new AnalysisConfig.Builder(detectors);
 
-        TimeValue bucketSpan = AnalysisConfig.Builder.DEFAULT_BUCKET_SPAN;
         if (randomBoolean()) {
-            bucketSpan = TimeValue.timeValueSeconds(randomIntBetween(1, 1_000));
+            TimeValue bucketSpan = TimeValue.timeValueSeconds(randomIntBetween(1, 1_000));
             builder.setBucketSpan(bucketSpan);
         }
         if (isCategorization) {
@@ -83,6 +86,11 @@ public class AnalysisConfigTests extends AbstractSerializingTestCase<AnalysisCon
                 }
                 builder.setCategorizationAnalyzerConfig(analyzerBuilder.build());
             }
+            if (randomBoolean()) {
+                boolean enabled = randomBoolean();
+                builder.setPerPartitionCategorizationConfig(
+                    new PerPartitionCategorizationConfig(enabled, enabled && randomBoolean()));
+            }
         }
         if (randomBoolean()) {
             builder.setLatency(TimeValue.timeValueSeconds(randomIntBetween(1, 1_000_000)));
@@ -611,6 +619,79 @@ public class AnalysisConfigTests extends AbstractSerializingTestCase<AnalysisCon
         assertEquals(Messages.getMessage(Messages.JOB_CONFIG_CATEGORIZATION_FILTERS_CONTAINS_INVALID_REGEX, "("), e.getMessage());
     }
 
+    public void testVerify_GivenPerPartitionCategorizationAndNoPartitions() {
+        AnalysisConfig.Builder analysisConfig = createValidCategorizationConfig();
+        analysisConfig.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, randomBoolean()));
+
+        ElasticsearchException e = ESTestCase.expectThrows(ElasticsearchException.class, analysisConfig::build);
+
+        assertEquals(
+            "partition_field_name must be set for detectors that reference mlcategory when per-partition categorization is enabled",
+            e.getMessage());
+    }
+
+    public void testVerify_GivenPerPartitionCategorizationAndMultiplePartitionFields() {
+
+        List<Detector> detectors = new ArrayList<>();
+        for (String partitionFieldValue : Arrays.asList("part1", "part2")) {
+            Detector.Builder detector = new Detector.Builder("count", null);
+            detector.setByFieldName("mlcategory");
+            detector.setPartitionFieldName(partitionFieldValue);
+            detectors.add(detector.build());
+        }
+        AnalysisConfig.Builder analysisConfig = new AnalysisConfig.Builder(detectors);
+        analysisConfig.setCategorizationFieldName("msg");
+        analysisConfig.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, randomBoolean()));
+
+        ElasticsearchException e = ESTestCase.expectThrows(ElasticsearchException.class, analysisConfig::build);
+
+        assertEquals(
+            "partition_field_name cannot vary between detectors when per-partition categorization is enabled: [part1] and [part2] are used",
+            e.getMessage());
+    }
+
+    public void testVerify_GivenPerPartitionCategorizationAndNoPartitionFieldOnCategorizationDetector() {
+
+        List<Detector> detectors = new ArrayList<>();
+        Detector.Builder detector = new Detector.Builder("count", null);
+        detector.setByFieldName("mlcategory");
+        detectors.add(detector.build());
+        detector = new Detector.Builder("mean", "responsetime");
+        detector.setPartitionFieldName("airline");
+        detectors.add(detector.build());
+        AnalysisConfig.Builder analysisConfig = new AnalysisConfig.Builder(detectors);
+        analysisConfig.setCategorizationFieldName("msg");
+        analysisConfig.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, randomBoolean()));
+
+        ElasticsearchException e = ESTestCase.expectThrows(ElasticsearchException.class, analysisConfig::build);
+
+        assertEquals(
+            "partition_field_name must be set for detectors that reference mlcategory when per-partition categorization is enabled",
+            e.getMessage());
+    }
+
+    public void testVerify_GivenComplexPerPartitionCategorizationConfig() {
+
+        List<Detector> detectors = new ArrayList<>();
+        Detector.Builder detector = new Detector.Builder("count", null);
+        detector.setByFieldName("mlcategory");
+        detector.setPartitionFieldName("event.dataset");
+        detectors.add(detector.build());
+        detector = new Detector.Builder("mean", "responsetime");
+        detector.setByFieldName("airline");
+        detectors.add(detector.build());
+        detector = new Detector.Builder("rare", null);
+        detector.setByFieldName("mlcategory");
+        detector.setPartitionFieldName("event.dataset");
+        detectors.add(detector.build());
+        AnalysisConfig.Builder analysisConfig = new AnalysisConfig.Builder(detectors);
+        analysisConfig.setCategorizationFieldName("msg");
+        boolean stopOnWarn = randomBoolean();
+        analysisConfig.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, stopOnWarn));
+
+        assertThat(analysisConfig.build().getPerPartitionCategorizationConfig().isStopOnWarn(), is(stopOnWarn));
+    }
+
     private static AnalysisConfig.Builder createValidConfig() {
         List<Detector> detectors = new ArrayList<>();
         Detector detector = new Detector.Builder("count", null).build();

+ 18 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/JobUpdateTests.java

@@ -105,6 +105,9 @@ public class JobUpdateTests extends AbstractSerializingTestCase<JobUpdate> {
         if (randomBoolean() && jobSupportsCategorizationFilters(job)) {
             update.setCategorizationFilters(Arrays.asList(generateRandomStringArray(10, 10, false)));
         }
+        if (randomBoolean() && jobSupportsPerPartitionCategorization(job)) {
+            update.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, randomBoolean()));
+        }
         if (randomBoolean()) {
             update.setCustomSettings(Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
         }
@@ -140,6 +143,13 @@ public class JobUpdateTests extends AbstractSerializingTestCase<JobUpdate> {
         return true;
     }
 
+    private static boolean jobSupportsPerPartitionCategorization(@Nullable Job job) {
+        if (job == null) {
+            return true;
+        }
+        return job.getAnalysisConfig().getPerPartitionCategorizationConfig().isEnabled();
+    }
+
     private static List<JobUpdate.DetectorUpdate> createRandomDetectorUpdates() {
         int size = randomInt(10);
         List<JobUpdate.DetectorUpdate> detectorUpdates = new ArrayList<>(size);
@@ -241,6 +251,7 @@ public class JobUpdateTests extends AbstractSerializingTestCase<JobUpdate> {
         updateBuilder.setDailyModelSnapshotRetentionAfterDays(randomLongBetween(0, newModelSnapshotRetentionDays));
         updateBuilder.setRenormalizationWindowDays(randomNonNegativeLong());
         updateBuilder.setCategorizationFilters(categorizationFilters);
+        updateBuilder.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, randomBoolean()));
         updateBuilder.setCustomSettings(customSettings);
         updateBuilder.setModelSnapshotId(randomAlphaOfLength(10));
         updateBuilder.setJobVersion(VersionUtils.randomCompatibleVersion(random(), Version.CURRENT));
@@ -250,10 +261,12 @@ public class JobUpdateTests extends AbstractSerializingTestCase<JobUpdate> {
         jobBuilder.setGroups(Collections.singletonList("group-1"));
         Detector.Builder d1 = new Detector.Builder("info_content", "domain");
         d1.setOverFieldName("mlcategory");
+        d1.setPartitionFieldName("host");
         Detector.Builder d2 = new Detector.Builder("min", "field");
         d2.setOverFieldName("host");
         AnalysisConfig.Builder ac = new AnalysisConfig.Builder(Arrays.asList(d1.build(), d2.build()));
         ac.setCategorizationFieldName("cat_field");
+        ac.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, randomBoolean()));
         jobBuilder.setAnalysisConfig(ac);
         jobBuilder.setDataDescription(new DataDescription.Builder());
         jobBuilder.setCreateTime(new Date());
@@ -270,6 +283,8 @@ public class JobUpdateTests extends AbstractSerializingTestCase<JobUpdate> {
         assertEquals(update.getModelSnapshotRetentionDays(), updatedJob.getModelSnapshotRetentionDays());
         assertEquals(update.getResultsRetentionDays(), updatedJob.getResultsRetentionDays());
         assertEquals(update.getCategorizationFilters(), updatedJob.getAnalysisConfig().getCategorizationFilters());
+        assertEquals(update.getPerPartitionCategorizationConfig().isEnabled(),
+            updatedJob.getAnalysisConfig().getPerPartitionCategorizationConfig().isEnabled());
         assertEquals(update.getCustomSettings(), updatedJob.getCustomSettings());
         assertEquals(update.getModelSnapshotId(), updatedJob.getModelSnapshotId());
         assertEquals(update.getJobVersion(), updatedJob.getJobVersion());
@@ -306,6 +321,9 @@ public class JobUpdateTests extends AbstractSerializingTestCase<JobUpdate> {
         assertTrue(update.isAutodetectProcessUpdate());
         update = new JobUpdate.Builder("foo").setGroups(Collections.singletonList("bar")).build();
         assertTrue(update.isAutodetectProcessUpdate());
+        update = new JobUpdate.Builder("foo")
+            .setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, true)).build();
+        assertTrue(update.isAutodetectProcessUpdate());
     }
 
     public void testUpdateAnalysisLimitWithValueGreaterThanMax() {

+ 45 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/PerPartitionCategorizationConfigTests.java

@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+
+package org.elasticsearch.xpack.core.ml.job.config;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.test.AbstractSerializingTestCase;
+
+import static org.hamcrest.Matchers.is;
+
+public class PerPartitionCategorizationConfigTests extends AbstractSerializingTestCase<PerPartitionCategorizationConfig> {
+
+    public void testConstructorDefaults() {
+        assertThat(new PerPartitionCategorizationConfig().isEnabled(), is(false));
+        assertThat(new PerPartitionCategorizationConfig().isStopOnWarn(), is(false));
+    }
+
+    public void testValidation() {
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
+            () -> new PerPartitionCategorizationConfig(false, true));
+
+        assertThat(e.getMessage(), is("stop_on_warn cannot be true in per_partition_categorization when enabled is false"));
+    }
+
+    @Override
+    protected PerPartitionCategorizationConfig createTestInstance() {
+        boolean enabled = randomBoolean();
+        return new PerPartitionCategorizationConfig(enabled, randomBoolean() ? null : enabled && randomBoolean());
+    }
+
+    @Override
+    protected Writeable.Reader<PerPartitionCategorizationConfig> instanceReader() {
+        return PerPartitionCategorizationConfig::new;
+    }
+
+    @Override
+    protected PerPartitionCategorizationConfig doParseInstance(XContentParser parser) {
+        return PerPartitionCategorizationConfig.STRICT_PARSER.apply(parser, null);
+    }
+}

+ 6 - 3
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/AutodetectResultProcessorIT.java

@@ -216,7 +216,9 @@ public class AutodetectResultProcessorIT extends MlSingleNodeTestCase {
         QueryPage<Influencer> persistedInfluencers = getInfluencers();
         assertResultsAreSame(influencers, persistedInfluencers);
 
-        QueryPage<CategoryDefinition> persistedDefinition = getCategoryDefinition(categoryDefinition.getCategoryId());
+        QueryPage<CategoryDefinition> persistedDefinition =
+            getCategoryDefinition(randomBoolean() ? categoryDefinition.getCategoryId() : null,
+                randomBoolean() ? categoryDefinition.getPartitionFieldValue() : null);
         assertEquals(1, persistedDefinition.count());
         assertEquals(categoryDefinition, persistedDefinition.results().get(0));
 
@@ -597,11 +599,12 @@ public class AutodetectResultProcessorIT extends MlSingleNodeTestCase {
         return resultHolder.get();
     }
 
-    private QueryPage<CategoryDefinition> getCategoryDefinition(long categoryId) throws Exception {
+    private QueryPage<CategoryDefinition> getCategoryDefinition(Long categoryId, String partitionFieldValue) throws Exception {
         AtomicReference<Exception> errorHolder = new AtomicReference<>();
         AtomicReference<QueryPage<CategoryDefinition>> resultHolder = new AtomicReference<>();
         CountDownLatch latch = new CountDownLatch(1);
-        jobResultsProvider.categoryDefinitions(JOB_ID, categoryId, false, null, null, r -> {
+        jobResultsProvider.categoryDefinitions(JOB_ID, categoryId, partitionFieldValue, false, (categoryId == null) ? 0 : null,
+            (categoryId == null) ? 100 : null, r -> {
             resultHolder.set(r);
             latch.countDown();
         }, e -> {

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetCategoriesAction.java

@@ -37,8 +37,8 @@ public class TransportGetCategoriesAction extends HandledTransportAction<GetCate
                 jobExists -> {
                     Integer from = request.getPageParams() != null ? request.getPageParams().getFrom() : null;
                     Integer size = request.getPageParams() != null ? request.getPageParams().getSize() : null;
-                    jobResultsProvider.categoryDefinitions(request.getJobId(), request.getCategoryId(), true, from, size,
-                            r -> listener.onResponse(new GetCategoriesAction.Response(r)), listener::onFailure, client);
+                    jobResultsProvider.categoryDefinitions(request.getJobId(), request.getCategoryId(), request.getPartitionFieldValue(),
+                        true, from, size, r -> listener.onResponse(new GetCategoriesAction.Response(r)), listener::onFailure, client);
                 },
                 listener::onFailure
         ));

+ 1 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateProcessAction.java

@@ -30,6 +30,7 @@ public class TransportUpdateProcessAction extends TransportJobTaskAction<UpdateP
                                  ActionListener<UpdateProcessAction.Response> listener) {
         UpdateParams updateParams = UpdateParams.builder(request.getJobId())
                 .modelPlotConfig(request.getModelPlotConfig())
+                .perPartitionCategorizationConfig(request.getPerPartitionCategorizationConfig())
                 .detectorUpdates(request.getDetectorUpdates())
                 .filter(request.getFilter())
                 .updateScheduledEvents(request.isUpdateScheduledEvents())

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/UpdateJobProcessNotifier.java

@@ -115,8 +115,8 @@ public class UpdateJobProcessNotifier {
             return;
         }
 
-        Request request = new Request(update.getJobId(), update.getModelPlotConfig(), update.getDetectorUpdates(), update.getFilter(),
-                update.isUpdateScheduledEvents());
+        Request request = new Request(update.getJobId(), update.getModelPlotConfig(), update.getPerPartitionCategorizationConfig(),
+            update.getDetectorUpdates(), update.getFilter(), update.isUpdateScheduledEvents());
 
         executeAsyncWithOrigin(client, ML_ORIGIN, UpdateProcessAction.INSTANCE, request,
                 new ActionListener<Response>() {

+ 16 - 7
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java

@@ -372,7 +372,7 @@ public class JobResultsProvider {
     @SuppressWarnings("unchecked")
     public static int countFields(Map<String, Object> mapping) {
         Object propertiesNode = mapping.get("properties");
-        if (propertiesNode != null && propertiesNode instanceof Map) {
+        if (propertiesNode instanceof Map) {
             mapping = (Map<String, Object>) propertiesNode;
         } else {
             return 0;
@@ -800,11 +800,13 @@ public class JobResultsProvider {
      * Get a page of {@linkplain CategoryDefinition}s for the given <code>jobId</code>.
      * Uses a supplied client, so may run as the currently authenticated user
      * @param jobId the job id
+     * @param categoryId a specific category ID to retrieve, or <code>null</code> to retrieve as many as possible
+     * @param partitionFieldValue the partition field value to filter on, or <code>null</code> for no filtering
      * @param augment Should the category definition be augmented with a Grok pattern?
      * @param from  Skip the first N categories. This parameter is for paging
      * @param size  Take only this number of categories
      */
-    public void categoryDefinitions(String jobId, Long categoryId, boolean augment, Integer from, Integer size,
+    public void categoryDefinitions(String jobId, Long categoryId, String partitionFieldValue, boolean augment, Integer from, Integer size,
                                     Consumer<QueryPage<CategoryDefinition>> handler,
                                     Consumer<Exception> errorHandler, Client client) {
         if (categoryId != null && (from != null || size != null)) {
@@ -817,16 +819,25 @@ public class JobResultsProvider {
 
         SearchRequest searchRequest = new SearchRequest(indexName);
         searchRequest.indicesOptions(MlIndicesUtils.addIgnoreUnavailable(searchRequest.indicesOptions()));
+        QueryBuilder categoryIdQuery;
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
         if (categoryId != null) {
-            sourceBuilder.query(QueryBuilders.termQuery(CategoryDefinition.CATEGORY_ID.getPreferredName(), categoryId));
+            categoryIdQuery = QueryBuilders.termQuery(CategoryDefinition.CATEGORY_ID.getPreferredName(), categoryId);
         } else if (from != null && size != null) {
+            categoryIdQuery = QueryBuilders.existsQuery(CategoryDefinition.CATEGORY_ID.getPreferredName());
             sourceBuilder.from(from).size(size)
-                    .query(QueryBuilders.existsQuery(CategoryDefinition.CATEGORY_ID.getPreferredName()))
                     .sort(new FieldSortBuilder(CategoryDefinition.CATEGORY_ID.getPreferredName()).order(SortOrder.ASC));
         } else {
             throw new IllegalStateException("Both categoryId and pageParams are not specified");
         }
+        if (partitionFieldValue != null) {
+            QueryBuilder partitionQuery =
+                QueryBuilders.termQuery(CategoryDefinition.PARTITION_FIELD_VALUE.getPreferredName(), partitionFieldValue);
+            QueryBuilder combinedQuery = QueryBuilders.boolQuery().must(categoryIdQuery).must(partitionQuery);
+            sourceBuilder.query(combinedQuery);
+        } else {
+            sourceBuilder.query(categoryIdQuery);
+        }
         sourceBuilder.trackTotalHits(true);
         searchRequest.source(sourceBuilder);
         executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest,
@@ -1402,9 +1413,7 @@ public class JobResultsProvider {
 
                     executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, updateRequest,
                             ActionListener.<UpdateResponse>wrap(
-                                    response -> {
-                                        handler.accept(updatedCalendar);
-                                    },
+                                    response -> handler.accept(updatedCalendar),
                                     errorHandler)
                             , client::update);
 

+ 4 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectCommunicator.java

@@ -220,6 +220,10 @@ public class AutodetectCommunicator implements Closeable {
                 autodetectProcess.writeUpdateModelPlotMessage(update.getModelPlotConfig());
             }
 
+            if (update.getPerPartitionCategorizationConfig() != null) {
+                autodetectProcess.writeUpdatePerPartitionCategorizationMessage(update.getPerPartitionCategorizationConfig());
+            }
+
             // Filters have to be written before detectors
             if (update.getFilter() != null) {
                 autodetectProcess.writeUpdateFiltersMessage(Collections.singletonList(update.getFilter()));

+ 9 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java

@@ -10,6 +10,7 @@ import org.elasticsearch.xpack.core.ml.calendars.ScheduledEvent;
 import org.elasticsearch.xpack.core.ml.job.config.DetectionRule;
 import org.elasticsearch.xpack.core.ml.job.config.MlFilter;
 import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
 import org.elasticsearch.xpack.ml.job.persistence.StateStreamer;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams;
@@ -50,6 +51,14 @@ public interface AutodetectProcess extends NativeProcess {
      */
     void writeUpdateModelPlotMessage(ModelPlotConfig modelPlotConfig) throws IOException;
 
+    /**
+     * Update the per-partition categorization configuration
+     *
+     * @param perPartitionCategorizationConfig New per-partition categorization config
+     * @throws IOException If the write fails
+     */
+    void writeUpdatePerPartitionCategorizationMessage(PerPartitionCategorizationConfig perPartitionCategorizationConfig) throws IOException;
+
     /**
      * Write message to update the detector rules
      *

+ 5 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/BlackHoleAutodetectProcess.java

@@ -10,6 +10,7 @@ import org.elasticsearch.xpack.core.ml.calendars.ScheduledEvent;
 import org.elasticsearch.xpack.core.ml.job.config.DetectionRule;
 import org.elasticsearch.xpack.core.ml.job.config.MlFilter;
 import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 import org.elasticsearch.xpack.ml.job.persistence.StateStreamer;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.output.FlushAcknowledgement;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams;
@@ -84,6 +85,10 @@ public class BlackHoleAutodetectProcess implements AutodetectProcess {
     public void writeUpdateModelPlotMessage(ModelPlotConfig modelPlotConfig) throws IOException {
     }
 
+    @Override
+    public void writeUpdatePerPartitionCategorizationMessage(PerPartitionCategorizationConfig perPartitionCategorizationConfig) {
+    }
+
     @Override
     public void writeUpdateDetectorRulesMessage(int detectorIndex, List<DetectionRule> rules) throws IOException {
     }

+ 7 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java

@@ -12,6 +12,7 @@ import org.elasticsearch.xpack.core.ml.calendars.ScheduledEvent;
 import org.elasticsearch.xpack.core.ml.job.config.DetectionRule;
 import org.elasticsearch.xpack.core.ml.job.config.MlFilter;
 import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
 import org.elasticsearch.xpack.ml.job.persistence.StateStreamer;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams;
@@ -80,6 +81,12 @@ class NativeAutodetectProcess extends AbstractNativeProcess implements Autodetec
         newMessageWriter().writeUpdateModelPlotMessage(modelPlotConfig);
     }
 
+    @Override
+    public void writeUpdatePerPartitionCategorizationMessage(PerPartitionCategorizationConfig perPartitionCategorizationConfig)
+        throws IOException {
+        // TODO: write the control message once it's been implemented on the C++ side
+    }
+
     @Override
     public void writeUpdateDetectorRulesMessage(int detectorIndex, List<DetectionRule> rules) throws IOException {
         newMessageWriter().writeUpdateDetectorRulesMessage(detectorIndex, rules);

+ 21 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/UpdateParams.java

@@ -9,6 +9,7 @@ import org.elasticsearch.common.Nullable;
 import org.elasticsearch.xpack.core.ml.job.config.JobUpdate;
 import org.elasticsearch.xpack.core.ml.job.config.MlFilter;
 import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 
 import java.util.List;
 import java.util.Objects;
@@ -17,14 +18,18 @@ public final class UpdateParams {
 
     private final String jobId;
     private final ModelPlotConfig modelPlotConfig;
+    private final PerPartitionCategorizationConfig perPartitionCategorizationConfig;
     private final List<JobUpdate.DetectorUpdate> detectorUpdates;
     private final MlFilter filter;
     private final boolean updateScheduledEvents;
 
-    private UpdateParams(String jobId, @Nullable ModelPlotConfig modelPlotConfig, @Nullable List<JobUpdate.DetectorUpdate> detectorUpdates,
+    private UpdateParams(String jobId, @Nullable ModelPlotConfig modelPlotConfig,
+                         @Nullable PerPartitionCategorizationConfig perPartitionCategorizationConfig,
+                         @Nullable List<JobUpdate.DetectorUpdate> detectorUpdates,
                          @Nullable MlFilter filter, boolean updateScheduledEvents) {
         this.jobId = Objects.requireNonNull(jobId);
         this.modelPlotConfig = modelPlotConfig;
+        this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
         this.detectorUpdates = detectorUpdates;
         this.filter = filter;
         this.updateScheduledEvents = updateScheduledEvents;
@@ -39,6 +44,11 @@ public final class UpdateParams {
         return modelPlotConfig;
     }
 
+    @Nullable
+    public PerPartitionCategorizationConfig getPerPartitionCategorizationConfig() {
+        return perPartitionCategorizationConfig;
+    }
+
     @Nullable
     public List<JobUpdate.DetectorUpdate> getDetectorUpdates() {
         return detectorUpdates;
@@ -55,7 +65,7 @@ public final class UpdateParams {
      * update to external resources a job uses (e.g. calendars, filters).
      */
     public boolean isJobUpdate() {
-        return modelPlotConfig != null || detectorUpdates != null;
+        return modelPlotConfig != null || detectorUpdates != null || perPartitionCategorizationConfig != null;
     }
 
     public boolean isUpdateScheduledEvents() {
@@ -65,6 +75,7 @@ public final class UpdateParams {
     public static UpdateParams fromJobUpdate(JobUpdate jobUpdate) {
         return new Builder(jobUpdate.getJobId())
                 .modelPlotConfig(jobUpdate.getModelPlotConfig())
+                .perPartitionCategorizationConfig(jobUpdate.getPerPartitionCategorizationConfig())
                 .detectorUpdates(jobUpdate.getDetectorUpdates())
                 .updateScheduledEvents(jobUpdate.getGroups() != null)
                 .build();
@@ -86,6 +97,7 @@ public final class UpdateParams {
 
         private String jobId;
         private ModelPlotConfig modelPlotConfig;
+        private PerPartitionCategorizationConfig perPartitionCategorizationConfig;
         private List<JobUpdate.DetectorUpdate> detectorUpdates;
         private MlFilter filter;
         private boolean updateScheduledEvents;
@@ -99,6 +111,11 @@ public final class UpdateParams {
             return this;
         }
 
+        public Builder perPartitionCategorizationConfig(PerPartitionCategorizationConfig perPartitionCategorizationConfig) {
+            this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
+            return this;
+        }
+
         public Builder detectorUpdates(List<JobUpdate.DetectorUpdate> detectorUpdates) {
             this.detectorUpdates = detectorUpdates;
             return this;
@@ -115,7 +132,8 @@ public final class UpdateParams {
         }
 
         public UpdateParams build() {
-            return new UpdateParams(jobId, modelPlotConfig, detectorUpdates, filter, updateScheduledEvents);
+            return new UpdateParams(jobId, modelPlotConfig, perPartitionCategorizationConfig, detectorUpdates, filter,
+                updateScheduledEvents);
         }
     }
 }

+ 18 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/UpdateProcessMessage.java

@@ -10,19 +10,24 @@ import org.elasticsearch.xpack.core.ml.calendars.ScheduledEvent;
 import org.elasticsearch.xpack.core.ml.job.config.JobUpdate;
 import org.elasticsearch.xpack.core.ml.job.config.MlFilter;
 import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 
 import java.util.List;
 
 public final class UpdateProcessMessage {
 
     @Nullable private final ModelPlotConfig modelPlotConfig;
+    @Nullable private final PerPartitionCategorizationConfig perPartitionCategorizationConfig;
     @Nullable private final List<JobUpdate.DetectorUpdate> detectorUpdates;
     @Nullable private final MlFilter filter;
     @Nullable private final List<ScheduledEvent> scheduledEvents;
 
-    private UpdateProcessMessage(@Nullable ModelPlotConfig modelPlotConfig, @Nullable List<JobUpdate.DetectorUpdate> detectorUpdates,
+    private UpdateProcessMessage(@Nullable ModelPlotConfig modelPlotConfig,
+                                 @Nullable PerPartitionCategorizationConfig perPartitionCategorizationConfig,
+                                 @Nullable List<JobUpdate.DetectorUpdate> detectorUpdates,
                                  @Nullable MlFilter filter, List<ScheduledEvent> scheduledEvents) {
         this.modelPlotConfig = modelPlotConfig;
+        this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
         this.detectorUpdates = detectorUpdates;
         this.filter = filter;
         this.scheduledEvents = scheduledEvents;
@@ -33,6 +38,11 @@ public final class UpdateProcessMessage {
         return modelPlotConfig;
     }
 
+    @Nullable
+    public PerPartitionCategorizationConfig getPerPartitionCategorizationConfig() {
+        return perPartitionCategorizationConfig;
+    }
+
     @Nullable
     public List<JobUpdate.DetectorUpdate> getDetectorUpdates() {
         return detectorUpdates;
@@ -51,6 +61,7 @@ public final class UpdateProcessMessage {
     public static class Builder {
 
         @Nullable private ModelPlotConfig modelPlotConfig;
+        @Nullable private PerPartitionCategorizationConfig perPartitionCategorizationConfig;
         @Nullable private List<JobUpdate.DetectorUpdate> detectorUpdates;
         @Nullable private MlFilter filter;
         @Nullable private List<ScheduledEvent> scheduledEvents;
@@ -60,6 +71,11 @@ public final class UpdateProcessMessage {
             return this;
         }
 
+        public Builder setPerPartitionCategorizationConfig(PerPartitionCategorizationConfig perPartitionCategorizationConfig) {
+            this.perPartitionCategorizationConfig = perPartitionCategorizationConfig;
+            return this;
+        }
+
         public Builder setDetectorUpdates(List<JobUpdate.DetectorUpdate> detectorUpdates) {
             this.detectorUpdates = detectorUpdates;
             return this;
@@ -76,7 +92,7 @@ public final class UpdateProcessMessage {
         }
 
         public UpdateProcessMessage build() {
-            return new UpdateProcessMessage(modelPlotConfig, detectorUpdates, filter, scheduledEvents);
+            return new UpdateProcessMessage(modelPlotConfig, perPartitionCategorizationConfig, detectorUpdates, filter, scheduledEvents);
         }
     }
 }

+ 7 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriter.java

@@ -35,6 +35,7 @@ public class FieldConfigWriter {
     private static final String INFLUENCER_PREFIX = "influencer.";
     private static final String CATEGORIZATION_FIELD_OPTION = " categorizationfield=";
     private static final String CATEGORIZATION_FILTER_PREFIX = "categorizationfilter.";
+    private static final String PER_PARTITION_CATEGORIZATION_OPTION = " perpartitioncategorization=";
 
     // Note: for the Engine API summarycountfield is currently passed as a
     // command line option to autodetect rather than in the field config file
@@ -94,14 +95,16 @@ public class FieldConfigWriter {
     }
 
     private void writeDetectorClause(int detectorId, Detector detector, StringBuilder contents) {
-        contents.append(DETECTOR_PREFIX).append(detectorId)
-        .append(DETECTOR_CLAUSE_SUFFIX).append(EQUALS);
+        contents.append(DETECTOR_PREFIX).append(detectorId).append(DETECTOR_CLAUSE_SUFFIX).append(EQUALS);
 
         DefaultDetectorDescription.appendOn(detector, contents);
 
         if (Strings.isNullOrEmpty(config.getCategorizationFieldName()) == false) {
-            contents.append(CATEGORIZATION_FIELD_OPTION)
-            .append(quoteField(config.getCategorizationFieldName()));
+            contents.append(CATEGORIZATION_FIELD_OPTION).append(quoteField(config.getCategorizationFieldName()));
+            if (Strings.isNullOrEmpty(detector.getPartitionFieldName()) == false &&
+                config.getPerPartitionCategorizationConfig().isEnabled()) {
+                contents.append(PER_PARTITION_CATEGORIZATION_OPTION).append("true");
+            }
         }
 
         contents.append(NEW_LINE);

+ 1 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/results/RestGetCategoriesAction.java

@@ -84,6 +84,7 @@ public class RestGetCategoriesAction extends BaseRestHandler {
                         restRequest.paramAsInt(Request.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)
                 ));
             }
+            request.setPartitionFieldValue(restRequest.param(Request.PARTITION_FIELD_VALUE.getPreferredName()));
         }
 
         return channel -> client.execute(GetCategoriesAction.INSTANCE, request, new RestToXContentListener<>(channel));

+ 48 - 59
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProviderTests.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.ml.job.persistence;
 
 import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.Version;
 import org.elasticsearch.action.ActionListener;
@@ -91,10 +92,9 @@ public class JobResultsProviderTests extends ESTestCase {
 
         BucketsQueryBuilder bq = new BucketsQueryBuilder().from(from).size(size).anomalyScoreThreshold(1.0);
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<Bucket>[] holder = new QueryPage[1];
-        provider.buckets(jobId, bq, r -> holder[0] = r, e -> {throw new RuntimeException(e);}, client);
-        QueryPage<Bucket> buckets = holder[0];
+        SetOnce<QueryPage<Bucket>> holder = new SetOnce<>();
+        provider.buckets(jobId, bq, holder::set, e -> {throw new RuntimeException(e);}, client);
+        QueryPage<Bucket> buckets = holder.get();
         assertEquals(1L, buckets.count());
         QueryBuilder query = queryBuilderHolder[0];
         String queryString = query.toString();
@@ -125,10 +125,9 @@ public class JobResultsProviderTests extends ESTestCase {
         BucketsQueryBuilder bq = new BucketsQueryBuilder().from(from).size(size).anomalyScoreThreshold(5.1)
                 .includeInterim(true);
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<Bucket>[] holder = new QueryPage[1];
-        provider.buckets(jobId, bq, r -> holder[0] = r, e -> {throw new RuntimeException(e);}, client);
-        QueryPage<Bucket> buckets = holder[0];
+        SetOnce<QueryPage<Bucket>> holder = new SetOnce<>();
+        provider.buckets(jobId, bq, holder::set, e -> {throw new RuntimeException(e);}, client);
+        QueryPage<Bucket> buckets = holder.get();
         assertEquals(1L, buckets.count());
         QueryBuilder query = queryBuilderHolder[0];
         String queryString = query.toString();
@@ -161,10 +160,9 @@ public class JobResultsProviderTests extends ESTestCase {
         bq.anomalyScoreThreshold(5.1);
         bq.includeInterim(true);
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<Bucket>[] holder = new QueryPage[1];
-        provider.buckets(jobId, bq, r -> holder[0] = r, e -> {throw new RuntimeException(e);}, client);
-        QueryPage<Bucket> buckets = holder[0];
+        SetOnce<QueryPage<Bucket>> holder = new SetOnce<>();
+        provider.buckets(jobId, bq, holder::set, e -> {throw new RuntimeException(e);}, client);
+        QueryPage<Bucket> buckets = holder.get();
         assertEquals(1L, buckets.count());
         QueryBuilder query = queryBuilderHolder[0];
         String queryString = query.toString();
@@ -174,7 +172,7 @@ public class JobResultsProviderTests extends ESTestCase {
 
     public void testBucket_NoBucketNoExpand() throws IOException {
         String jobId = "TestJobIdentification";
-        Long timestamp = 98765432123456789L;
+        long timestamp = 98765432123456789L;
         List<Map<String, Object>> source = new ArrayList<>();
 
         SearchResponse response = createSearchResponse(source);
@@ -207,11 +205,10 @@ public class JobResultsProviderTests extends ESTestCase {
         BucketsQueryBuilder bq = new BucketsQueryBuilder();
         bq.timestamp(Long.toString(now.getTime()));
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<Bucket>[] bucketHolder = new QueryPage[1];
-        provider.buckets(jobId, bq, q -> bucketHolder[0] = q, e -> {}, client);
-        assertThat(bucketHolder[0].count(), equalTo(1L));
-        Bucket b = bucketHolder[0].results().get(0);
+        SetOnce<QueryPage<Bucket>> bucketHolder = new SetOnce<>();
+        provider.buckets(jobId, bq, bucketHolder::set, e -> {}, client);
+        assertThat(bucketHolder.get().count(), equalTo(1L));
+        Bucket b = bucketHolder.get().results().get(0);
         assertEquals(now, b.getTimestamp());
     }
 
@@ -248,10 +245,9 @@ public class JobResultsProviderTests extends ESTestCase {
                 .epochEnd(String.valueOf(now.getTime())).includeInterim(true).sortField(sortfield)
                 .recordScore(2.2);
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<AnomalyRecord>[] holder = new QueryPage[1];
-        provider.records(jobId, rqb, page -> holder[0] = page, RuntimeException::new, client);
-        QueryPage<AnomalyRecord> recordPage = holder[0];
+        SetOnce<QueryPage<AnomalyRecord>> holder = new SetOnce<>();
+        provider.records(jobId, rqb, holder::set, e -> { throw new RuntimeException(e); }, client);
+        QueryPage<AnomalyRecord> recordPage = holder.get();
         assertEquals(2L, recordPage.count());
         List<AnomalyRecord> records = recordPage.results();
         assertEquals(22.4, records.get(0).getTypical().get(0), 0.000001);
@@ -301,10 +297,9 @@ public class JobResultsProviderTests extends ESTestCase {
         rqb.sortField(sortfield);
         rqb.recordScore(2.2);
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<AnomalyRecord>[] holder = new QueryPage[1];
-        provider.records(jobId, rqb, page -> holder[0] = page, RuntimeException::new, client);
-        QueryPage<AnomalyRecord> recordPage = holder[0];
+        SetOnce<QueryPage<AnomalyRecord>> holder = new SetOnce<>();
+        provider.records(jobId, rqb, holder::set, e -> { throw new RuntimeException(e); }, client);
+        QueryPage<AnomalyRecord> recordPage = holder.get();
         assertEquals(2L, recordPage.count());
         List<AnomalyRecord> records = recordPage.results();
         assertEquals(22.4, records.get(0).getTypical().get(0), 0.000001);
@@ -346,11 +341,10 @@ public class JobResultsProviderTests extends ESTestCase {
         Client client = getMockedClient(qb -> {}, response);
         JobResultsProvider provider = createProvider(client);
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<AnomalyRecord>[] holder = new QueryPage[1];
-        provider.bucketRecords(jobId, bucket, from, size, true, sortfield, true, page -> holder[0] = page, RuntimeException::new,
-                client);
-        QueryPage<AnomalyRecord> recordPage = holder[0];
+        SetOnce<QueryPage<AnomalyRecord>> holder = new SetOnce<>();
+        provider.bucketRecords(jobId, bucket, from, size, true, sortfield, true, holder::set,
+                e -> { throw new RuntimeException(e); }, client);
+        QueryPage<AnomalyRecord> recordPage = holder.get();
         assertEquals(2L, recordPage.count());
         List<AnomalyRecord> records = recordPage.results();
 
@@ -384,7 +378,7 @@ public class JobResultsProviderTests extends ESTestCase {
         JobResultsProvider provider = createProvider(client);
 
         Integer[] holder = new Integer[1];
-        provider.expandBucket(jobId, false, bucket, records -> holder[0] = records, RuntimeException::new, client);
+        provider.expandBucket(jobId, false, bucket, records -> holder[0] = records, e -> { throw new RuntimeException(e); }, client);
         int records = holder[0];
         assertEquals(400L, records);
     }
@@ -407,11 +401,10 @@ public class JobResultsProviderTests extends ESTestCase {
         Client client = getMockedClient(q -> {}, response);
 
         JobResultsProvider provider = createProvider(client);
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<CategoryDefinition>[] holder = new QueryPage[1];
-        provider.categoryDefinitions(jobId, null, false, from, size, r -> holder[0] = r,
-                e -> {throw new RuntimeException(e);}, client);
-        QueryPage<CategoryDefinition> categoryDefinitions = holder[0];
+        SetOnce<QueryPage<CategoryDefinition>> holder = new SetOnce<>();
+        provider.categoryDefinitions(jobId, null, null, false, from, size, holder::set,
+                e -> { throw new RuntimeException(e); }, client);
+        QueryPage<CategoryDefinition> categoryDefinitions = holder.get();
         assertEquals(1L, categoryDefinitions.count());
         assertEquals(terms, categoryDefinitions.results().get(0).getTerms());
     }
@@ -429,11 +422,10 @@ public class JobResultsProviderTests extends ESTestCase {
         SearchResponse response = createSearchResponse(Collections.singletonList(source));
         Client client = getMockedClient(q -> {}, response);
         JobResultsProvider provider = createProvider(client);
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<CategoryDefinition>[] holder = new QueryPage[1];
-        provider.categoryDefinitions(jobId, categoryId, false, null, null,
-                r -> holder[0] = r, e -> {throw new RuntimeException(e);}, client);
-        QueryPage<CategoryDefinition> categoryDefinitions = holder[0];
+        SetOnce<QueryPage<CategoryDefinition>> holder = new SetOnce<>();
+        provider.categoryDefinitions(jobId, categoryId, null, false, null, null,
+            holder::set, e -> { throw new RuntimeException(e); }, client);
+        QueryPage<CategoryDefinition> categoryDefinitions = holder.get();
         assertEquals(1L, categoryDefinitions.count());
         assertEquals(terms, categoryDefinitions.results().get(0).getTerms());
     }
@@ -471,11 +463,10 @@ public class JobResultsProviderTests extends ESTestCase {
         Client client = getMockedClient(q -> qbHolder[0] = q, response);
         JobResultsProvider provider = createProvider(client);
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<Influencer>[] holder = new QueryPage[1];
+        SetOnce<QueryPage<Influencer>> holder = new SetOnce<>();
         InfluencersQuery query = new InfluencersQueryBuilder().from(from).size(size).includeInterim(false).build();
-        provider.influencers(jobId, query, page -> holder[0] = page, RuntimeException::new, client);
-        QueryPage<Influencer> page = holder[0];
+        provider.influencers(jobId, query, holder::set, e -> { throw new RuntimeException(e); }, client);
+        QueryPage<Influencer> page = holder.get();
         assertEquals(2L, page.count());
 
         String queryString = qbHolder[0].toString();
@@ -531,12 +522,11 @@ public class JobResultsProviderTests extends ESTestCase {
         Client client = getMockedClient(q -> qbHolder[0] = q, response);
         JobResultsProvider provider = createProvider(client);
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<Influencer>[] holder = new QueryPage[1];
+        SetOnce<QueryPage<Influencer>> holder = new SetOnce<>();
         InfluencersQuery query = new InfluencersQueryBuilder().from(from).size(size).start("0").end("0").sortField("sort")
                 .sortDescending(true).influencerScoreThreshold(0.0).includeInterim(true).build();
-        provider.influencers(jobId, query, page -> holder[0] = page, RuntimeException::new, client);
-        QueryPage<Influencer> page = holder[0];
+        provider.influencers(jobId, query, holder::set, e -> { throw new RuntimeException(e); }, client);
+        QueryPage<Influencer> page = holder.get();
         assertEquals(2L, page.count());
 
         String queryString = qbHolder[0].toString();
@@ -586,10 +576,9 @@ public class JobResultsProviderTests extends ESTestCase {
         Client client = getMockedClient(qb -> {}, response);
         JobResultsProvider provider = createProvider(client);
 
-        @SuppressWarnings({"unchecked", "rawtypes"})
-        QueryPage<ModelSnapshot>[] holder = new QueryPage[1];
-        provider.modelSnapshots(jobId, from, size, r -> holder[0] = r, RuntimeException::new);
-        QueryPage<ModelSnapshot> page = holder[0];
+        SetOnce<QueryPage<ModelSnapshot>> holder = new SetOnce<>();
+        provider.modelSnapshots(jobId, from, size, holder::set, e -> { throw new RuntimeException(e); });
+        QueryPage<ModelSnapshot> page = holder.get();
         assertEquals(2L, page.count());
         List<ModelSnapshot> snapshots = page.results();
 
@@ -607,7 +596,7 @@ public class JobResultsProviderTests extends ESTestCase {
         assertEquals(6, snapshots.get(1).getSnapshotDocCount());
     }
 
-    public void testViolatedFieldCountLimit() throws Exception {
+    public void testViolatedFieldCountLimit() {
         Map<String, Object> mapping = new HashMap<>();
 
         int i = 0;
@@ -666,7 +655,7 @@ public class JobResultsProviderTests extends ESTestCase {
     public void testTimingStats_Ok() throws IOException {
         String indexName = AnomalyDetectorsIndex.jobResultsAliasedName("foo");
         List<Map<String, Object>> source =
-            Arrays.asList(
+            Collections.singletonList(
                 Map.of(
                     Job.ID.getPreferredName(), "foo",
                     TimingStats.BUCKET_COUNT.getPreferredName(), 7,
@@ -734,7 +723,7 @@ public class JobResultsProviderTests extends ESTestCase {
 
     public void testDatafeedTimingStats_MultipleDocumentsAtOnce() throws IOException {
         List<Map<String, Object>> sourceFoo =
-            Arrays.asList(
+            Collections.singletonList(
                 Map.of(
                     Job.ID.getPreferredName(), "foo",
                     DatafeedTimingStats.SEARCH_COUNT.getPreferredName(), 6,
@@ -745,7 +734,7 @@ public class JobResultsProviderTests extends ESTestCase {
                         ExponentialAverageCalculationContext.LATEST_TIMESTAMP.getPreferredName(), Instant.ofEpochMilli(100000600),
                         ExponentialAverageCalculationContext.PREVIOUS_EXPONENTIAL_AVERAGE_MS.getPreferredName(), 60.0)));
         List<Map<String, Object>> sourceBar =
-            Arrays.asList(
+            Collections.singletonList(
                 Map.of(
                     Job.ID.getPreferredName(), "bar",
                     DatafeedTimingStats.SEARCH_COUNT.getPreferredName(), 7,
@@ -811,7 +800,7 @@ public class JobResultsProviderTests extends ESTestCase {
     public void testDatafeedTimingStats_Ok() throws IOException {
         String indexName = AnomalyDetectorsIndex.jobResultsAliasedName("foo");
         List<Map<String, Object>> source =
-            Arrays.asList(
+            Collections.singletonList(
                 Map.of(
                     Job.ID.getPreferredName(), "foo",
                     DatafeedTimingStats.SEARCH_COUNT.getPreferredName(), 6,

+ 8 - 6
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/UpdateParamsTests.java

@@ -10,9 +10,9 @@ import org.elasticsearch.xpack.core.ml.job.config.DetectionRule;
 import org.elasticsearch.xpack.core.ml.job.config.JobUpdate;
 import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig;
 import org.elasticsearch.xpack.core.ml.job.config.Operator;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 import org.elasticsearch.xpack.core.ml.job.config.RuleCondition;
 
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
@@ -21,14 +21,14 @@ public class UpdateParamsTests extends ESTestCase {
 
     public void testFromJobUpdate() {
         String jobId = "foo";
-        DetectionRule rule = new DetectionRule.Builder(Arrays.asList(
-            new RuleCondition(RuleCondition.AppliesTo.ACTUAL,
-                Operator.GT, 1.0))).build();
-        List<DetectionRule> rules = Arrays.asList(rule);
+        DetectionRule rule = new DetectionRule.Builder(Collections.singletonList(
+            new RuleCondition(RuleCondition.AppliesTo.ACTUAL, Operator.GT, 1.0))).build();
+        List<DetectionRule> rules = Collections.singletonList(rule);
         List<JobUpdate.DetectorUpdate> detectorUpdates = Collections.singletonList(
             new JobUpdate.DetectorUpdate(2, null, rules));
         JobUpdate.Builder updateBuilder = new JobUpdate.Builder(jobId)
             .setModelPlotConfig(new ModelPlotConfig())
+            .setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig())
             .setDetectorUpdates(detectorUpdates);
 
         UpdateParams params = UpdateParams.fromJobUpdate(updateBuilder.build());
@@ -36,10 +36,12 @@ public class UpdateParamsTests extends ESTestCase {
         assertFalse(params.isUpdateScheduledEvents());
         assertEquals(params.getDetectorUpdates(), updateBuilder.build().getDetectorUpdates());
         assertEquals(params.getModelPlotConfig(), updateBuilder.build().getModelPlotConfig());
+        assertEquals(params.getPerPartitionCategorizationConfig(), updateBuilder.build().getPerPartitionCategorizationConfig());
 
-        params = UpdateParams.fromJobUpdate(updateBuilder.setGroups(Arrays.asList("bar")).build());
+        params = UpdateParams.fromJobUpdate(updateBuilder.setGroups(Collections.singletonList("bar")).build());
 
         assertTrue(params.isUpdateScheduledEvents());
+        assertTrue(params.isJobUpdate());
     }
 
 }

+ 19 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriterTests.java

@@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.job.config.DetectionRule;
 import org.elasticsearch.xpack.core.ml.job.config.Detector;
 import org.elasticsearch.xpack.core.ml.job.config.MlFilter;
 import org.elasticsearch.xpack.core.ml.job.config.Operator;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 import org.elasticsearch.xpack.core.ml.job.config.RuleCondition;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.ini4j.Config;
@@ -142,6 +143,24 @@ public class FieldConfigWriterTests extends ESTestCase {
         verifyNoMoreInteractions(writer);
     }
 
+    public void testWrite_GivenConfigHasPerPartitionCategorization() throws IOException {
+        Detector.Builder d = new Detector.Builder("metric", "Integer_Value");
+        d.setByFieldName("mlcategory");
+        d.setPartitionFieldName("event.dataset");
+
+        AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(d.build()));
+        builder.setCategorizationFieldName("message");
+        builder.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(true, false));
+        analysisConfig = builder.build();
+        writer = mock(OutputStreamWriter.class);
+
+        createFieldConfigWriter().write();
+
+        verify(writer).write("detector.0.clause = metric(Integer_Value) by mlcategory partitionfield=\"event.dataset\" "
+            + "categorizationfield=message perpartitioncategorization=true\n");
+        verifyNoMoreInteractions(writer);
+    }
+
     public void testWrite_GivenConfigHasInfluencers() throws IOException {
         Detector.Builder d = new Detector.Builder("metric", "Integer_Value");
         d.setByFieldName("ts_hash");

+ 18 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/results/CategoryDefinitionTests.java

@@ -24,6 +24,10 @@ public class CategoryDefinitionTests extends AbstractBWCSerializationTestCase<Ca
     public CategoryDefinition createTestInstance(String jobId) {
         CategoryDefinition categoryDefinition = new CategoryDefinition(jobId);
         categoryDefinition.setCategoryId(randomLong());
+        if (randomBoolean()) {
+            categoryDefinition.setPartitionFieldName(randomAlphaOfLength(10));
+            categoryDefinition.setPartitionFieldValue(randomAlphaOfLength(20));
+        }
         categoryDefinition.setTerms(randomAlphaOfLength(10));
         categoryDefinition.setRegex(randomAlphaOfLength(10));
         categoryDefinition.setMaxMatchingLength(randomLong());
@@ -52,7 +56,11 @@ public class CategoryDefinitionTests extends AbstractBWCSerializationTestCase<Ca
 
     @Override
     protected CategoryDefinition doParseInstance(XContentParser parser) {
-        return CategoryDefinition.STRICT_PARSER.apply(parser, null);
+        // As a category definition contains a field named after the partition field, the parser
+        // for category definitions serialised to XContent must always ignore unknown fields.
+        // This is why the lenient parser is used in this test rather than the strict parser
+        // that most of the other tests for this package use.
+        return CategoryDefinition.LENIENT_PARSER.apply(parser, null);
     }
 
     public void testEquals_GivenSameObject() {
@@ -130,6 +138,8 @@ public class CategoryDefinitionTests extends AbstractBWCSerializationTestCase<Ca
     private static CategoryDefinition createFullyPopulatedCategoryDefinition() {
         CategoryDefinition category = new CategoryDefinition("jobName");
         category.setCategoryId(42);
+        category.setPartitionFieldName("p");
+        category.setPartitionFieldValue("v");
         category.setTerms("foo bar");
         category.setRegex(".*?foo.*?bar.*");
         category.setMaxMatchingLength(120L);
@@ -138,6 +148,9 @@ public class CategoryDefinitionTests extends AbstractBWCSerializationTestCase<Ca
         return category;
     }
 
+    /**
+     * For this class the strict parser is <em>only</em> used for parsing C++ output.
+     */
     public void testStrictParser() throws IOException {
         String json = "{\"job_id\":\"job_1\", \"foo\":\"bar\"}";
         try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) {
@@ -157,6 +170,10 @@ public class CategoryDefinitionTests extends AbstractBWCSerializationTestCase<Ca
 
     @Override
     protected CategoryDefinition mutateInstanceForVersion(CategoryDefinition instance, Version version) {
+        if (version.before(Version.V_8_0_0)) {
+            instance.setPartitionFieldName(null);
+            instance.setPartitionFieldValue(null);
+        }
         if (version.before(Version.V_7_8_0)) {
             instance.setPreferredToCategories(new long[0]);
             instance.setNumMatches(0L);

+ 4 - 0
x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_categories.json

@@ -47,6 +47,10 @@
       "size":{
         "type":"int",
         "description":"specifies a max number of categories to get"
+      },
+      "partition_field_value":{
+        "type":"string",
+        "description":"Specifies the partition to retrieve categories for. This is optional, and should never be used for jobs where per-partition categorization is disabled."
       }
     },
     "body":{

+ 52 - 6
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/jobs_get_result_categories.yml

@@ -24,7 +24,7 @@ setup:
       index:
         index:  .ml-anomalies-jobs-get-result-categories
         id:     jobs-get-result-categories-1
-        body:   { "job_id": "jobs-get-result-categories", "category_id": 1 }
+        body:   { "job_id": "jobs-get-result-categories", "category_id": 1, "partition_field_name": "p", "partition_field_value": "v1" }
   - do:
       headers:
         Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
@@ -32,14 +32,22 @@ setup:
       index:
         index:  .ml-anomalies-jobs-get-result-categories
         id:     jobs-get-result-categories-2
-        body:   { "job_id": "jobs-get-result-categories", "category_id": 2 }
+        body:   { "job_id": "jobs-get-result-categories", "category_id": 2, "partition_field_name": "p", "partition_field_value": "v2" }
   - do:
       headers:
         Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
         Content-Type: application/json
       index:
-        index:  .ml-anomalies-unrelated
+        index:  .ml-anomalies-jobs-get-result-categories
         id:     jobs-get-result-categories-3
+        body:   { "job_id": "jobs-get-result-categories", "category_id": 3, "partition_field_name": "p", "partition_field_value": "v1" }
+  - do:
+      headers:
+        Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
+        Content-Type: application/json
+      index:
+        index:  .ml-anomalies-unrelated
+        id:     unrelated-1
         body:   { "job_id": "unrelated", "category_id": 1 }
 
   - do:
@@ -54,11 +62,19 @@ setup:
       ml.get_categories:
         job_id: "jobs-get-result-categories"
 
-  - match: { count: 2 }
+  - match: { count: 3 }
   - match: { categories.0.job_id: jobs-get-result-categories }
   - match: { categories.0.category_id: 1 }
+  - match: { categories.0.partition_field_name: p }
+  - match: { categories.0.partition_field_value: v1 }
   - match: { categories.1.job_id: jobs-get-result-categories }
   - match: { categories.1.category_id: 2 }
+  - match: { categories.1.partition_field_name: p }
+  - match: { categories.1.partition_field_value: v2 }
+  - match: { categories.2.job_id: jobs-get-result-categories }
+  - match: { categories.2.category_id: 3 }
+  - match: { categories.2.partition_field_name: p }
+  - match: { categories.2.partition_field_value: v1 }
 
 ---
 "Test get categories with pagination":
@@ -74,12 +90,12 @@ setup:
   - do:
      ml.get_categories:
         job_id: "jobs-get-result-categories"
-        from: 1
+        from: 2
         size: 2
 
   - length: { categories: 1 }
   - match: { categories.0.job_id: jobs-get-result-categories }
-  - match: { categories.0.category_id: 2 }
+  - match: { categories.0.category_id: 3 }
 
 ---
 "Test post get categories with pagination":
@@ -114,8 +130,38 @@ setup:
         job_id: "jobs-get-result-categories"
         category_id: "1"
 
+  - length: { categories: 1 }
+  - match: { categories.0.job_id: jobs-get-result-categories }
+  - match: { categories.0.category_id: 1 }
+
+---
+"Test get category by partition":
+  - do:
+      ml.get_categories:
+        job_id: "jobs-get-result-categories"
+        partition_field_value: "v1"
+
+  - length: { categories: 2 }
   - match: { categories.0.job_id: jobs-get-result-categories }
   - match: { categories.0.category_id: 1 }
+  - match: { categories.1.job_id: jobs-get-result-categories }
+  - match: { categories.1.category_id: 3 }
+
+  - do:
+      ml.get_categories:
+        job_id: "jobs-get-result-categories"
+        partition_field_value: "v2"
+
+  - length: { categories: 1 }
+  - match: { categories.0.job_id: jobs-get-result-categories }
+  - match: { categories.0.category_id: 2 }
+
+  - do:
+      ml.get_categories:
+        job_id: "jobs-get-result-categories"
+        partition_field_value: "v3"
+
+  - length: { categories: 0 }
 
 ---
 "Test with invalid param combinations":