|
@@ -5,6 +5,7 @@
|
|
|
*/
|
|
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;
|
|
|
|
|
|
+import org.apache.lucene.util.SetOnce;
|
|
|
import org.elasticsearch.common.ParseField;
|
|
|
import org.elasticsearch.common.collect.Tuple;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
@@ -20,7 +21,6 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
|
|
|
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
|
|
import org.elasticsearch.search.aggregations.Aggregations;
|
|
|
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
|
|
|
-import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
|
|
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
|
|
@@ -39,22 +39,36 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constru
|
|
|
import static org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider.registeredMetricName;
|
|
|
|
|
|
/**
|
|
|
- * {@link Accuracy} is a metric that answers the question:
|
|
|
- * "What fraction of examples have been classified correctly by the classifier?"
|
|
|
+ * {@link Accuracy} is a metric that answers the following two questions:
|
|
|
*
|
|
|
- * equation: accuracy = 1/n * Σ(y == y´)
|
|
|
+ * 1. What is the fraction of documents for which predicted class equals the actual class?
|
|
|
+ *
|
|
|
+ * equation: overall_accuracy = 1/n * Σ(y == y')
|
|
|
+ * where: n = total number of documents
|
|
|
+ * y = document's actual class
|
|
|
+ * y' = document's predicted class
|
|
|
+ *
|
|
|
+ * 2. For any given class X, what is the fraction of documents for which either
|
|
|
+ * a) both actual and predicted class are equal to X (true positives)
|
|
|
+ * or
|
|
|
+ * b) both actual and predicted class are not equal to X (true negatives)
|
|
|
+ *
|
|
|
+ * equation: accuracy(X) = 1/n * (TP(X) + TN(X))
|
|
|
+ * where: X = class being examined
|
|
|
+ * n = total number of documents
|
|
|
+ * TP(X) = number of true positives wrt X
|
|
|
+ * TN(X) = number of true negatives wrt X
|
|
|
*/
|
|
|
public class Accuracy implements EvaluationMetric {
|
|
|
|
|
|
public static final ParseField NAME = new ParseField("accuracy");
|
|
|
|
|
|
+ static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
|
|
|
+
|
|
|
private static final String PAINLESS_TEMPLATE = "doc[''{0}''].value == doc[''{1}''].value";
|
|
|
- private static final String CLASSES_AGG_NAME = "classification_classes";
|
|
|
- private static final String PER_CLASS_ACCURACY_AGG_NAME = "classification_per_class_accuracy";
|
|
|
- private static final String OVERALL_ACCURACY_AGG_NAME = "classification_overall_accuracy";
|
|
|
|
|
|
- private static String buildScript(Object...args) {
|
|
|
- return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
|
|
|
+ private static Script buildScript(Object...args) {
|
|
|
+ return new Script(new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args));
|
|
|
}
|
|
|
|
|
|
private static final ObjectParser<Accuracy, Void> PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Accuracy::new);
|
|
@@ -63,11 +77,20 @@ public class Accuracy implements EvaluationMetric {
|
|
|
return PARSER.apply(parser, null);
|
|
|
}
|
|
|
|
|
|
- private EvaluationMetricResult result;
|
|
|
+ private static final int MAX_CLASSES_CARDINALITY = 1000;
|
|
|
|
|
|
- public Accuracy() {}
|
|
|
+ private final MulticlassConfusionMatrix matrix;
|
|
|
+ private final SetOnce<String> actualField = new SetOnce<>();
|
|
|
+ private final SetOnce<Double> overallAccuracy = new SetOnce<>();
|
|
|
+ private final SetOnce<Result> result = new SetOnce<>();
|
|
|
|
|
|
- public Accuracy(StreamInput in) throws IOException {}
|
|
|
+ public Accuracy() {
|
|
|
+ this.matrix = new MulticlassConfusionMatrix(MAX_CLASSES_CARDINALITY, NAME.getPreferredName() + "_");
|
|
|
+ }
|
|
|
+
|
|
|
+ public Accuracy(StreamInput in) throws IOException {
|
|
|
+ this.matrix = new MulticlassConfusionMatrix(in);
|
|
|
+ }
|
|
|
|
|
|
@Override
|
|
|
public String getWriteableName() {
|
|
@@ -81,43 +104,79 @@ public class Accuracy implements EvaluationMetric {
|
|
|
|
|
|
@Override
|
|
|
public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String actualField, String predictedField) {
|
|
|
- if (result != null) {
|
|
|
- return Tuple.tuple(List.of(), List.of());
|
|
|
+ // Store given {@code actualField} for the purpose of generating error message in {@code process}.
|
|
|
+ this.actualField.trySet(actualField);
|
|
|
+ List<AggregationBuilder> aggs = new ArrayList<>();
|
|
|
+ List<PipelineAggregationBuilder> pipelineAggs = new ArrayList<>();
|
|
|
+ if (overallAccuracy.get() == null) {
|
|
|
+ aggs.add(AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(buildScript(actualField, predictedField)));
|
|
|
+ }
|
|
|
+ if (result.get() == null) {
|
|
|
+ Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> matrixAggs = matrix.aggs(actualField, predictedField);
|
|
|
+ aggs.addAll(matrixAggs.v1());
|
|
|
+ pipelineAggs.addAll(matrixAggs.v2());
|
|
|
}
|
|
|
- Script accuracyScript = new Script(buildScript(actualField, predictedField));
|
|
|
- return Tuple.tuple(
|
|
|
- List.of(
|
|
|
- AggregationBuilders.terms(CLASSES_AGG_NAME)
|
|
|
- .field(actualField)
|
|
|
- .subAggregation(AggregationBuilders.avg(PER_CLASS_ACCURACY_AGG_NAME).script(accuracyScript)),
|
|
|
- AggregationBuilders.avg(OVERALL_ACCURACY_AGG_NAME).script(accuracyScript)),
|
|
|
- List.of());
|
|
|
+ return Tuple.tuple(aggs, pipelineAggs);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public void process(Aggregations aggs) {
|
|
|
- if (result != null) {
|
|
|
- return;
|
|
|
+ if (overallAccuracy.get() == null && aggs.get(OVERALL_ACCURACY_AGG_NAME) instanceof NumericMetricsAggregation.SingleValue) {
|
|
|
+ NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
|
|
|
+ overallAccuracy.set(overallAccuracyAgg.value());
|
|
|
}
|
|
|
- Terms classesAgg = aggs.get(CLASSES_AGG_NAME);
|
|
|
- NumericMetricsAggregation.SingleValue overallAccuracyAgg = aggs.get(OVERALL_ACCURACY_AGG_NAME);
|
|
|
- List<ActualClass> actualClasses = new ArrayList<>(classesAgg.getBuckets().size());
|
|
|
- for (Terms.Bucket bucket : classesAgg.getBuckets()) {
|
|
|
- String actualClass = bucket.getKeyAsString();
|
|
|
- long actualClassDocCount = bucket.getDocCount();
|
|
|
- NumericMetricsAggregation.SingleValue accuracyAgg = bucket.getAggregations().get(PER_CLASS_ACCURACY_AGG_NAME);
|
|
|
- actualClasses.add(new ActualClass(actualClass, actualClassDocCount, accuracyAgg.value()));
|
|
|
+ matrix.process(aggs);
|
|
|
+ if (result.get() == null && matrix.getResult().isPresent()) {
|
|
|
+ if (matrix.getResult().get().getOtherActualClassCount() > 0) {
|
|
|
+ // This means there were more than {@code maxClassesCardinality} buckets.
|
|
|
+ // We cannot calculate per-class accuracy accurately, so we fail.
|
|
|
+ throw ExceptionsHelper.badRequestException(
|
|
|
+ "Cannot calculate per-class accuracy. Cardinality of field [{}] is too high", actualField.get());
|
|
|
+ }
|
|
|
+ result.set(new Result(computePerClassAccuracy(matrix.getResult().get()), overallAccuracy.get()));
|
|
|
}
|
|
|
- result = new Result(actualClasses, overallAccuracyAgg.value());
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public Optional<EvaluationMetricResult> getResult() {
|
|
|
- return Optional.ofNullable(result);
|
|
|
+ public Optional<Result> getResult() {
|
|
|
+ return Optional.ofNullable(result.get());
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Computes the per-class accuracy results based on multiclass confusion matrix's result.
|
|
|
+ * Time complexity of this method is linear wrt multiclass confusion matrix size, so O(n^2) where n is the matrix dimension.
|
|
|
+ * This method is visible for testing only.
|
|
|
+ */
|
|
|
+ static List<PerClassResult> computePerClassAccuracy(MulticlassConfusionMatrix.Result matrixResult) {
|
|
|
+ assert matrixResult.getOtherActualClassCount() == 0;
|
|
|
+ // Number of actual classes taken into account
|
|
|
+ int n = matrixResult.getConfusionMatrix().size();
|
|
|
+ // Total number of documents taken into account
|
|
|
+ long totalDocCount =
|
|
|
+ matrixResult.getConfusionMatrix().stream().mapToLong(MulticlassConfusionMatrix.ActualClass::getActualClassDocCount).sum();
|
|
|
+ List<PerClassResult> classes = new ArrayList<>(n);
|
|
|
+ for (int i = 0; i < n; ++i) {
|
|
|
+ String className = matrixResult.getConfusionMatrix().get(i).getActualClass();
|
|
|
+ // Start with the assumption that all the docs were predicted correctly.
|
|
|
+ long correctDocCount = totalDocCount;
|
|
|
+ for (int j = 0; j < n; ++j) {
|
|
|
+ if (i != j) {
|
|
|
+ // Subtract errors (false negatives)
|
|
|
+ correctDocCount -= matrixResult.getConfusionMatrix().get(i).getPredictedClasses().get(j).getCount();
|
|
|
+ // Subtract errors (false positives)
|
|
|
+ correctDocCount -= matrixResult.getConfusionMatrix().get(j).getPredictedClasses().get(i).getCount();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Subtract errors (false negatives) for classes other than explicitly listed in confusion matrix
|
|
|
+ correctDocCount -= matrixResult.getConfusionMatrix().get(i).getOtherPredictedClassDocCount();
|
|
|
+ classes.add(new PerClassResult(className, ((double)correctDocCount) / totalDocCount));
|
|
|
+ }
|
|
|
+ return classes;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public void writeTo(StreamOutput out) throws IOException {
|
|
|
+ matrix.writeTo(out);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -131,25 +190,26 @@ public class Accuracy implements EvaluationMetric {
|
|
|
public boolean equals(Object o) {
|
|
|
if (this == o) return true;
|
|
|
if (o == null || getClass() != o.getClass()) return false;
|
|
|
- return true;
|
|
|
+ Accuracy that = (Accuracy) o;
|
|
|
+ return Objects.equals(this.matrix, that.matrix);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hashCode(NAME.getPreferredName());
|
|
|
+ return Objects.hash(matrix);
|
|
|
}
|
|
|
|
|
|
public static class Result implements EvaluationMetricResult {
|
|
|
|
|
|
- private static final ParseField ACTUAL_CLASSES = new ParseField("actual_classes");
|
|
|
+ private static final ParseField CLASSES = new ParseField("classes");
|
|
|
private static final ParseField OVERALL_ACCURACY = new ParseField("overall_accuracy");
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
private static final ConstructingObjectParser<Result, Void> PARSER =
|
|
|
- new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<ActualClass>) a[0], (double) a[1]));
|
|
|
+ new ConstructingObjectParser<>("accuracy_result", true, a -> new Result((List<PerClassResult>) a[0], (double) a[1]));
|
|
|
|
|
|
static {
|
|
|
- PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, ACTUAL_CLASSES);
|
|
|
+ PARSER.declareObjectArray(constructorArg(), PerClassResult.PARSER, CLASSES);
|
|
|
PARSER.declareDouble(constructorArg(), OVERALL_ACCURACY);
|
|
|
}
|
|
|
|
|
@@ -157,18 +217,18 @@ public class Accuracy implements EvaluationMetric {
|
|
|
return PARSER.apply(parser, null);
|
|
|
}
|
|
|
|
|
|
- /** List of actual classes. */
|
|
|
- private final List<ActualClass> actualClasses;
|
|
|
- /** Fraction of documents predicted correctly. */
|
|
|
+ /** List of per-class results. */
|
|
|
+ private final List<PerClassResult> classes;
|
|
|
+ /** Fraction of documents for which predicted class equals the actual class. */
|
|
|
private final double overallAccuracy;
|
|
|
|
|
|
- public Result(List<ActualClass> actualClasses, double overallAccuracy) {
|
|
|
- this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, ACTUAL_CLASSES));
|
|
|
+ public Result(List<PerClassResult> classes, double overallAccuracy) {
|
|
|
+ this.classes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(classes, CLASSES));
|
|
|
this.overallAccuracy = overallAccuracy;
|
|
|
}
|
|
|
|
|
|
public Result(StreamInput in) throws IOException {
|
|
|
- this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
|
|
|
+ this.classes = Collections.unmodifiableList(in.readList(PerClassResult::new));
|
|
|
this.overallAccuracy = in.readDouble();
|
|
|
}
|
|
|
|
|
@@ -182,8 +242,8 @@ public class Accuracy implements EvaluationMetric {
|
|
|
return NAME.getPreferredName();
|
|
|
}
|
|
|
|
|
|
- public List<ActualClass> getActualClasses() {
|
|
|
- return actualClasses;
|
|
|
+ public List<PerClassResult> getClasses() {
|
|
|
+ return classes;
|
|
|
}
|
|
|
|
|
|
public double getOverallAccuracy() {
|
|
@@ -192,14 +252,14 @@ public class Accuracy implements EvaluationMetric {
|
|
|
|
|
|
@Override
|
|
|
public void writeTo(StreamOutput out) throws IOException {
|
|
|
- out.writeList(actualClasses);
|
|
|
+ out.writeList(classes);
|
|
|
out.writeDouble(overallAccuracy);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
builder.startObject();
|
|
|
- builder.field(ACTUAL_CLASSES.getPreferredName(), actualClasses);
|
|
|
+ builder.field(CLASSES.getPreferredName(), classes);
|
|
|
builder.field(OVERALL_ACCURACY.getPreferredName(), overallAccuracy);
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
@@ -210,54 +270,47 @@ public class Accuracy implements EvaluationMetric {
|
|
|
if (this == o) return true;
|
|
|
if (o == null || getClass() != o.getClass()) return false;
|
|
|
Result that = (Result) o;
|
|
|
- return Objects.equals(this.actualClasses, that.actualClasses)
|
|
|
+ return Objects.equals(this.classes, that.classes)
|
|
|
&& this.overallAccuracy == that.overallAccuracy;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(actualClasses, overallAccuracy);
|
|
|
+ return Objects.hash(classes, overallAccuracy);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public static class ActualClass implements ToXContentObject, Writeable {
|
|
|
+ public static class PerClassResult implements ToXContentObject, Writeable {
|
|
|
|
|
|
- private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
|
|
|
- private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
|
|
|
+ private static final ParseField CLASS_NAME = new ParseField("class_name");
|
|
|
private static final ParseField ACCURACY = new ParseField("accuracy");
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
- private static final ConstructingObjectParser<ActualClass, Void> PARSER =
|
|
|
- new ConstructingObjectParser<>("accuracy_actual_class", true, a -> new ActualClass((String) a[0], (long) a[1], (double) a[2]));
|
|
|
+ private static final ConstructingObjectParser<PerClassResult, Void> PARSER =
|
|
|
+ new ConstructingObjectParser<>("accuracy_per_class_result", true, a -> new PerClassResult((String) a[0], (double) a[1]));
|
|
|
|
|
|
static {
|
|
|
- PARSER.declareString(constructorArg(), ACTUAL_CLASS);
|
|
|
- PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
|
|
|
+ PARSER.declareString(constructorArg(), CLASS_NAME);
|
|
|
PARSER.declareDouble(constructorArg(), ACCURACY);
|
|
|
}
|
|
|
|
|
|
- /** Name of the actual class. */
|
|
|
- private final String actualClass;
|
|
|
- /** Number of documents (examples) belonging to the {code actualClass} class. */
|
|
|
- private final long actualClassDocCount;
|
|
|
- /** Fraction of documents belonging to the {code actualClass} class predicted correctly. */
|
|
|
+ /** Name of the class. */
|
|
|
+ private final String className;
|
|
|
+ /** Fraction of documents that are either true positives or true negatives wrt {@code className}. */
|
|
|
private final double accuracy;
|
|
|
|
|
|
- public ActualClass(
|
|
|
- String actualClass, long actualClassDocCount, double accuracy) {
|
|
|
- this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
|
|
|
- this.actualClassDocCount = actualClassDocCount;
|
|
|
+ public PerClassResult(String className, double accuracy) {
|
|
|
+ this.className = ExceptionsHelper.requireNonNull(className, CLASS_NAME);
|
|
|
this.accuracy = accuracy;
|
|
|
}
|
|
|
|
|
|
- public ActualClass(StreamInput in) throws IOException {
|
|
|
- this.actualClass = in.readString();
|
|
|
- this.actualClassDocCount = in.readVLong();
|
|
|
+ public PerClassResult(StreamInput in) throws IOException {
|
|
|
+ this.className = in.readString();
|
|
|
this.accuracy = in.readDouble();
|
|
|
}
|
|
|
|
|
|
- public String getActualClass() {
|
|
|
- return actualClass;
|
|
|
+ public String getClassName() {
|
|
|
+ return className;
|
|
|
}
|
|
|
|
|
|
public double getAccuracy() {
|
|
@@ -266,16 +319,14 @@ public class Accuracy implements EvaluationMetric {
|
|
|
|
|
|
@Override
|
|
|
public void writeTo(StreamOutput out) throws IOException {
|
|
|
- out.writeString(actualClass);
|
|
|
- out.writeVLong(actualClassDocCount);
|
|
|
+ out.writeString(className);
|
|
|
out.writeDouble(accuracy);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
builder.startObject();
|
|
|
- builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
|
|
|
- builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
|
|
|
+ builder.field(CLASS_NAME.getPreferredName(), className);
|
|
|
builder.field(ACCURACY.getPreferredName(), accuracy);
|
|
|
builder.endObject();
|
|
|
return builder;
|
|
@@ -285,15 +336,14 @@ public class Accuracy implements EvaluationMetric {
|
|
|
public boolean equals(Object o) {
|
|
|
if (this == o) return true;
|
|
|
if (o == null || getClass() != o.getClass()) return false;
|
|
|
- ActualClass that = (ActualClass) o;
|
|
|
- return Objects.equals(this.actualClass, that.actualClass)
|
|
|
- && this.actualClassDocCount == that.actualClassDocCount
|
|
|
+ PerClassResult that = (PerClassResult) o;
|
|
|
+ return Objects.equals(this.className, that.className)
|
|
|
&& this.accuracy == that.accuracy;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(actualClass, actualClassDocCount, accuracy);
|
|
|
+ return Objects.hash(className, accuracy);
|
|
|
}
|
|
|
}
|
|
|
}
|