|  | @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.io.stream.StreamOutput;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.xcontent.XContentBuilder;
 | 
	
		
			
				|  |  |  import org.elasticsearch.index.query.BoolQueryBuilder;
 | 
	
		
			
				|  |  | +import org.elasticsearch.index.query.QueryBuilder;
 | 
	
		
			
				|  |  |  import org.elasticsearch.index.query.QueryBuilders;
 | 
	
		
			
				|  |  |  import org.elasticsearch.search.aggregations.AggregationBuilder;
 | 
	
		
			
				|  |  |  import org.elasticsearch.search.aggregations.AggregationBuilders;
 | 
	
	
		
			
				|  | @@ -18,10 +19,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResu
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import java.io.IOException;
 | 
	
		
			
				|  |  | -import java.util.ArrayList;
 | 
	
		
			
				|  |  |  import java.util.List;
 | 
	
		
			
				|  |  |  import java.util.Optional;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric.actualIsTrueQuery;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public static final ParseField AT = new ParseField("at");
 | 
	
	
		
			
				|  | @@ -29,8 +31,8 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
 | 
	
		
			
				|  |  |      protected final double[] thresholds;
 | 
	
		
			
				|  |  |      private EvaluationMetricResult result;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    protected AbstractConfusionMatrixMetric(double[] thresholds) {
 | 
	
		
			
				|  |  | -        this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT);
 | 
	
		
			
				|  |  | +    protected AbstractConfusionMatrixMetric(List<Double> at) {
 | 
	
		
			
				|  |  | +        this.thresholds = ExceptionsHelper.requireNonNull(at, AT).stream().mapToDouble(Double::doubleValue).toArray();
 | 
	
		
			
				|  |  |          if (thresholds.length == 0) {
 | 
	
		
			
				|  |  |              throw ExceptionsHelper.badRequestException("[" + getName() + "." + AT.getPreferredName() + "] must have at least one value");
 | 
	
		
			
				|  |  |          }
 | 
	
	
		
			
				|  | @@ -60,20 +62,16 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Override
 | 
	
		
			
				|  |  | -    public final List<AggregationBuilder> aggs(String actualField, List<ClassInfo> classInfos) {
 | 
	
		
			
				|  |  | +    public final List<AggregationBuilder> aggs(String actualField, String predictedProbabilityField) {
 | 
	
		
			
				|  |  |          if (result != null) {
 | 
	
		
			
				|  |  |              return List.of();
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | -        List<AggregationBuilder> aggs = new ArrayList<>();
 | 
	
		
			
				|  |  | -        for (double threshold : thresholds) {
 | 
	
		
			
				|  |  | -            aggs.addAll(aggsAt(actualField, classInfos, threshold));
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -        return aggs;
 | 
	
		
			
				|  |  | +        return aggsAt(actualField, predictedProbabilityField);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Override
 | 
	
		
			
				|  |  | -    public void process(ClassInfo classInfo, Aggregations aggs) {
 | 
	
		
			
				|  |  | -        result = evaluate(classInfo, aggs);
 | 
	
		
			
				|  |  | +    public void process(Aggregations aggs) {
 | 
	
		
			
				|  |  | +        result = evaluate(aggs);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Override
 | 
	
	
		
			
				|  | @@ -81,40 +79,43 @@ abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric
 | 
	
		
			
				|  |  |          return Optional.ofNullable(result);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    protected abstract List<AggregationBuilder> aggsAt(String labelField, List<ClassInfo> classInfos, double threshold);
 | 
	
		
			
				|  |  | +    protected abstract List<AggregationBuilder> aggsAt(String actualField, String predictedProbabilityField);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    protected abstract EvaluationMetricResult evaluate(Aggregations aggs);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    protected abstract EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs);
 | 
	
		
			
				|  |  | +    enum Condition {
 | 
	
		
			
				|  |  | +        TP(true, true),
 | 
	
		
			
				|  |  | +        FP(false, true),
 | 
	
		
			
				|  |  | +        TN(false, false),
 | 
	
		
			
				|  |  | +        FN(true, false);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    protected enum Condition {
 | 
	
		
			
				|  |  | -        TP, FP, TN, FN;
 | 
	
		
			
				|  |  | +        final boolean actual;
 | 
	
		
			
				|  |  | +        final boolean predicted;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Condition(boolean actual, boolean predicted) {
 | 
	
		
			
				|  |  | +            this.actual = actual;
 | 
	
		
			
				|  |  | +            this.predicted = predicted;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    protected String aggName(ClassInfo classInfo, double threshold, Condition condition) {
 | 
	
		
			
				|  |  | -        return getName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name();
 | 
	
		
			
				|  |  | +    protected String aggName(double threshold, Condition condition) {
 | 
	
		
			
				|  |  | +        return getName() + "_at_" + threshold + "_" + condition.name();
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) {
 | 
	
		
			
				|  |  | +    protected AggregationBuilder buildAgg(String actualField, String predictedProbabilityField, double threshold, Condition condition) {
 | 
	
		
			
				|  |  |          BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
 | 
	
		
			
				|  |  | -        switch (condition) {
 | 
	
		
			
				|  |  | -            case TP:
 | 
	
		
			
				|  |  | -                boolQuery.must(classInfo.matchingQuery());
 | 
	
		
			
				|  |  | -                boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold));
 | 
	
		
			
				|  |  | -                break;
 | 
	
		
			
				|  |  | -            case FP:
 | 
	
		
			
				|  |  | -                boolQuery.mustNot(classInfo.matchingQuery());
 | 
	
		
			
				|  |  | -                boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold));
 | 
	
		
			
				|  |  | -                break;
 | 
	
		
			
				|  |  | -            case TN:
 | 
	
		
			
				|  |  | -                boolQuery.mustNot(classInfo.matchingQuery());
 | 
	
		
			
				|  |  | -                boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold));
 | 
	
		
			
				|  |  | -                break;
 | 
	
		
			
				|  |  | -            case FN:
 | 
	
		
			
				|  |  | -                boolQuery.must(classInfo.matchingQuery());
 | 
	
		
			
				|  |  | -                boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold));
 | 
	
		
			
				|  |  | -                break;
 | 
	
		
			
				|  |  | -            default:
 | 
	
		
			
				|  |  | -                throw new IllegalArgumentException("Unknown enum value: " + condition);
 | 
	
		
			
				|  |  | +        QueryBuilder actualIsTrueQuery = actualIsTrueQuery(actualField);
 | 
	
		
			
				|  |  | +        QueryBuilder predictedIsTrueQuery = QueryBuilders.rangeQuery(predictedProbabilityField).gte(threshold);
 | 
	
		
			
				|  |  | +        if (condition.actual) {
 | 
	
		
			
				|  |  | +            boolQuery.must(actualIsTrueQuery);
 | 
	
		
			
				|  |  | +        } else {
 | 
	
		
			
				|  |  | +            boolQuery.mustNot(actualIsTrueQuery);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        if (condition.predicted) {
 | 
	
		
			
				|  |  | +            boolQuery.must(predictedIsTrueQuery);
 | 
	
		
			
				|  |  | +        } else {
 | 
	
		
			
				|  |  | +            boolQuery.mustNot(predictedIsTrueQuery);
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  | -        return AggregationBuilders.filter(aggName(classInfo, threshold, condition), boolQuery);
 | 
	
		
			
				|  |  | +        return AggregationBuilders.filter(aggName(threshold, condition), boolQuery);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  }
 |