|
@@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.index.query.BoolQueryBuilder;
|
|
|
+import org.elasticsearch.index.query.QueryBuilder;
|
|
|
import org.elasticsearch.index.query.QueryBuilders;
|
|
|
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
|
|
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
|
@@ -18,10 +19,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
-import java.util.ArrayList;
|
|
|
import java.util.List;
|
|
|
import java.util.Optional;
|
|
|
|
|
|
+import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
|
|
|
+
|
|
|
abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
|
|
|
|
|
|
public static final ParseField AT = new ParseField("at");
|
|
@@ -29,8 +31,8 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
|
|
|
protected final double[] thresholds;
|
|
|
private EvaluationMetricResult result;
|
|
|
|
|
|
- protected AbstractConfusionMatrixMetric(double[] thresholds) {
|
|
|
- this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT);
|
|
|
+ protected AbstractConfusionMatrixMetric(List<Double> at) {
|
|
|
+ this.thresholds = ExceptionsHelper.requireNonNull(at, AT).stream().mapToDouble(Double::doubleValue).toArray();
|
|
|
if (thresholds.length == 0) {
|
|
|
throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value");
|
|
|
}
|
|
@@ -60,20 +62,16 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public final List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
|
|
|
+ public final List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
|
|
|
if (result != null) {
|
|
|
return List.of();
|
|
|
}
|
|
|
- List<AggregationBuilder> aggs = new ArrayList<>();
|
|
|
- for (double threshold : thresholds) {
|
|
|
- aggs.addAll(aggsAt(actualField, classInfos, threshold));
|
|
|
- }
|
|
|
- return aggs;
|
|
|
+ return aggsAt(actualField, predictedProbabilityField);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public void process(ClassInfo classInfo, Aggregations aggs) {
|
|
|
- result = evaluate(classInfo, aggs);
|
|
|
+ public void process(Aggregations aggs) {
|
|
|
+ result = evaluate(aggs);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -81,40 +79,43 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
|
|
|
return Optional.ofNullable(result);
|
|
|
}
|
|
|
|
|
|
- protected abstract List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold);
|
|
|
+ protected abstract List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField);
|
|
|
+
|
|
|
+ protected abstract EvaluationMetricResult evaluate(Aggregations aggs);
|
|
|
|
|
|
- protected abstract EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
|
|
|
+ enum Condition {
|
|
|
+ TP(true, true),
|
|
|
+ FP(false, true),
|
|
|
+ TN(false, false),
|
|
|
+ FN(true, false);
|
|
|
|
|
|
- protected enum Condition {
|
|
|
- TP, FP, TN, FN;
|
|
|
+ final boolean actual;
|
|
|
+ final boolean predicted;
|
|
|
+
|
|
|
+ Condition(boolean actual, boolean predicted) {
|
|
|
+ this.actual = actual;
|
|
|
+ this.predicted = predicted;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
- protected String aggName(ClassInfo classInfo, double threshold, Condition condition) {
|
|
|
- return getName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
|
|
|
+ protected String aggName(double threshold, Condition condition) {
|
|
|
+ return getName() + "_at_" + threshold + "_" + condition.name();
|
|
|
}
|
|
|
|
|
|
- protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) {
|
|
|
+ protected AggregationBuilder buildAgg(String actualField, String predictedProbabilityField, double threshold, Condition condition) {
|
|
|
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
|
|
|
- switch (condition) {
|
|
|
- case TP:
|
|
|
- boolQuery.must(classInfo.matchingQuery());
|
|
|
- boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold));
|
|
|
- break;
|
|
|
- case FP:
|
|
|
- boolQuery.mustNot(classInfo.matchingQuery());
|
|
|
- boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold));
|
|
|
- break;
|
|
|
- case TN:
|
|
|
- boolQuery.mustNot(classInfo.matchingQuery());
|
|
|
- boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold));
|
|
|
- break;
|
|
|
- case FN:
|
|
|
- boolQuery.must(classInfo.matchingQuery());
|
|
|
- boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold));
|
|
|
- break;
|
|
|
- default:
|
|
|
- throw new IllegalArgumentException("Unknown enum value: " + condition);
|
|
|
+ QueryBuilder actualIsTrueQuery = actualIsTrueQuery(actualField);
|
|
|
+ QueryBuilder predictedIsTrueQuery = QueryBuilders.rangeQuery(predictedProbabilityField).gte(threshold);
|
|
|
+ if (condition.actual) {
|
|
|
+ boolQuery.must(actualIsTrueQuery);
|
|
|
+ } else {
|
|
|
+ boolQuery.mustNot(actualIsTrueQuery);
|
|
|
+ }
|
|
|
+ if (condition.predicted) {
|
|
|
+ boolQuery.must(predictedIsTrueQuery);
|
|
|
+ } else {
|
|
|
+ boolQuery.mustNot(predictedIsTrueQuery);
|
|
|
}
|
|
|
- return AggregationBuilders.filter(aggName(classInfo, threshold, condition), boolQuery);
|
|
|
+ return AggregationBuilders.filter(aggName(threshold, condition), boolQuery);
|
|
|
}
|
|
|
}
|