|
@@ -149,6 +149,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.common.xcontent.XContentFactory;
|
|
|
import org.elasticsearch.common.xcontent.XContentType;
|
|
|
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
|
|
+import org.elasticsearch.index.query.QueryBuilders;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.search.SearchHit;
|
|
|
import org.junit.After;
|
|
@@ -1427,7 +1428,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
public void testStopDataFrameAnalyticsConfig() throws Exception {
|
|
|
String sourceIndex = "stop-test-source-index";
|
|
|
String destIndex = "stop-test-dest-index";
|
|
|
- createIndex(sourceIndex, mappingForClassification());
|
|
|
+ createIndex(sourceIndex, defaultMappingForTest());
|
|
|
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000)
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT);
|
|
|
|
|
@@ -1525,27 +1526,28 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
assertThat(exception.status().getStatus(), equalTo(404));
|
|
|
}
|
|
|
|
|
|
- public void testEvaluateDataFrame() throws IOException {
|
|
|
+ public void testEvaluateDataFrame_BinarySoftClassification() throws IOException {
|
|
|
String indexName = "evaluate-test-index";
|
|
|
createIndex(indexName, mappingForClassification());
|
|
|
BulkRequest bulk = new BulkRequest()
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
- .add(docForClassification(indexName, false, 0.1)) // #0
|
|
|
- .add(docForClassification(indexName, false, 0.2)) // #1
|
|
|
- .add(docForClassification(indexName, false, 0.3)) // #2
|
|
|
- .add(docForClassification(indexName, false, 0.4)) // #3
|
|
|
- .add(docForClassification(indexName, false, 0.7)) // #4
|
|
|
- .add(docForClassification(indexName, true, 0.2)) // #5
|
|
|
- .add(docForClassification(indexName, true, 0.3)) // #6
|
|
|
- .add(docForClassification(indexName, true, 0.4)) // #7
|
|
|
- .add(docForClassification(indexName, true, 0.8)) // #8
|
|
|
- .add(docForClassification(indexName, true, 0.9)); // #9
|
|
|
+ .add(docForClassification(indexName, "blue", false, 0.1)) // #0
|
|
|
+ .add(docForClassification(indexName, "blue", false, 0.2)) // #1
|
|
|
+ .add(docForClassification(indexName, "blue", false, 0.3)) // #2
|
|
|
+ .add(docForClassification(indexName, "blue", false, 0.4)) // #3
|
|
|
+ .add(docForClassification(indexName, "blue", false, 0.7)) // #4
|
|
|
+ .add(docForClassification(indexName, "blue", true, 0.2)) // #5
|
|
|
+ .add(docForClassification(indexName, "green", true, 0.3)) // #6
|
|
|
+ .add(docForClassification(indexName, "green", true, 0.4)) // #7
|
|
|
+ .add(docForClassification(indexName, "green", true, 0.8)) // #8
|
|
|
+ .add(docForClassification(indexName, "green", true, 0.9)); // #9
|
|
|
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
|
|
|
|
|
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
|
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
|
|
new EvaluateDataFrameRequest(
|
|
|
indexName,
|
|
|
+ null,
|
|
|
new BinarySoftClassification(
|
|
|
actualField,
|
|
|
probabilityField,
|
|
@@ -1596,7 +1598,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0));
|
|
|
assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0));
|
|
|
assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException {
|
|
|
+ String indexName = "evaluate-with-query-test-index";
|
|
|
+ createIndex(indexName, mappingForClassification());
|
|
|
+ BulkRequest bulk = new BulkRequest()
|
|
|
+ .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
+ .add(docForClassification(indexName, "blue", true, 1.0)) // #0
|
|
|
+ .add(docForClassification(indexName, "blue", true, 1.0)) // #1
|
|
|
+ .add(docForClassification(indexName, "blue", true, 1.0)) // #2
|
|
|
+ .add(docForClassification(indexName, "blue", true, 1.0)) // #3
|
|
|
+ .add(docForClassification(indexName, "blue", true, 0.0)) // #4
|
|
|
+ .add(docForClassification(indexName, "blue", true, 0.0)) // #5
|
|
|
+ .add(docForClassification(indexName, "green", true, 0.0)) // #6
|
|
|
+ .add(docForClassification(indexName, "green", true, 0.0)) // #7
|
|
|
+ .add(docForClassification(indexName, "green", true, 0.0)) // #8
|
|
|
+ .add(docForClassification(indexName, "green", true, 1.0)); // #9
|
|
|
+ highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
|
|
|
|
|
|
+ MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
|
+ EvaluateDataFrameRequest evaluateDataFrameRequest =
|
|
|
+ new EvaluateDataFrameRequest(
|
|
|
+ indexName,
|
|
|
+ // Request only "blue" subset to be evaluated
|
|
|
+ new QueryConfig(QueryBuilders.termQuery(datasetField, "blue")),
|
|
|
+ new BinarySoftClassification(actualField, probabilityField, ConfusionMatrixMetric.at(0.5)));
|
|
|
+
|
|
|
+ EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
+ execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
|
+ assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(BinarySoftClassification.NAME));
|
|
|
+ assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
|
|
+
|
|
|
+ ConfusionMatrixMetric.Result confusionMatrixResult = evaluateDataFrameResponse.getMetricByName(ConfusionMatrixMetric.NAME);
|
|
|
+ assertThat(confusionMatrixResult.getMetricName(), equalTo(ConfusionMatrixMetric.NAME));
|
|
|
+ ConfusionMatrixMetric.ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5");
|
|
|
+ assertThat(confusionMatrix.getTruePositives(), equalTo(4L)); // docs #0, #1, #2 and #3
|
|
|
+ assertThat(confusionMatrix.getFalsePositives(), equalTo(0L));
|
|
|
+ assertThat(confusionMatrix.getTrueNegatives(), equalTo(0L));
|
|
|
+ assertThat(confusionMatrix.getFalseNegatives(), equalTo(2L)); // docs #4 and #5
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testEvaluateDataFrame_Regression() throws IOException {
|
|
|
String regressionIndex = "evaluate-regression-test-index";
|
|
|
createIndex(regressionIndex, mappingForRegression());
|
|
|
BulkRequest regressionBulk = new BulkRequest()
|
|
@@ -1613,10 +1656,14 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
.add(docForRegression(regressionIndex, 0.5, 0.9)); // #9
|
|
|
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
|
|
|
|
|
- evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex,
|
|
|
- new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
|
|
|
+ MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
|
+ EvaluateDataFrameRequest evaluateDataFrameRequest =
|
|
|
+ new EvaluateDataFrameRequest(
|
|
|
+ regressionIndex,
|
|
|
+ null,
|
|
|
+ new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric()));
|
|
|
|
|
|
- evaluateDataFrameResponse =
|
|
|
+ EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME));
|
|
|
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2));
|
|
@@ -1643,12 +1690,16 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
.endObject();
|
|
|
}
|
|
|
|
|
|
+ private static final String datasetField = "dataset";
|
|
|
private static final String actualField = "label";
|
|
|
private static final String probabilityField = "p";
|
|
|
|
|
|
private static XContentBuilder mappingForClassification() throws IOException {
|
|
|
return XContentFactory.jsonBuilder().startObject()
|
|
|
.startObject("properties")
|
|
|
+ .startObject(datasetField)
|
|
|
+ .field("type", "keyword")
|
|
|
+ .endObject()
|
|
|
.startObject(actualField)
|
|
|
.field("type", "keyword")
|
|
|
.endObject()
|
|
@@ -1659,10 +1710,10 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
.endObject();
|
|
|
}
|
|
|
|
|
|
- private static IndexRequest docForClassification(String indexName, boolean isTrue, double p) {
|
|
|
+ private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) {
|
|
|
return new IndexRequest()
|
|
|
.index(indexName)
|
|
|
- .source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p);
|
|
|
+ .source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p);
|
|
|
}
|
|
|
|
|
|
private static final String actualRegression = "regression_actual";
|
|
@@ -1697,7 +1748,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
BulkRequest bulk1 = new BulkRequest()
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
|
for (int i = 0; i < 10; ++i) {
|
|
|
- bulk1.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
|
|
+ bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
|
|
}
|
|
|
highLevelClient().bulk(bulk1, RequestOptions.DEFAULT);
|
|
|
|
|
@@ -1723,7 +1774,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
BulkRequest bulk2 = new BulkRequest()
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
|
for (int i = 10; i < 100; ++i) {
|
|
|
- bulk2.add(docForClassification(indexName, randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
|
|
+ bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
|
|
}
|
|
|
highLevelClient().bulk(bulk2, RequestOptions.DEFAULT);
|
|
|
|