|
@@ -138,13 +138,13 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats;
|
|
|
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
|
|
|
import org.elasticsearch.client.ml.dataframe.QueryConfig;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.AccuracyMetric;
|
|
|
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.AucRocMetric;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
|
|
|
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric;
|
|
|
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric;
|
|
|
+import org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.ConfusionMatrixMetric;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
|
|
|
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric;
|
|
|
-import org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.HuberMetric;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
|
|
|
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredLogarithmicErrorMetric;
|
|
@@ -1744,15 +1744,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
new OutlierDetection(
|
|
|
actualField,
|
|
|
probabilityField,
|
|
|
- PrecisionMetric.at(0.4, 0.5, 0.6), RecallMetric.at(0.5, 0.7), ConfusionMatrixMetric.at(0.5), AucRocMetric.withCurve()));
|
|
|
+ org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.at(0.4, 0.5, 0.6),
|
|
|
+ org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.at(0.5, 0.7),
|
|
|
+ ConfusionMatrixMetric.at(0.5),
|
|
|
+ org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.withCurve()));
|
|
|
|
|
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(OutlierDetection.NAME));
|
|
|
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4));
|
|
|
|
|
|
- PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME);
|
|
|
- assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
|
|
|
+ org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.Result precisionResult =
|
|
|
+ evaluateDataFrameResponse.getMetricByName(
|
|
|
+ org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME);
|
|
|
+ assertThat(
|
|
|
+ precisionResult.getMetricName(),
|
|
|
+ equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.PrecisionMetric.NAME));
|
|
|
// Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9)
|
|
|
assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9));
|
|
|
// Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9)
|
|
@@ -1761,8 +1768,11 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9));
|
|
|
assertNull(precisionResult.getScoreByThreshold("0.1"));
|
|
|
|
|
|
- RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME);
|
|
|
- assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
|
|
|
+ org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.Result recallResult =
|
|
|
+ evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME);
|
|
|
+ assertThat(
|
|
|
+ recallResult.getMetricName(),
|
|
|
+ equalTo(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.RecallMetric.NAME));
|
|
|
// Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
|
|
|
assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9));
|
|
|
// Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9)
|
|
@@ -1778,7 +1788,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7
|
|
|
assertNull(confusionMatrixResult.getScoreByThreshold("0.1"));
|
|
|
|
|
|
- AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
|
|
|
+ AucRocMetric.Result aucRocResult =
|
|
|
+ evaluateDataFrameResponse.getMetricByName(org.elasticsearch.client.ml.dataframe.evaluation.outlierdetection.AucRocMetric.NAME);
|
|
|
assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
|
|
assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9));
|
|
|
assertNotNull(aucRocResult.getCurve());
|
|
@@ -1890,24 +1901,40 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
createIndex(indexName, mappingForClassification());
|
|
|
BulkRequest regressionBulk = new BulkRequest()
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
|
|
- .add(docForClassification(indexName, "cat", "cat"))
|
|
|
- .add(docForClassification(indexName, "cat", "cat"))
|
|
|
- .add(docForClassification(indexName, "cat", "cat"))
|
|
|
- .add(docForClassification(indexName, "cat", "dog"))
|
|
|
- .add(docForClassification(indexName, "cat", "fish"))
|
|
|
- .add(docForClassification(indexName, "dog", "cat"))
|
|
|
- .add(docForClassification(indexName, "dog", "dog"))
|
|
|
- .add(docForClassification(indexName, "dog", "dog"))
|
|
|
- .add(docForClassification(indexName, "dog", "dog"))
|
|
|
- .add(docForClassification(indexName, "ant", "cat"));
|
|
|
+ .add(docForClassification(indexName, "cat", "cat", 0.9))
|
|
|
+ .add(docForClassification(indexName, "cat", "cat", 0.85))
|
|
|
+ .add(docForClassification(indexName, "cat", "cat", 0.95))
|
|
|
+ .add(docForClassification(indexName, "cat", "dog", 0.4))
|
|
|
+ .add(docForClassification(indexName, "cat", "fish", 0.35))
|
|
|
+ .add(docForClassification(indexName, "dog", "cat", 0.5))
|
|
|
+ .add(docForClassification(indexName, "dog", "dog", 0.4))
|
|
|
+ .add(docForClassification(indexName, "dog", "dog", 0.35))
|
|
|
+ .add(docForClassification(indexName, "dog", "dog", 0.6))
|
|
|
+ .add(docForClassification(indexName, "ant", "cat", 0.1));
|
|
|
highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT);
|
|
|
|
|
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
|
|
|
|
|
+ { // AucRoc
|
|
|
+ EvaluateDataFrameRequest evaluateDataFrameRequest =
|
|
|
+ new EvaluateDataFrameRequest(
|
|
|
+ indexName, null, new Classification(actualClassField, null, topClassesField, AucRocMetric.forClassWithCurve("cat")));
|
|
|
+
|
|
|
+ EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
+ execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
|
+ assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
|
|
+ assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
|
|
+
|
|
|
+ AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME);
|
|
|
+ assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME));
|
|
|
+ assertThat(aucRocResult.getScore(), closeTo(0.99995, 1e-9));
|
|
|
+ assertThat(aucRocResult.getDocCount(), equalTo(5L));
|
|
|
+ assertNotNull(aucRocResult.getCurve());
|
|
|
+ }
|
|
|
{ // Accuracy
|
|
|
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
|
|
new EvaluateDataFrameRequest(
|
|
|
- indexName, null, new Classification(actualClassField, predictedClassField, new AccuracyMetric()));
|
|
|
+ indexName, null, new Classification(actualClassField, predictedClassField, null, new AccuracyMetric()));
|
|
|
|
|
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
@@ -1931,65 +1958,47 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
{ // Precision
|
|
|
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
|
|
new EvaluateDataFrameRequest(
|
|
|
- indexName,
|
|
|
- null,
|
|
|
- new Classification(
|
|
|
- actualClassField,
|
|
|
- predictedClassField,
|
|
|
- new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric()));
|
|
|
+ indexName, null, new Classification(actualClassField, predictedClassField, null, new PrecisionMetric()));
|
|
|
|
|
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
|
|
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
|
|
|
|
|
- org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.Result precisionResult =
|
|
|
- evaluateDataFrameResponse.getMetricByName(
|
|
|
- org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME);
|
|
|
- assertThat(
|
|
|
- precisionResult.getMetricName(),
|
|
|
- equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.NAME));
|
|
|
+ PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME);
|
|
|
+ assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME));
|
|
|
assertThat(
|
|
|
precisionResult.getClasses(),
|
|
|
equalTo(
|
|
|
List.of(
|
|
|
// 3 out of 5 examples labeled as "cat" were classified correctly
|
|
|
- new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("cat", 0.6),
|
|
|
+ new PrecisionMetric.PerClassResult("cat", 0.6),
|
|
|
// 3 out of 4 examples labeled as "dog" were classified correctly
|
|
|
- new org.elasticsearch.client.ml.dataframe.evaluation.classification.PrecisionMetric.PerClassResult("dog", 0.75))));
|
|
|
+ new PrecisionMetric.PerClassResult("dog", 0.75))));
|
|
|
assertThat(precisionResult.getAvgPrecision(), equalTo(0.675));
|
|
|
}
|
|
|
{ // Recall
|
|
|
EvaluateDataFrameRequest evaluateDataFrameRequest =
|
|
|
new EvaluateDataFrameRequest(
|
|
|
- indexName,
|
|
|
- null,
|
|
|
- new Classification(
|
|
|
- actualClassField,
|
|
|
- predictedClassField,
|
|
|
- new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric()));
|
|
|
+ indexName, null, new Classification(actualClassField, predictedClassField, null, new RecallMetric()));
|
|
|
|
|
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
|
assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Classification.NAME));
|
|
|
assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1));
|
|
|
|
|
|
- org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.Result recallResult =
|
|
|
- evaluateDataFrameResponse.getMetricByName(
|
|
|
- org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME);
|
|
|
- assertThat(
|
|
|
- recallResult.getMetricName(),
|
|
|
- equalTo(org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.NAME));
|
|
|
+ RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME);
|
|
|
+ assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME));
|
|
|
assertThat(
|
|
|
recallResult.getClasses(),
|
|
|
equalTo(
|
|
|
List.of(
|
|
|
// 3 out of 5 examples labeled as "cat" were classified correctly
|
|
|
- new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("cat", 0.6),
|
|
|
+ new RecallMetric.PerClassResult("cat", 0.6),
|
|
|
// 3 out of 4 examples labeled as "dog" were classified correctly
|
|
|
- new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("dog", 0.75),
|
|
|
+ new RecallMetric.PerClassResult("dog", 0.75),
|
|
|
// no examples labeled as "ant" were classified correctly
|
|
|
- new org.elasticsearch.client.ml.dataframe.evaluation.classification.RecallMetric.PerClassResult("ant", 0.0))));
|
|
|
+ new RecallMetric.PerClassResult("ant", 0.0))));
|
|
|
assertThat(recallResult.getAvgRecall(), equalTo(0.45));
|
|
|
}
|
|
|
{ // No size provided for MulticlassConfusionMatrixMetric, default used instead
|
|
@@ -1997,7 +2006,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
new EvaluateDataFrameRequest(
|
|
|
indexName,
|
|
|
null,
|
|
|
- new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric()));
|
|
|
+ new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric()));
|
|
|
|
|
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
@@ -2042,7 +2051,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
new EvaluateDataFrameRequest(
|
|
|
indexName,
|
|
|
null,
|
|
|
- new Classification(actualClassField, predictedClassField, new MulticlassConfusionMatrixMetric(2)));
|
|
|
+ new Classification(actualClassField, predictedClassField, null, new MulticlassConfusionMatrixMetric(2)));
|
|
|
|
|
|
EvaluateDataFrameResponse evaluateDataFrameResponse =
|
|
|
execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync);
|
|
@@ -2116,6 +2125,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
|
|
|
private static final String actualClassField = "actual_class";
|
|
|
private static final String predictedClassField = "predicted_class";
|
|
|
+ private static final String topClassesField = "top_classes";
|
|
|
|
|
|
private static XContentBuilder mappingForClassification() throws IOException {
|
|
|
return XContentFactory.jsonBuilder().startObject()
|
|
@@ -2126,14 +2136,22 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|
|
.startObject(predictedClassField)
|
|
|
.field("type", "keyword")
|
|
|
.endObject()
|
|
|
+ .startObject(topClassesField)
|
|
|
+ .field("type", "nested")
|
|
|
+ .endObject()
|
|
|
.endObject()
|
|
|
.endObject();
|
|
|
}
|
|
|
|
|
|
- private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass) {
|
|
|
+ private static IndexRequest docForClassification(String indexName, String actualClass, String predictedClass, double p) {
|
|
|
return new IndexRequest()
|
|
|
.index(indexName)
|
|
|
- .source(XContentType.JSON, actualClassField, actualClass, predictedClassField, predictedClass);
|
|
|
+ .source(XContentType.JSON,
|
|
|
+ actualClassField, actualClass,
|
|
|
+ predictedClassField, predictedClass,
|
|
|
+ topClassesField, List.of(
|
|
|
+ Map.of("class_name", predictedClass, "class_probability", p),
|
|
|
+ Map.of("class_name", "other", "class_probability", 1 - p)));
|
|
|
}
|
|
|
|
|
|
private static final String actualRegression = "regression_actual";
|