|
@@ -23,8 +23,11 @@ import org.elasticsearch.action.index.IndexRequest;
|
|
|
import org.elasticsearch.action.support.WriteRequest;
|
|
|
import org.elasticsearch.client.ml.GetBucketsRequest;
|
|
|
import org.elasticsearch.client.ml.GetBucketsResponse;
|
|
|
+import org.elasticsearch.client.ml.GetRecordsRequest;
|
|
|
+import org.elasticsearch.client.ml.GetRecordsResponse;
|
|
|
import org.elasticsearch.client.ml.PutJobRequest;
|
|
|
import org.elasticsearch.client.ml.job.config.Job;
|
|
|
+import org.elasticsearch.client.ml.job.results.AnomalyRecord;
|
|
|
import org.elasticsearch.client.ml.job.results.Bucket;
|
|
|
import org.elasticsearch.client.ml.job.util.PageParams;
|
|
|
import org.elasticsearch.common.xcontent.XContentType;
|
|
@@ -34,7 +37,10 @@ import org.junit.Before;
|
|
|
import java.io.IOException;
|
|
|
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
+import static org.hamcrest.Matchers.greaterThan;
|
|
|
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
|
|
import static org.hamcrest.Matchers.is;
|
|
|
+import static org.hamcrest.Matchers.lessThan;
|
|
|
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
|
|
|
|
|
public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
|
|
@@ -47,7 +53,8 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
|
|
|
// 2018-08-01T00:00:00Z
|
|
|
private static final long START_TIME_EPOCH_MS = 1533081600000L;
|
|
|
|
|
|
- private BucketStats bucketStats = new BucketStats();
|
|
|
+ private Stats bucketStats = new Stats();
|
|
|
+ private Stats recordStats = new Stats();
|
|
|
|
|
|
@Before
|
|
|
public void createJobAndIndexResults() throws IOException {
|
|
@@ -68,7 +75,7 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
|
|
|
|
|
|
// Also index an interim bucket
|
|
|
addBucketIndexRequest(time, true, bulkRequest);
|
|
|
- addRecordIndexRequests(time, true, bulkRequest);
|
|
|
+ addRecordIndexRequest(time, true, bulkRequest);
|
|
|
|
|
|
highLevelClient().bulk(bulkRequest, RequestOptions.DEFAULT);
|
|
|
}
|
|
@@ -91,16 +98,21 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
|
|
|
}
|
|
|
int recordCount = randomIntBetween(1, 3);
|
|
|
for (int i = 0; i < recordCount; ++i) {
|
|
|
- IndexRequest indexRequest = new IndexRequest(RESULTS_INDEX, DOC);
|
|
|
- double recordScore = randomDoubleBetween(0.0, 100.0, true);
|
|
|
- double p = randomDoubleBetween(0.0, 0.05, false);
|
|
|
- indexRequest.source("{\"job_id\":\"" + JOB_ID + "\", \"result_type\":\"record\", \"timestamp\": " + timestamp + "," +
|
|
|
- "\"bucket_span\": 3600,\"is_interim\": " + isInterim + ", \"record_score\": " + recordScore + ", \"probability\": "
|
|
|
- + p + "}", XContentType.JSON);
|
|
|
- bulkRequest.add(indexRequest);
|
|
|
+ addRecordIndexRequest(timestamp, isInterim, bulkRequest);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private void addRecordIndexRequest(long timestamp, boolean isInterim, BulkRequest bulkRequest) {
|
|
|
+ IndexRequest indexRequest = new IndexRequest(RESULTS_INDEX, DOC);
|
|
|
+ double recordScore = randomDoubleBetween(0.0, 100.0, true);
|
|
|
+ recordStats.report(recordScore);
|
|
|
+ double p = randomDoubleBetween(0.0, 0.05, false);
|
|
|
+ indexRequest.source("{\"job_id\":\"" + JOB_ID + "\", \"result_type\":\"record\", \"timestamp\": " + timestamp + "," +
|
|
|
+ "\"bucket_span\": 3600,\"is_interim\": " + isInterim + ", \"record_score\": " + recordScore + ", \"probability\": "
|
|
|
+ + p + "}", XContentType.JSON);
|
|
|
+ bulkRequest.add(indexRequest);
|
|
|
+ }
|
|
|
+
|
|
|
@After
|
|
|
public void deleteJob() throws IOException {
|
|
|
new MlRestTestStateCleaner(logger, client()).clearMlMetadata();
|
|
@@ -194,7 +206,73 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private static class BucketStats {
|
|
|
+ public void testGetRecords() throws IOException {
|
|
|
+ MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
|
+
|
|
|
+ {
|
|
|
+ GetRecordsRequest request = new GetRecordsRequest(JOB_ID);
|
|
|
+
|
|
|
+ GetRecordsResponse response = execute(request, machineLearningClient::getRecords, machineLearningClient::getRecordsAsync);
|
|
|
+
|
|
|
+ assertThat(response.count(), greaterThan(0L));
|
|
|
+ assertThat(response.count(), equalTo(recordStats.totalCount()));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetRecordsRequest request = new GetRecordsRequest(JOB_ID);
|
|
|
+ request.setRecordScore(50.0);
|
|
|
+
|
|
|
+ GetRecordsResponse response = execute(request, machineLearningClient::getRecords, machineLearningClient::getRecordsAsync);
|
|
|
+
|
|
|
+ long majorAndCriticalCount = recordStats.majorCount + recordStats.criticalCount;
|
|
|
+ assertThat(response.count(), equalTo(majorAndCriticalCount));
|
|
|
+ assertThat(response.records().size(), equalTo((int) Math.min(100, majorAndCriticalCount)));
|
|
|
+ assertThat(response.records().stream().anyMatch(r -> r.getRecordScore() < 50.0), is(false));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetRecordsRequest request = new GetRecordsRequest(JOB_ID);
|
|
|
+ request.setExcludeInterim(true);
|
|
|
+
|
|
|
+ GetRecordsResponse response = execute(request, machineLearningClient::getRecords, machineLearningClient::getRecordsAsync);
|
|
|
+
|
|
|
+ assertThat(response.count(), equalTo(recordStats.totalCount() - 1));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ long end = START_TIME_EPOCH_MS + 10 * 3600000;
|
|
|
+ GetRecordsRequest request = new GetRecordsRequest(JOB_ID);
|
|
|
+ request.setStart(String.valueOf(START_TIME_EPOCH_MS));
|
|
|
+ request.setEnd(String.valueOf(end));
|
|
|
+
|
|
|
+ GetRecordsResponse response = execute(request, machineLearningClient::getRecords, machineLearningClient::getRecordsAsync);
|
|
|
+
|
|
|
+ for (AnomalyRecord record : response.records()) {
|
|
|
+ assertThat(record.getTimestamp().getTime(), greaterThanOrEqualTo(START_TIME_EPOCH_MS));
|
|
|
+ assertThat(record.getTimestamp().getTime(), lessThan(end));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetRecordsRequest request = new GetRecordsRequest(JOB_ID);
|
|
|
+ request.setPageParams(new PageParams(3, 3));
|
|
|
+
|
|
|
+ GetRecordsResponse response = execute(request, machineLearningClient::getRecords, machineLearningClient::getRecordsAsync);
|
|
|
+
|
|
|
+ assertThat(response.records().size(), equalTo(3));
|
|
|
+ }
|
|
|
+ {
|
|
|
+ GetRecordsRequest request = new GetRecordsRequest(JOB_ID);
|
|
|
+ request.setSort("probability");
|
|
|
+ request.setDescending(true);
|
|
|
+
|
|
|
+ GetRecordsResponse response = execute(request, machineLearningClient::getRecords, machineLearningClient::getRecordsAsync);
|
|
|
+
|
|
|
+ double previousProb = 1.0;
|
|
|
+ for (AnomalyRecord record : response.records()) {
|
|
|
+ assertThat(record.getProbability(), lessThanOrEqualTo(previousProb));
|
|
|
+ previousProb = record.getProbability();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private static class Stats {
|
|
|
// score < 50.0
|
|
|
private long minorCount;
|
|
|
|
|
@@ -204,14 +282,18 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
|
|
|
// score > 75.0
|
|
|
private long criticalCount;
|
|
|
|
|
|
- private void report(double anomalyScore) {
|
|
|
- if (anomalyScore < 50.0) {
|
|
|
+ private void report(double score) {
|
|
|
+ if (score < 50.0) {
|
|
|
minorCount++;
|
|
|
- } else if (anomalyScore < 75.0) {
|
|
|
+ } else if (score < 75.0) {
|
|
|
majorCount++;
|
|
|
} else {
|
|
|
criticalCount++;
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ private long totalCount() {
|
|
|
+ return minorCount + majorCount + criticalCount;
|
|
|
+ }
|
|
|
}
|
|
|
}
|