|
@@ -125,7 +125,9 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
|
|
import org.elasticsearch.client.ml.dataframe.OutlierDetection;
|
|
|
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
|
|
|
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
|
|
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
|
|
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
|
|
@@ -1573,19 +1575,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
|
|
|
public void testEvaluateDataFrame_BinarySoftClassification() throws IOException {
|
|
|
String indexName = "evaluate-test-index";
|
|
|
- createIndex(indexName, mappingForClassification());
|
|
|
+ createIndex(indexName, mappingForSoftClassification());
|
|
|
BulkRequest bulk = new BulkRequest()
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
- .add(docForClassification(indexName, "blue", false, 0.1)) // #0
|
|
|
- .add(docForClassification(indexName, "blue", false, 0.2)) // #1
|
|
|
- .add(docForClassification(indexName, "blue", false, 0.3)) // #2
|
|
|
- .add(docForClassification(indexName, "blue", false, 0.4)) // #3
|
|
|
- .add(docForClassification(indexName, "blue", false, 0.7)) // #4
|
|
|
- .add(docForClassification(indexName, "blue", true, 0.2)) // #5
|
|
|
- .add(docForClassification(indexName, "green", true, 0.3)) // #6
|
|
|
- .add(docForClassification(indexName, "green", true, 0.4)) // #7
|
|
|
- .add(docForClassification(indexName, "green", true, 0.8)) // #8
|
|
|
- .add(docForClassification(indexName, "green", true, 0.9)); // #9
|
|
|
+ .add(docForSoftClassification(indexName, "blue", false, 0.1)) // #0
|
|
|
+ .add(docForSoftClassification(indexName, "blue", false, 0.2)) // #1
|
|
|
+ .add(docForSoftClassification(indexName, "blue", false, 0.3)) // #2
|
|
|
+ .add(docForSoftClassification(indexName, "blue", false, 0.4)) // #3
|
|
|
+ .add(docForSoftClassification(indexName, "blue", false, 0.7)) // #4
|
|
|
+ .add(docForSoftClassification(indexName, "blue", true, 0.2)) // #5
|
|
|
+ .add(docForSoftClassification(indexName, "green", true, 0.3)) // #6
|
|
|
+ .add(docForSoftClassification(indexName, "green", true, 0.4)) // #7
|
|
|
+ .add(docForSoftClassification(indexName, "green", true, 0.8)) // #8
|
|
|
+ .add(docForSoftClassification(indexName, "green", true, 0.9)); // #9
|
|
|
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
|
|
|
|
|
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
@@ -1647,19 +1649,19 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
|
|
|
public void testEvaluateDataFrame_BinarySoftClassification_WithQuery() throws IOException {
|
|
|
String indexName = "evaluate-with-query-test-index";
|
|
|
- createIndex(indexName, mappingForClassification());
|
|
|
+ createIndex(indexName, mappingForSoftClassification());
|
|
|
BulkRequest bulk = new BulkRequest()
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
- .add(docForClassification(indexName, "blue", true, 1.0)) // #0
|
|
|
- .add(docForClassification(indexName, "blue", true, 1.0)) // #1
|
|
|
- .add(docForClassification(indexName, "blue", true, 1.0)) // #2
|
|
|
- .add(docForClassification(indexName, "blue", true, 1.0)) // #3
|
|
|
- .add(docForClassification(indexName, "blue", true, 0.0)) // #4
|
|
|
- .add(docForClassification(indexName, "blue", true, 0.0)) // #5
|
|
|
- .add(docForClassification(indexName, "green", true, 0.0)) // #6
|
|
|
- .add(docForClassification(indexName, "green", true, 0.0)) // #7
|
|
|
- .add(docForClassification(indexName, "green", true, 0.0)) // #8
|
|
|
- .add(docForClassification(indexName, "green", true, 1.0)); // #9
|
|
|
+ .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #0
|
|
|
+ .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #1
|
|
|
+ .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #2
|
|
|
+ .add(docForSoftClassification(indexName, "blue", true, 1.0)) // #3
|
|
|
+ .add(docForSoftClassification(indexName, "blue", true, 0.0)) // #4
|
|
|
+ .add(docForSoftClassification(indexName, "blue", true, 0.0)) // #5
|
|
|
+ .add(docForSoftClassification(indexName, "green", true, 0.0)) // #6
|
|
|
+ .add(docForSoftClassification(indexName, "green", true, 0.0)) // #7
|
|
|
+ .add(docForSoftClassification(indexName, "green", true, 0.0)) // #8
|
|
|
+ .add(docForSoftClassification(indexName, "green", true, 1.0)); // #9
|
|
|
highLevelClient().bulk(bulk, RequestOptions.DEFAULT);
|
|
|
|
|
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
@@ -1722,6 +1724,74 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9));
|
|
|
}
|
|
|
|
|
|
+ public void testEvaluateDataFrame_Classification() throws IOException {
|
|
|
+ String indexName = "evaluate-classification-test-index";
|
|
|
+ createIndex(indexName, mappingForClassification());
|
|
|
+ BulkRequest regressionBulk = new BulkRequest()
|
|
|
+ .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
+ .add(docForClassification(indexName, "cat", "cat"))
|
|
|
+ .add(docForClassification(indexName, "cat", "cat"))
|
|
|
+ .add(docForClassification(indexName, "cat", "cat"))
|
|
|
+ .add(docForClassification(indexName, "cat", "dog"))
|
|
|
+ .add(docForClassification(indexName, "cat", "fish"))
|
|
|
+ .add(docForClassification(indexName, "dog", "cat"))
|
|
|
+ .add(docForClassification(indexName, "dog", "dog"))
|
|
|
+ .add(docForClassification(indexName, "dog", "dog"))
|
|
|
+ .add(docForClassification(indexName, "dog", "dog"))
|
|
|
+ .add(docForClassification(indexName, "horse", "cat"));
|
|
|
+ highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
|
|
+
|
|
|
+ MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
|
+
|
|
|
+ { // No size provided for MulticlassConfusionMatrixMetric, default used instead
|
|
|
+ EvaluateDataFrameRequest evaluateDataFrameRequest =
|
|
|
+ new EvaluateDataFrameRequest(
|
|
|
+ indexName,
|
|
|
+ null,
|
|
|
+ new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric()));
|
|
|
+
|
|
|
+ EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
+ execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
|
+ assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
|
|
+ assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
|
|
+
|
|
|
+ MulticlassConfusionMatrixMetric.Result mcmResult =
|
|
|
+ evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
|
|
|
+ assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
|
|
+ assertThat(
|
|
|
+ mcmResult.getConfusionMatrix(),
|
|
|
+ equalTo(
|
|
|
+ Map.of(
|
|
|
+ "cat", Map.of("cat", 3L, "dog", 1L, "horse", 0L, "_other_", 1L),
|
|
|
+ "dog", Map.of("cat", 1L, "dog", 3L, "horse", 0L),
|
|
|
+ "horse", Map.of("cat", 1L, "dog", 0L, "horse", 0L))));
|
|
|
+ assertThat(mcmResult.getOtherClassesCount(), equalTo(0L));
|
|
|
+ }
|
|
|
+ { // Explicit size provided for MulticlassConfusionMatrixMetric metric
|
|
|
+ EvaluateDataFrameRequest evaluateDataFrameRequest =
|
|
|
+ new EvaluateDataFrameRequest(
|
|
|
+ indexName,
|
|
|
+ null,
|
|
|
+ new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2)));
|
|
|
+
|
|
|
+ EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
+ execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
|
+ assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
|
|
+ assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
|
|
+
|
|
|
+ MulticlassConfusionMatrixMetric.Result mcmResult =
|
|
|
+ evaluateDataFrameResponse.getMetricByName(MulticlassConfusionMatrixMetric.NAME);
|
|
|
+ assertThat(mcmResult.getMetricName(), equalTo(MulticlassConfusionMatrixMetric.NAME));
|
|
|
+ assertThat(
|
|
|
+ mcmResult.getConfusionMatrix(),
|
|
|
+ equalTo(
|
|
|
+ Map.of(
|
|
|
+ "cat", Map.of("cat", 3L, "dog", 1L, "_other_", 1L),
|
|
|
+ "dog", Map.of("cat", 1L, "dog", 3L))));
|
|
|
+ assertThat(mcmResult.getOtherClassesCount(), equalTo(1L));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private static XContentBuilder defaultMappingForTest() throws IOException {
|
|
|
return XContentFactory.jsonBuilder().startObject()
|
|
|
.startObject("properties")
|
|
@@ -1739,7 +1809,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
private static final String actualField = "label";
|
|
|
private static final String probabilityField = "p";
|
|
|
|
|
|
- private static XContentBuilder mappingForClassification() throws IOException {
|
|
|
+ private static XContentBuilder mappingForSoftClassification() throws IOException {
|
|
|
return XContentFactory.jsonBuilder().startObject()
|
|
|
.startObject("properties")
|
|
|
.startObject(datasetField)
|
|
@@ -1755,26 +1825,48 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
.endObject();
|
|
|
}
|
|
|
|
|
|
- private static IndexRequest docForClassification(String indexName, String dataset, boolean isTrue, double p) {
|
|
|
+ private static IndexRequest docForSoftClassification(String indexName, String dataset, boolean isTrue, double p) {
|
|
|
return new IndexRequest()
|
|
|
.index(indexName)
|
|
|
.source(XContentType.JSON, datasetField, dataset, actualField, Boolean.toString(isTrue), probabilityField, p);
|
|
|
}
|
|
|
|
|
|
+ private static final String actualClassField = "actual_class";
|
|
|
+ private static final String predictedClassField = "predicted_class";
|
|
|
+
|
|
|
+ private static XContentBuilder mappingForClassification() throws IOException {
|
|
|
+ return XContentFactory.jsonBuilder().startObject()
|
|
|
+ .startObject("properties")
|
|
|
+ .startObject(actualClassField)
|
|
|
+ .field("type", "keyword")
|
|
|
+ .endObject()
|
|
|
+ .startObject(predictedClassField)
|
|
|
+ .field("type", "keyword")
|
|
|
+ .endObject()
|
|
|
+ .endObject()
|
|
|
+ .endObject();
|
|
|
+ }
|
|
|
+
|
|
|
+ private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) {
|
|
|
+ return new IndexRequest()
|
|
|
+ .index(indexName)
|
|
|
+ .source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass);
|
|
|
+ }
|
|
|
+
|
|
|
private static final String actualRegression = "regression_actual";
|
|
|
private static final String probabilityRegression = "regression_prob";
|
|
|
|
|
|
private static XContentBuilder mappingForRegression() throws IOException {
|
|
|
return XContentFactory.jsonBuilder().startObject()
|
|
|
.startObject("properties")
|
|
|
- .startObject(actualRegression)
|
|
|
- .field("type", "double")
|
|
|
- .endObject()
|
|
|
- .startObject(probabilityRegression)
|
|
|
- .field("type", "double")
|
|
|
- .endObject()
|
|
|
+ .startObject(actualRegression)
|
|
|
+ .field("type", "double")
|
|
|
+ .endObject()
|
|
|
+ .startObject(probabilityRegression)
|
|
|
+ .field("type", "double")
|
|
|
+ .endObject()
|
|
|
.endObject()
|
|
|
- .endObject();
|
|
|
+ .endObject();
|
|
|
}
|
|
|
|
|
|
private static IndexRequest docForRegression(String indexName, double act, double p) {
|
|
@@ -1789,11 +1881,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
|
|
|
public void testEstimateMemoryUsage() throws IOException {
|
|
|
String indexName = "estimate-test-index";
|
|
|
- createIndex(indexName, mappingForClassification());
|
|
|
+ createIndex(indexName, mappingForSoftClassification());
|
|
|
BulkRequest bulk1 = new BulkRequest()
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
|
for (int i = 0; i < 10; ++i) {
|
|
|
- bulk1.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
|
|
+ bulk1.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
|
|
}
|
|
|
highLevelClient().bulk(bulk1, RequestOptions.DEFAULT);
|
|
|
|
|
@@ -1819,7 +1911,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
BulkRequest bulk2 = new BulkRequest()
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
|
for (int i = 10; i < 100; ++i) {
|
|
|
- bulk2.add(docForClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
|
|
+ bulk2.add(docForSoftClassification(indexName, randomAlphaOfLength(10), randomBoolean(), randomDoubleBetween(0.0, 1.0, true)));
|
|
|
}
|
|
|
highLevelClient().bulk(bulk2, RequestOptions.DEFAULT);
|
|
|
|