Browse Source

[ML] Modify test case to update running job (#124287) (#124523)

This PR makes a change to the existing Java REST test DetectionRulesIT.testCondition such that it updates detection rules for a running job. Previously it had relied on closing and re-opening the job for the update to take effect.

Relates elastic/ml-cpp#2821
Ed Savage 7 months ago
parent
commit
6fe817f5ce

+ 124 - 5
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java

@@ -11,6 +11,7 @@ import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
+import org.elasticsearch.xpack.core.ml.action.PutFilterAction;
 import org.elasticsearch.xpack.core.ml.action.UpdateFilterAction;
 import org.elasticsearch.xpack.core.ml.annotations.Annotation;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
@@ -98,7 +99,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase {
 
         // push the data for the first half buckets
         postData(job.getId(), joinBetween(0, data.size() / 2, data));
-        closeJob(job.getId());
+        flushJob(job.getId(), true);
 
         List<AnomalyRecord> records = getRecords(job.getId());
         // remove records that are not anomalies
@@ -116,18 +117,35 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase {
             JobUpdate.Builder update = new JobUpdate.Builder(job.getId());
             update.setDetectorUpdates(Arrays.asList(new JobUpdate.DetectorUpdate(0, null, Arrays.asList(newRule))));
             updateJob(job.getId(), update.build());
+            // Wait until the notification that the job was updated is indexed
+            assertBusy(
+                () -> assertResponse(
+                    prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX).setSize(1)
+                        .addSort("timestamp", SortOrder.DESC)
+                        .setQuery(
+                            QueryBuilders.boolQuery()
+                                .filter(QueryBuilders.termQuery("job_id", job.getId()))
+                                .filter(QueryBuilders.termQuery("level", "info"))
+                        ),
+                    searchResponse -> {
+                        SearchHit[] hits = searchResponse.getHits().getHits();
+                        assertThat(hits.length, equalTo(1));
+                        assertThat((String) hits[0].getSourceAsMap().get("message"), containsString("Job updated: [detectors]"));
+                    }
+                )
+            );
         }
 
         // push second half
-        openJob(job.getId());
         postData(job.getId(), joinBetween(data.size() / 2, data.size(), data));
-        closeJob(job.getId());
+        flushJob(job.getId(), true);
 
         GetRecordsAction.Request recordsAfterFirstHalf = new GetRecordsAction.Request(job.getId());
         recordsAfterFirstHalf.setStart(String.valueOf(firstRecordTimestamp + 1));
         records = getRecords(recordsAfterFirstHalf);
         assertThat("records were " + records, (int) (records.stream().filter(r -> r.getProbability() < 0.01).count()), equalTo(1));
         assertThat(records.get(0).getByFieldValue(), equalTo("low"));
+        closeJob(job.getId());
     }
 
     public void testScope() throws Exception {
@@ -242,7 +260,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase {
         closeJob(job.getId());
     }
 
-    public void testScopeAndCondition() throws IOException {
+    public void testScopeAndCondition() throws Exception {
         // We have 2 IPs and they're both safe-listed.
         List<String> ips = Arrays.asList("111.111.111.111", "222.222.222.222");
         MlFilter safeIps = MlFilter.builder("safe_ips").setItems(ips).build();
@@ -298,11 +316,112 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase {
         }
 
         postData(job.getId(), joinBetween(0, data.size(), data));
-        closeJob(job.getId());
+        flushJob(job.getId(), true);
 
         List<AnomalyRecord> records = getRecords(job.getId());
         assertThat(records.size(), equalTo(1));
         assertThat(records.get(0).getOverFieldValue(), equalTo("222.222.222.222"));
+
+        // Remove "111.111.111.111" from the "safe_ips" filter
+        List<String> addedIps = Arrays.asList();
+        List<String> removedIps = Arrays.asList("111.111.111.111");
+        PutFilterAction.Response updatedFilter = updateMlFilter("safe_ips", addedIps, removedIps);
+        // Wait until the notification that the filter was updated is indexed
+        assertBusy(
+            () -> assertResponse(
+                prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX).setSize(1)
+                    .addSort("timestamp", SortOrder.DESC)
+                    .setQuery(
+                        QueryBuilders.boolQuery()
+                            .filter(QueryBuilders.termQuery("job_id", job.getId()))
+                            .filter(QueryBuilders.termQuery("level", "info"))
+                    ),
+                searchResponse -> {
+                    SearchHit[] hits = searchResponse.getHits().getHits();
+                    assertThat(hits.length, equalTo(1));
+                    assertThat(
+                        (String) hits[0].getSourceAsMap().get("message"),
+                        containsString("Filter [safe_ips] has been modified; removed items: ['111.111.111.111']")
+                    );
+                }
+            )
+        );
+        MlFilter updatedSafeIps = MlFilter.builder("safe_ips").setItems(Arrays.asList("222.222.222.222")).build();
+        assertThat(updatedFilter.getFilter(), equalTo(updatedSafeIps));
+
+        data.clear();
+        // Now send anomalous count of 9 for 111.111.111.111
+        for (int i = 0; i < 9; i++) {
+            data.add(createIpRecord(timestamp, "111.111.111.111"));
+        }
+
+        // Some more normal buckets
+        for (int bucket = 0; bucket < 3; bucket++) {
+            for (String ip : ips) {
+                data.add(createIpRecord(timestamp, ip));
+            }
+            timestamp += TimeValue.timeValueHours(1).getMillis();
+        }
+
+        postData(job.getId(), joinBetween(0, data.size(), data));
+        flushJob(job.getId(), true);
+
+        records = getRecords(job.getId());
+        assertThat(records.size(), equalTo(2));
+        assertThat(records.get(0).getOverFieldValue(), equalTo("222.222.222.222"));
+        assertThat(records.get(1).getOverFieldValue(), equalTo("111.111.111.111"));
+
+        {
+            // Update detection rules such that it now applies only to actual values > 10.0
+            DetectionRule newRule = new DetectionRule.Builder(
+                Arrays.asList(new RuleCondition(RuleCondition.AppliesTo.ACTUAL, Operator.GT, 10.0))
+            ).build();
+            JobUpdate.Builder update = new JobUpdate.Builder(job.getId());
+            update.setDetectorUpdates(Arrays.asList(new JobUpdate.DetectorUpdate(0, null, Arrays.asList(newRule))));
+            updateJob(job.getId(), update.build());
+            // Wait until the notification that the job was updated is indexed
+            assertBusy(
+                () -> assertResponse(
+                    prepareSearch(NotificationsIndex.NOTIFICATIONS_INDEX).setSize(1)
+                        .addSort("timestamp", SortOrder.DESC)
+                        .setQuery(
+                            QueryBuilders.boolQuery()
+                                .filter(QueryBuilders.termQuery("job_id", job.getId()))
+                                .filter(QueryBuilders.termQuery("level", "info"))
+                        ),
+                    searchResponse -> {
+                        SearchHit[] hits = searchResponse.getHits().getHits();
+                        assertThat(hits.length, equalTo(1));
+                        assertThat((String) hits[0].getSourceAsMap().get("message"), containsString("Job updated: [detectors]"));
+                    }
+                )
+            );
+        }
+
+        data.clear();
+        // Now send anomalous count of 10 for 222.222.222.222
+        for (int i = 0; i < 10; i++) {
+            data.add(createIpRecord(timestamp, "222.222.222.222"));
+        }
+
+        // Some more normal buckets
+        for (int bucket = 0; bucket < 3; bucket++) {
+            for (String ip : ips) {
+                data.add(createIpRecord(timestamp, ip));
+            }
+            timestamp += TimeValue.timeValueHours(1).getMillis();
+        }
+
+        postData(job.getId(), joinBetween(0, data.size(), data));
+
+        closeJob(job.getId());
+
+        // The anomalous records should not have changed.
+        records = getRecords(job.getId());
+        assertThat(records.size(), equalTo(2));
+        assertThat(records.get(0).getOverFieldValue(), equalTo("222.222.222.222"));
+        assertThat(records.get(1).getOverFieldValue(), equalTo("111.111.111.111"));
+
     }
 
     public void testForceTimeShiftAction() throws Exception {

+ 8 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java

@@ -79,6 +79,7 @@ import org.elasticsearch.xpack.core.ml.action.PutFilterAction;
 import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction;
 import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
 import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
+import org.elasticsearch.xpack.core.ml.action.UpdateFilterAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
 import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
@@ -311,6 +312,13 @@ abstract class MlNativeIntegTestCase extends ESIntegTestCase {
         return client().execute(PutFilterAction.INSTANCE, new PutFilterAction.Request(filter)).actionGet();
     }
 
+    protected PutFilterAction.Response updateMlFilter(String filterId, List<String> addItems, List<String> removeItems) {
+        UpdateFilterAction.Request request = new UpdateFilterAction.Request(filterId);
+        request.setAddItems(addItems);
+        request.setRemoveItems(removeItems);
+        return client().execute(UpdateFilterAction.INSTANCE, request).actionGet();
+    }
+
     protected static List<String> fetchAllAuditMessages(String jobId) throws Exception {
         RefreshRequest refreshRequest = new RefreshRequest(NotificationsIndex.NOTIFICATIONS_INDEX);
         BroadcastResponse refreshResponse = client().execute(RefreshAction.INSTANCE, refreshRequest).actionGet();