|
@@ -5,7 +5,6 @@
|
|
|
*/
|
|
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;
|
|
|
|
|
|
-import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.search.SearchResponse;
|
|
|
import org.elasticsearch.common.Nullable;
|
|
|
import org.elasticsearch.common.ParseField;
|
|
@@ -14,18 +13,14 @@ import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
|
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
|
|
import org.elasticsearch.common.xcontent.XContentParser;
|
|
|
-import org.elasticsearch.index.query.BoolQueryBuilder;
|
|
|
import org.elasticsearch.index.query.QueryBuilder;
|
|
|
import org.elasticsearch.index.query.QueryBuilders;
|
|
|
import org.elasticsearch.search.aggregations.AggregationBuilder;
|
|
|
-import org.elasticsearch.search.aggregations.Aggregations;
|
|
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation;
|
|
|
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
|
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
-import java.util.ArrayList;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
import java.util.Comparator;
|
|
@@ -87,17 +82,16 @@ public class BinarySoftClassification implements Evaluation {
|
|
|
if (metrics.isEmpty()) {
|
|
|
throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName());
|
|
|
}
|
|
|
- Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getMetricName));
|
|
|
+ Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getName));
|
|
|
return metrics;
|
|
|
}
|
|
|
|
|
|
private static List<SoftClassificationMetric> defaultMetrics() {
|
|
|
- List<SoftClassificationMetric> defaultMetrics = new ArrayList<>(4);
|
|
|
- defaultMetrics.add(new AucRoc(false));
|
|
|
- defaultMetrics.add(new Precision(Arrays.asList(0.25, 0.5, 0.75)));
|
|
|
- defaultMetrics.add(new Recall(Arrays.asList(0.25, 0.5, 0.75)));
|
|
|
- defaultMetrics.add(new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)));
|
|
|
- return defaultMetrics;
|
|
|
+ return Arrays.asList(
|
|
|
+ new AucRoc(false),
|
|
|
+ new Precision(Arrays.asList(0.25, 0.5, 0.75)),
|
|
|
+ new Recall(Arrays.asList(0.25, 0.5, 0.75)),
|
|
|
+ new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75)));
|
|
|
}
|
|
|
|
|
|
public BinarySoftClassification(StreamInput in) throws IOException {
|
|
@@ -126,7 +120,7 @@ public class BinarySoftClassification implements Evaluation {
|
|
|
|
|
|
builder.startObject(METRICS.getPreferredName());
|
|
|
for (SoftClassificationMetric metric : metrics) {
|
|
|
- builder.field(metric.getMetricName(), metric);
|
|
|
+ builder.field(metric.getName(), metric);
|
|
|
}
|
|
|
builder.endObject();
|
|
|
|
|
@@ -155,34 +149,34 @@ public class BinarySoftClassification implements Evaluation {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public SearchSourceBuilder buildSearch(QueryBuilder queryBuilder) {
|
|
|
- BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
|
|
- .filter(QueryBuilders.existsQuery(actualField))
|
|
|
- .filter(QueryBuilders.existsQuery(predictedProbabilityField))
|
|
|
- .filter(queryBuilder);
|
|
|
- SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(boolQuery);
|
|
|
+ public List<SoftClassificationMetric> getMetrics() {
|
|
|
+ return metrics;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public SearchSourceBuilder buildSearch(QueryBuilder userProvidedQueryBuilder) {
|
|
|
+ ExceptionsHelper.requireNonNull(userProvidedQueryBuilder, "userProvidedQueryBuilder");
|
|
|
+ SearchSourceBuilder searchSourceBuilder =
|
|
|
+ newSearchSourceBuilder(List.of(actualField, predictedProbabilityField), userProvidedQueryBuilder);
|
|
|
+ BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
|
|
|
for (SoftClassificationMetric metric : metrics) {
|
|
|
- List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo()));
|
|
|
+ List<AggregationBuilder> aggs = metric.aggs(actualField, Collections.singletonList(binaryClassInfo));
|
|
|
aggs.forEach(searchSourceBuilder::aggregation);
|
|
|
}
|
|
|
return searchSourceBuilder;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
|
|
|
+ public void process(SearchResponse searchResponse) {
|
|
|
+ ExceptionsHelper.requireNonNull(searchResponse, "searchResponse");
|
|
|
if (searchResponse.getHits().getTotalHits().value == 0) {
|
|
|
- listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField,
|
|
|
- predictedProbabilityField));
|
|
|
- return;
|
|
|
+ throw ExceptionsHelper.badRequestException(
|
|
|
+ "No documents found containing both [{}, {}] fields", actualField, predictedProbabilityField);
|
|
|
}
|
|
|
-
|
|
|
- List<EvaluationMetricResult> results = new ArrayList<>();
|
|
|
- Aggregations aggs = searchResponse.getAggregations();
|
|
|
BinaryClassInfo binaryClassInfo = new BinaryClassInfo();
|
|
|
for (SoftClassificationMetric metric : metrics) {
|
|
|
- results.add(metric.evaluate(binaryClassInfo, aggs));
|
|
|
+ metric.process(binaryClassInfo, searchResponse.getAggregations());
|
|
|
}
|
|
|
- listener.onResponse(results);
|
|
|
}
|
|
|
|
|
|
private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo {
|