Browse Source

[ML] Pass through the stop-on-warn setting for categorization jobs (#58632)

When per_partition_categorization.stop_on_warn is set for an analysis
config it is now passed through to the autodetect C++ process.

Also adds some end-to-end tests that exercise the functionality
added in elastic/ml-cpp#1356
David Roberts 5 years ago
parent
commit
7df93562f0

+ 7 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/CategorizerStats.java

@@ -16,6 +16,7 @@ import org.elasticsearch.common.xcontent.ObjectParser.ValueType;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.core.common.time.TimeUtils;
+import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.results.ReservedFieldNames;
 import org.elasticsearch.xpack.core.ml.job.results.Result;
@@ -121,7 +122,12 @@ public class CategorizerStats implements ToXContentObject, Writeable {
     }
 
     public String getId() {
-        return documentIdPrefix(jobId) + logTime.toEpochMilli();
+        StringBuilder idBuilder = new StringBuilder(documentIdPrefix(jobId));
+        idBuilder.append(logTime.toEpochMilli());
+        if (partitionFieldName != null) {
+            idBuilder.append('_').append(MachineLearningField.valuesToId(partitionFieldValue));
+        }
+        return idBuilder.toString();
     }
 
     public static String documentIdPrefix(String jobId) {

+ 254 - 34
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/CategorizationIT.java

@@ -12,6 +12,11 @@ import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.common.xcontent.DeprecationHandler;
+import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.XContentFactory;
+import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
@@ -19,24 +24,33 @@ import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
 import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
 import org.elasticsearch.xpack.core.ml.job.config.Detector;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
+import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.CategorizationStatus;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.CategorizerState;
+import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.CategorizerStats;
 import org.elasticsearch.xpack.core.ml.job.results.CategoryDefinition;
+import org.elasticsearch.xpack.core.ml.job.results.Result;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.junit.After;
 import org.junit.Before;
 
+import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Locale;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
 
 import static org.hamcrest.Matchers.arrayWithSize;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.hasKey;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.nullValue;
 
 /**
  * A fast integration test for categorization
@@ -59,23 +73,28 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
         BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
         IndexRequest indexRequest = new IndexRequest(DATA_INDEX);
         indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis(),
-                "msg", "Node 1 started");
+            "msg", "Node 1 started",
+            "part", "nodes");
         bulkRequestBuilder.add(indexRequest);
         indexRequest = new IndexRequest(DATA_INDEX);
         indexRequest.source("time", nowMillis - TimeValue.timeValueHours(2).millis() + 1,
-                "msg", "Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused " +
-                        "by foo exception]");
+            "msg", "Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]",
+            "part", "shutdowns");
         bulkRequestBuilder.add(indexRequest);
         indexRequest = new IndexRequest(DATA_INDEX);
         indexRequest.source("time", nowMillis - TimeValue.timeValueHours(1).millis(),
-                "msg", "Node 2 started");
+            "msg", "Node 2 started",
+            "part", "nodes");
         bulkRequestBuilder.add(indexRequest);
         indexRequest = new IndexRequest(DATA_INDEX);
         indexRequest.source("time", nowMillis - TimeValue.timeValueHours(1).millis() + 1,
-                "msg", "Failed to shutdown [error but this time completely different]");
+            "msg", "Failed to shutdown [error but this time completely different]",
+            "part", "shutdowns");
         bulkRequestBuilder.add(indexRequest);
         indexRequest = new IndexRequest(DATA_INDEX);
-        indexRequest.source("time", nowMillis, "msg", "Node 3 started");
+        indexRequest.source("time", nowMillis,
+            "msg", "Node 3 started",
+            "part", "nodes");
         bulkRequestBuilder.add(indexRequest);
 
         BulkResponse bulkResponse = bulkRequestBuilder
@@ -85,14 +104,12 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
     }
 
     @After
-    public void tearDownData() {
+    public void cleanup() {
         cleanUp();
-        client().admin().indices().prepareDelete(DATA_INDEX).get();
-        refresh("*");
     }
 
     public void testBasicCategorization() throws Exception {
-        Job.Builder job = newJobBuilder("categorization", Collections.emptyList());
+        Job.Builder job = newJobBuilder("categorization", Collections.emptyList(), false);
         registerJob(job);
         putJob(job);
         openJob(job.getId());
@@ -107,37 +124,129 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
         waitUntilJobIsClosed(job.getId());
 
         List<CategoryDefinition> categories = getCategories(job.getId());
-        assertThat(categories.size(), equalTo(3));
+        assertThat(categories, hasSize(3));
 
         CategoryDefinition category1 = categories.get(0);
         assertThat(category1.getRegex(), equalTo(".*?Node.+?started.*"));
-        assertThat(category1.getExamples(),
-                equalTo(Arrays.asList("Node 1 started", "Node 2 started")));
+        assertThat(category1.getExamples(), equalTo(Arrays.asList("Node 1 started", "Node 2 started")));
 
         CategoryDefinition category2 = categories.get(1);
-        assertThat(category2.getRegex(), equalTo(".*?Failed.+?to.+?shutdown.+?error.+?" +
-                "org\\.aaaa\\.bbbb\\.Cccc.+?line.+?caused.+?by.+?foo.+?exception.*"));
-        assertThat(category2.getExamples(), equalTo(Collections.singletonList(
-                "Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]")));
+        assertThat(category2.getRegex(),
+            equalTo(".*?Failed.+?to.+?shutdown.+?error.+?org\\.aaaa\\.bbbb\\.Cccc.+?line.+?caused.+?by.+?foo.+?exception.*"));
+        assertThat(category2.getExamples(),
+            equalTo(Collections.singletonList("Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]")));
 
         CategoryDefinition category3 = categories.get(2);
-        assertThat(category3.getRegex(), equalTo(".*?Failed.+?to.+?shutdown.+?error.+?but.+?" +
-                "this.+?time.+?completely.+?different.*"));
-        assertThat(category3.getExamples(), equalTo(Collections.singletonList(
-                "Failed to shutdown [error but this time completely different]")));
+        assertThat(category3.getRegex(),
+            equalTo(".*?Failed.+?to.+?shutdown.+?error.+?but.+?this.+?time.+?completely.+?different.*"));
+        assertThat(category3.getExamples(),
+            equalTo(Collections.singletonList("Failed to shutdown [error but this time completely different]")));
+
+        List<CategorizerStats> stats = getCategorizerStats(job.getId());
+        assertThat(stats, hasSize(1));
+        assertThat(stats.get(0).getCategorizationStatus(), equalTo(CategorizationStatus.OK));
+        assertThat(stats.get(0).getCategorizedDocCount(), equalTo(4L));
+        assertThat(stats.get(0).getTotalCategoryCount(), equalTo(3L));
+        assertThat(stats.get(0).getFrequentCategoryCount(), equalTo(1L));
+        assertThat(stats.get(0).getRareCategoryCount(), equalTo(2L));
+        assertThat(stats.get(0).getDeadCategoryCount(), equalTo(0L));
+        assertThat(stats.get(0).getFailedCategoryCount(), equalTo(0L));
+        assertThat(stats.get(0).getPartitionFieldName(), nullValue());
+        assertThat(stats.get(0).getPartitionFieldValue(), nullValue());
 
-        openJob("categorization");
+        openJob(job.getId());
         startDatafeed(datafeedId, 0, nowMillis + 1);
         waitUntilJobIsClosed(job.getId());
 
         categories = getCategories(job.getId());
-        assertThat(categories.size(), equalTo(3));
-        assertThat(categories.get(0).getExamples(),
-                equalTo(Arrays.asList("Node 1 started", "Node 2 started", "Node 3 started")));
+        assertThat(categories, hasSize(3));
+        assertThat(categories.get(0).getExamples(), equalTo(Arrays.asList("Node 1 started", "Node 2 started", "Node 3 started")));
+
+        stats = getCategorizerStats(job.getId());
+        assertThat(stats, hasSize(2));
+    }
+
+    public void testPerPartitionCategorization() throws Exception {
+        Job.Builder job = newJobBuilder("per-partition-categorization", Collections.emptyList(), true);
+        registerJob(job);
+        putJob(job);
+        openJob(job.getId());
+
+        String datafeedId = job.getId() + "-feed";
+        DatafeedConfig.Builder datafeedConfig = new DatafeedConfig.Builder(datafeedId, job.getId());
+        datafeedConfig.setIndices(Collections.singletonList(DATA_INDEX));
+        DatafeedConfig datafeed = datafeedConfig.build();
+        registerDatafeed(datafeed);
+        putDatafeed(datafeed);
+        startDatafeed(datafeedId, 0, nowMillis);
+        waitUntilJobIsClosed(job.getId());
+
+        List<CategoryDefinition> categories = getCategories(job.getId());
+        assertThat(categories, hasSize(3));
+
+        CategoryDefinition category1 = categories.get(0);
+        assertThat(category1.getRegex(), equalTo(".*?Node.+?started.*"));
+        assertThat(category1.getExamples(), equalTo(Arrays.asList("Node 1 started", "Node 2 started")));
+        assertThat(category1.getPartitionFieldName(), equalTo("part"));
+        assertThat(category1.getPartitionFieldValue(), equalTo("nodes"));
+
+        CategoryDefinition category2 = categories.get(1);
+        assertThat(category2.getRegex(),
+            equalTo(".*?Failed.+?to.+?shutdown.+?error.+?org\\.aaaa\\.bbbb\\.Cccc.+?line.+?caused.+?by.+?foo.+?exception.*"));
+        assertThat(category2.getExamples(),
+            equalTo(Collections.singletonList("Failed to shutdown [error org.aaaa.bbbb.Cccc line 54 caused by foo exception]")));
+        assertThat(category2.getPartitionFieldName(), equalTo("part"));
+        assertThat(category2.getPartitionFieldValue(), equalTo("shutdowns"));
+
+        CategoryDefinition category3 = categories.get(2);
+        assertThat(category3.getRegex(),
+            equalTo(".*?Failed.+?to.+?shutdown.+?error.+?but.+?this.+?time.+?completely.+?different.*"));
+        assertThat(category3.getExamples(),
+            equalTo(Collections.singletonList("Failed to shutdown [error but this time completely different]")));
+        assertThat(category3.getPartitionFieldName(), equalTo("part"));
+        assertThat(category3.getPartitionFieldValue(), equalTo("shutdowns"));
+
+        List<CategorizerStats> stats = getCategorizerStats(job.getId());
+        assertThat(stats, hasSize(2));
+        for (int i = 0; i < 2; ++i) {
+            if ("nodes".equals(stats.get(i).getPartitionFieldValue())) {
+                assertThat(stats.get(i).getCategorizationStatus(), equalTo(CategorizationStatus.OK));
+                assertThat(stats.get(i).getCategorizedDocCount(), equalTo(2L));
+                assertThat(stats.get(i).getTotalCategoryCount(), equalTo(1L));
+                assertThat(stats.get(i).getFrequentCategoryCount(), equalTo(1L));
+                assertThat(stats.get(i).getRareCategoryCount(), equalTo(0L));
+                assertThat(stats.get(i).getDeadCategoryCount(), equalTo(0L));
+                assertThat(stats.get(i).getFailedCategoryCount(), equalTo(0L));
+                assertThat(stats.get(i).getPartitionFieldName(), equalTo("part"));
+            } else {
+                assertThat(stats.get(i).getCategorizationStatus(), equalTo(CategorizationStatus.OK));
+                assertThat(stats.get(i).getCategorizedDocCount(), equalTo(2L));
+                assertThat(stats.get(i).getTotalCategoryCount(), equalTo(2L));
+                assertThat(stats.get(i).getFrequentCategoryCount(), equalTo(0L));
+                assertThat(stats.get(i).getRareCategoryCount(), equalTo(2L));
+                assertThat(stats.get(i).getDeadCategoryCount(), equalTo(0L));
+                assertThat(stats.get(i).getFailedCategoryCount(), equalTo(0L));
+                assertThat(stats.get(i).getPartitionFieldName(), equalTo("part"));
+                assertThat(stats.get(i).getPartitionFieldValue(), equalTo("shutdowns"));
+            }
+        }
+
+        openJob(job.getId());
+        startDatafeed(datafeedId, 0, nowMillis + 1);
+        waitUntilJobIsClosed(job.getId());
+
+        categories = getCategories(job.getId());
+        assertThat(categories, hasSize(3));
+        assertThat(categories.get(0).getExamples(), equalTo(Arrays.asList("Node 1 started", "Node 2 started", "Node 3 started")));
+        assertThat(categories.get(0).getPartitionFieldName(), equalTo("part"));
+        assertThat(categories.get(0).getPartitionFieldValue(), equalTo("nodes"));
+
+        stats = getCategorizerStats(job.getId());
+        assertThat(stats, hasSize(3));
     }
 
     public void testCategorizationWithFilters() throws Exception {
-        Job.Builder job = newJobBuilder("categorization-with-filters", Collections.singletonList("\\[.*\\]"));
+        Job.Builder job = newJobBuilder("categorization-with-filters", Collections.singletonList("\\[.*\\]"), false);
         registerJob(job);
         putJob(job);
         openJob(job.getId());
@@ -152,7 +261,7 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
         waitUntilJobIsClosed(job.getId());
 
         List<CategoryDefinition> categories = getCategories(job.getId());
-        assertThat(categories.size(), equalTo(2));
+        assertThat(categories, hasSize(2));
 
         CategoryDefinition category1 = categories.get(0);
         assertThat(category1.getRegex(), equalTo(".*?Node.+?started.*"));
@@ -167,7 +276,7 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
     }
 
     public void testCategorizationStatePersistedOnSwitchToRealtime() throws Exception {
-        Job.Builder job = newJobBuilder("categorization-swtich-to-realtime", Collections.emptyList());
+        Job.Builder job = newJobBuilder("categorization-swtich-to-realtime", Collections.emptyList(), false);
         registerJob(job);
         putJob(job);
         openJob(job.getId());
@@ -198,7 +307,7 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
         closeJob(job.getId());
 
         List<CategoryDefinition> categories = getCategories(job.getId());
-        assertThat(categories.size(), equalTo(3));
+        assertThat(categories, hasSize(3));
 
         CategoryDefinition category1 = categories.get(0);
         assertThat(category1.getRegex(), equalTo(".*?Node.+?started.*"));
@@ -244,7 +353,7 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
         };
 
         String jobId = "categorization-performance";
-        Job.Builder job = newJobBuilder(jobId, Collections.emptyList());
+        Job.Builder job = newJobBuilder(jobId, Collections.emptyList(), false);
         registerJob(job);
         putJob(job);
         openJob(job.getId());
@@ -266,6 +375,95 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
                 (MachineLearning.CATEGORIZATION_TOKENIZATION_IN_JAVA ? "Java" : "C++") + " took " + duration + "ms");
     }
 
+    public void testStopOnWarn() throws IOException {
+
+        long testTime = System.currentTimeMillis();
+
+        String jobId = "categorization-stop-on-warn";
+        Job.Builder job = newJobBuilder(jobId, Collections.emptyList(), true);
+        registerJob(job);
+        putJob(job);
+        openJob(job.getId());
+
+        String[] messages = new String[] { "Node 1 started", "Failed to shutdown" };
+        String[] partitions = new String[] { "nodes", "shutdowns" };
+
+        StringBuilder json = new StringBuilder(1000);
+        for (int docNum = 0; docNum < 200; ++docNum) {
+            // Two thirds of our messages are "Node 1 started", the rest "Failed to shutdown"
+            int partitionNum = (docNum % 3) / 2;
+            json.append(String.format(Locale.ROOT, "{\"time\":1000000,\"part\":\"%s\",\"msg\":\"%s\"}\n",
+                partitions[partitionNum], messages[partitionNum]));
+        }
+        postData(jobId, json.toString());
+
+        flushJob(jobId, false);
+
+        Consumer<CategorizerStats> checkStatsAt1000000 = stats -> {
+            assertThat(stats.getTimestamp().toEpochMilli(), is(1000000L));
+            if ("nodes".equals(stats.getPartitionFieldValue())) {
+                assertThat(stats.getCategorizationStatus(), equalTo(CategorizationStatus.WARN));
+                // We've sent 134 messages but only 100 should have been categorized as the partition went to "warn" status after 100
+                assertThat(stats.getCategorizedDocCount(), equalTo(100L));
+                assertThat(stats.getTotalCategoryCount(), equalTo(1L));
+                assertThat(stats.getFrequentCategoryCount(), equalTo(1L));
+                assertThat(stats.getRareCategoryCount(), equalTo(0L));
+                assertThat(stats.getDeadCategoryCount(), equalTo(0L));
+                assertThat(stats.getFailedCategoryCount(), equalTo(0L));
+                assertThat(stats.getPartitionFieldName(), equalTo("part"));
+            } else {
+                assertThat(stats.getCategorizationStatus(), equalTo(CategorizationStatus.OK));
+                assertThat(stats.getCategorizedDocCount(), equalTo(66L));
+                assertThat(stats.getTotalCategoryCount(), equalTo(1L));
+                assertThat(stats.getFrequentCategoryCount(), equalTo(1L));
+                assertThat(stats.getRareCategoryCount(), equalTo(0L));
+                assertThat(stats.getDeadCategoryCount(), equalTo(0L));
+                assertThat(stats.getFailedCategoryCount(), equalTo(0L));
+                assertThat(stats.getPartitionFieldName(), equalTo("part"));
+                assertThat(stats.getPartitionFieldValue(), equalTo("shutdowns"));
+            }
+            assertThat(stats.getLogTime().toEpochMilli(), greaterThanOrEqualTo(testTime));
+        };
+
+        List<CategorizerStats> stats = getCategorizerStats(jobId);
+        assertThat(stats, hasSize(2));
+        for (int i = 0; i < 2; ++i) {
+            checkStatsAt1000000.accept(stats.get(i));
+        }
+
+        postData(jobId, json.toString().replace("1000000", "2000000"));
+        closeJob(jobId);
+
+        stats = getCategorizerStats(jobId);
+        assertThat(stats, hasSize(3));
+        int numStatsAt2000000 = 0;
+        for (int i = 0; i < 3; ++i) {
+            if (stats.get(i).getTimestamp().toEpochMilli() == 2000000L) {
+                ++numStatsAt2000000;
+                // Now the "shutdowns" partition has seen more than 100 messages and only has 1 category so that should be in "warn" status
+                assertThat(stats.get(i).getCategorizationStatus(), equalTo(CategorizationStatus.WARN));
+                // We've sent 132 messages but only 100 should have been categorized as the partition went to "warn" status after 100
+                assertThat(stats.get(i).getCategorizedDocCount(), equalTo(100L));
+                assertThat(stats.get(i).getTotalCategoryCount(), equalTo(1L));
+                assertThat(stats.get(i).getFrequentCategoryCount(), equalTo(1L));
+                assertThat(stats.get(i).getRareCategoryCount(), equalTo(0L));
+                assertThat(stats.get(i).getDeadCategoryCount(), equalTo(0L));
+                assertThat(stats.get(i).getFailedCategoryCount(), equalTo(0L));
+                assertThat(stats.get(i).getPartitionFieldName(), equalTo("part"));
+                assertThat(stats.get(i).getPartitionFieldValue(), equalTo("shutdowns"));
+                assertThat(stats.get(i).getLogTime().toEpochMilli(), greaterThanOrEqualTo(testTime));
+            } else {
+                // The other stats documents are left over from the flush and should not have been updated,
+                // so should be identical to how they were then
+                checkStatsAt1000000.accept(stats.get(i));
+            }
+        }
+
+        // The stats for the "nodes" partition haven't changed, so should not have been updated; all messages
+        // sent at time 2000000 for the "nodes" partition should have been ignored as it had "warn" status
+        assertThat(numStatsAt2000000, is(1));
+    }
+
     public void testNumMatchesAndCategoryPreference() throws Exception {
         String index = "hadoop_logs";
         client().admin().indices().prepareCreate(index)
@@ -326,7 +524,7 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
             .get();
         assertThat(bulkResponse.hasFailures(), is(false));
 
-        Job.Builder job = newJobBuilder("categorization-with-preferred-categories", Collections.emptyList());
+        Job.Builder job = newJobBuilder("categorization-with-preferred-categories", Collections.emptyList(), false);
         registerJob(job);
         putJob(job);
         openJob(job.getId());
@@ -350,15 +548,18 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
         client().admin().indices().prepareDelete(index).get();
     }
 
-    private static Job.Builder newJobBuilder(String id, List<String> categorizationFilters) {
+    private static Job.Builder newJobBuilder(String id, List<String> categorizationFilters, boolean isPerPartition) {
         Detector.Builder detector = new Detector.Builder();
         detector.setFunction("count");
         detector.setByFieldName("mlcategory");
-        AnalysisConfig.Builder analysisConfig = new AnalysisConfig.Builder(
-                Collections.singletonList(detector.build()));
+        if (isPerPartition) {
+            detector.setPartitionFieldName("part");
+        }
+        AnalysisConfig.Builder analysisConfig = new AnalysisConfig.Builder(Collections.singletonList(detector.build()));
         analysisConfig.setBucketSpan(TimeValue.timeValueHours(1));
         analysisConfig.setCategorizationFieldName("msg");
         analysisConfig.setCategorizationFilters(categorizationFilters);
+        analysisConfig.setPerPartitionCategorizationConfig(new PerPartitionCategorizationConfig(isPerPartition, isPerPartition));
         DataDescription.Builder dataDescription = new DataDescription.Builder();
         dataDescription.setTimeField("time");
         Job.Builder jobBuilder = new Job.Builder(id);
@@ -366,4 +567,23 @@ public class CategorizationIT extends MlNativeAutodetectIntegTestCase {
         jobBuilder.setDataDescription(dataDescription);
         return jobBuilder;
     }
+
+    private List<CategorizerStats> getCategorizerStats(String jobId) throws IOException {
+
+        SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobResultsAliasedName(jobId))
+            .setQuery(QueryBuilders.boolQuery()
+                .filter(QueryBuilders.termQuery(Result.RESULT_TYPE.getPreferredName(), CategorizerStats.RESULT_TYPE_VALUE))
+                .filter(QueryBuilders.termQuery(Job.ID.getPreferredName(), jobId)))
+            .setSize(1000)
+            .get();
+
+        List<CategorizerStats> stats = new ArrayList<>();
+        for (SearchHit hit : searchResponse.getHits().getHits()) {
+            try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(
+                NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, hit.getSourceRef().streamInput())) {
+                stats.add(CategorizerStats.LENIENT_PARSER.apply(parser, null).build());
+            }
+        }
+        return stats;
+    }
 }

+ 4 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectBuilder.java

@@ -73,6 +73,7 @@ public class AutodetectBuilder {
     static final String MAX_QUANTILE_INTERVAL_ARG = "--maxQuantileInterval=";
     static final String SUMMARY_COUNT_FIELD_ARG = "--summarycountfield=";
     static final String TIME_FIELD_ARG = "--timefield=";
+    static final String STOP_CATEGORIZATION_ON_WARN_ARG = "--stopCategorizationOnWarnStatus";
 
     /**
      * Name of the config setting containing the path to the logs directory
@@ -198,6 +199,9 @@ public class AutodetectBuilder {
             if (Boolean.TRUE.equals(analysisConfig.getMultivariateByFields())) {
                 command.add(MULTIVARIATE_BY_FIELDS_ARG);
             }
+            if (Boolean.TRUE.equals(analysisConfig.getPerPartitionCategorizationConfig().isStopOnWarn())) {
+                command.add(STOP_CATEGORIZATION_ON_WARN_ARG);
+            }
         }
 
         // Input is always length encoded

+ 1 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java

@@ -66,8 +66,7 @@ public interface AutodetectProcess extends NativeProcess {
      * @param rules Detector rules
      * @throws IOException If the write fails
      */
-    void writeUpdateDetectorRulesMessage(int detectorIndex, List<DetectionRule> rules)
-            throws IOException;
+    void writeUpdateDetectorRulesMessage(int detectorIndex, List<DetectionRule> rules) throws IOException;
 
     /**
      * Write message to update the filters

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java

@@ -361,7 +361,7 @@ public class AutodetectProcessManager implements ClusterStateListener {
             GetFiltersAction.Request getFilterRequest = new GetFiltersAction.Request(updateParams.getFilter().getId());
             executeAsyncWithOrigin(client, ML_ORIGIN, GetFiltersAction.INSTANCE, getFilterRequest, ActionListener.wrap(
                 getFilterResponse -> filterListener.onResponse(getFilterResponse.getFilters().results().get(0)),
-                handler::accept
+                handler
             ));
         }
     }

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java

@@ -84,7 +84,7 @@ class NativeAutodetectProcess extends AbstractNativeProcess implements Autodetec
     @Override
     public void writeUpdatePerPartitionCategorizationMessage(PerPartitionCategorizationConfig perPartitionCategorizationConfig)
         throws IOException {
-        // TODO: write the control message once it's been implemented on the C++ side
+        newMessageWriter().writeCategorizationStopOnWarnMessage(perPartitionCategorizationConfig.isStopOnWarn());
     }
 
     @Override

+ 11 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/AutodetectControlMsgWriter.java

@@ -33,6 +33,11 @@ import java.util.concurrent.atomic.AtomicLong;
  */
 public class AutodetectControlMsgWriter extends AbstractControlMsgWriter {
 
+    /**
+     * This must match the code defined in the api::CFieldDataCategorizer C++ class.
+     */
+    private static final String CATEGORIZATION_STOP_ON_WARN_MESSAGE_CODE = "c";
+
     /**
      * This must match the code defined in the api::CAnomalyJob C++ class.
      */
@@ -41,12 +46,12 @@ public class AutodetectControlMsgWriter extends AbstractControlMsgWriter {
     /**
      * This must match the code defined in the api::CAnomalyJob C++ class.
      */
-    private static final String FORECAST_MESSAGE_CODE = "p";
+    private static final String INTERIM_MESSAGE_CODE = "i";
 
     /**
      * This must match the code defined in the api::CAnomalyJob C++ class.
      */
-    private static final String INTERIM_MESSAGE_CODE = "i";
+    private static final String FORECAST_MESSAGE_CODE = "p";
 
     /**
      * This must match the code defined in the api::CAnomalyJob C++ class.
@@ -190,6 +195,10 @@ public class AutodetectControlMsgWriter extends AbstractControlMsgWriter {
         writeMessage(configWriter.toString());
     }
 
+    public void writeCategorizationStopOnWarnMessage(boolean isStopOnWarn) throws IOException {
+        writeMessage(CATEGORIZATION_STOP_ON_WARN_MESSAGE_CODE + isStopOnWarn);
+    }
+
     public void writeUpdateDetectorRulesMessage(int detectorIndex, List<DetectionRule> rules) throws IOException {
         StringBuilder stringBuilder = new StringBuilder();
         stringBuilder.append(UPDATE_MESSAGE_CODE).append("[detectorRules]\n");

+ 2 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/ProcessResultsParser.java

@@ -48,7 +48,7 @@ public class ProcessResultsParser<T> {
             if (token != XContentParser.Token.START_ARRAY) {
                 throw new ElasticsearchParseException("unexpected token [" + token + "]");
             }
-            return new ResultIterator(in, parser);
+            return new ResultIterator(parser);
         } catch (IOException e) {
             throw new ElasticsearchParseException(e.getMessage(), e);
         }
@@ -56,12 +56,10 @@ public class ProcessResultsParser<T> {
 
     private class ResultIterator implements Iterator<T> {
 
-        private final InputStream in;
         private final XContentParser parser;
         private XContentParser.Token token;
 
-        private ResultIterator(InputStream in, XContentParser parser) {
-            this.in = in;
+        private ResultIterator(XContentParser parser) {
             this.parser = parser;
             token = parser.currentToken();
         }

+ 15 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectBuilderTests.java

@@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
 import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
 import org.elasticsearch.xpack.core.ml.job.config.Detector;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
+import org.elasticsearch.xpack.core.ml.job.config.PerPartitionCategorizationConfig;
 import org.elasticsearch.xpack.ml.process.NativeController;
 import org.elasticsearch.xpack.ml.process.ProcessPipes;
 import org.junit.Before;
@@ -24,6 +25,7 @@ import java.util.Collections;
 import java.util.List;
 
 import static org.elasticsearch.xpack.core.ml.job.config.JobTests.buildJobBuilder;
+import static org.hamcrest.Matchers.is;
 import static org.mockito.Mockito.mock;
 
 public class AutodetectBuilderTests extends ESTestCase {
@@ -46,15 +48,25 @@ public class AutodetectBuilderTests extends ESTestCase {
     }
 
     public void testBuildAutodetectCommand() {
+        boolean isPerPartitionCategorization = randomBoolean();
+
         Job.Builder job = buildJobBuilder("unit-test-job");
 
         Detector.Builder detectorBuilder = new Detector.Builder("mean", "value");
+        if (isPerPartitionCategorization) {
+            detectorBuilder.setByFieldName("mlcategory");
+        }
         detectorBuilder.setPartitionFieldName("foo");
         AnalysisConfig.Builder acBuilder = new AnalysisConfig.Builder(Collections.singletonList(detectorBuilder.build()));
         acBuilder.setBucketSpan(TimeValue.timeValueSeconds(120));
         acBuilder.setLatency(TimeValue.timeValueSeconds(360));
         acBuilder.setSummaryCountFieldName("summaryField");
         acBuilder.setMultivariateByFields(true);
+        if (isPerPartitionCategorization) {
+            acBuilder.setCategorizationFieldName("bar");
+        }
+        acBuilder.setPerPartitionCategorizationConfig(
+            new PerPartitionCategorizationConfig(isPerPartitionCategorization, isPerPartitionCategorization));
 
         job.setAnalysisConfig(acBuilder);
 
@@ -65,12 +77,12 @@ public class AutodetectBuilderTests extends ESTestCase {
         job.setDataDescription(dd);
 
         List<String> command = autodetectBuilder(job.build()).buildAutodetectCommand();
-        assertEquals(11, command.size());
         assertTrue(command.contains(AutodetectBuilder.AUTODETECT_PATH));
         assertTrue(command.contains(AutodetectBuilder.BUCKET_SPAN_ARG + "120"));
         assertTrue(command.contains(AutodetectBuilder.LATENCY_ARG + "360"));
         assertTrue(command.contains(AutodetectBuilder.SUMMARY_COUNT_FIELD_ARG + "summaryField"));
         assertTrue(command.contains(AutodetectBuilder.MULTIVARIATE_BY_FIELDS_ARG));
+        assertThat(command.contains(AutodetectBuilder.STOP_CATEGORIZATION_ON_WARN_ARG), is(isPerPartitionCategorization));
 
         assertTrue(command.contains(AutodetectBuilder.LENGTH_ENCODED_INPUT_ARG));
         assertTrue(command.contains(AutodetectBuilder.maxAnomalyRecordsArg(settings)));
@@ -82,6 +94,8 @@ public class AutodetectBuilderTests extends ESTestCase {
         assertTrue(command.contains(AutodetectBuilder.PERSIST_INTERVAL_ARG + expectedPersistInterval));
         int expectedMaxQuantileInterval = 21600 + AutodetectBuilder.calculateStaggeringInterval(job.getId());
         assertTrue(command.contains(AutodetectBuilder.MAX_QUANTILE_INTERVAL_ARG + expectedMaxQuantileInterval));
+
+        assertEquals(isPerPartitionCategorization ? 12 : 11, command.size());
     }
 
     public void testBuildAutodetectCommand_defaultTimeField() {