|
@@ -5,109 +5,179 @@
|
|
|
*/
|
|
|
package org.elasticsearch.xpack.core.ml.dataframe.evaluation;
|
|
|
|
|
|
+import org.elasticsearch.common.ParseField;
|
|
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
|
|
import org.elasticsearch.plugins.spi.NamedXContentProvider;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Classification;
|
|
|
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
|
|
|
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix;
|
|
|
+import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression;
|
|
|
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
|
|
|
-import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
|
|
|
|
|
|
-import java.util.ArrayList;
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.List;
|
|
|
|
|
|
public class MlEvaluationNamedXContentProvider implements NamedXContentProvider {
|
|
|
|
|
|
- @Override
|
|
|
- public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
|
|
|
- List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
|
|
|
+ /**
|
|
|
+ * Constructs the name under which a metric (or metric result) is registered.
|
|
|
+ * The name is prefixed with evaluation name so that registered names are unique.
|
|
|
+ *
|
|
|
+ * @param evaluationName name of the evaluation
|
|
|
+ * @param metricName name of the metric
|
|
|
+ * @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
|
|
|
+ */
|
|
|
+ public static String registeredMetricName(ParseField evaluationName, ParseField metricName) {
|
|
|
+ return registeredMetricName(evaluationName.getPreferredName(), metricName.getPreferredName());
|
|
|
+ }
|
|
|
|
|
|
- // Evaluations
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME,
|
|
|
- BinarySoftClassification::fromXContent));
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent));
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent));
|
|
|
+ /**
|
|
|
+ * Constructs the name under which a metric (or metric result) is registered.
|
|
|
+ * The name is prefixed with evaluation name so that registered names are unique.
|
|
|
+ *
|
|
|
+ * @param evaluationName name of the evaluation
|
|
|
+ * @param metricName name of the metric
|
|
|
+ * @return name appropriate for registering a metric (or metric result) in {@link NamedXContentRegistry}
|
|
|
+ */
|
|
|
+ public static String registeredMetricName(String evaluationName, String metricName) {
|
|
|
+ return evaluationName + "." + metricName;
|
|
|
+ }
|
|
|
|
|
|
- // Soft classification metrics
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent));
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Precision.NAME, Precision::fromXContent));
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Recall.NAME, Recall::fromXContent));
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME,
|
|
|
- ConfusionMatrix::fromXContent));
|
|
|
+ @Override
|
|
|
+ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
|
|
|
+ return Arrays.asList(
|
|
|
+ // Evaluations
|
|
|
+ new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, BinarySoftClassification::fromXContent),
|
|
|
+ new NamedXContentRegistry.Entry(Evaluation.class, Classification.NAME, Classification::fromXContent),
|
|
|
+ new NamedXContentRegistry.Entry(Evaluation.class, Regression.NAME, Regression::fromXContent),
|
|
|
|
|
|
- // Classification metrics
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, MulticlassConfusionMatrix.NAME,
|
|
|
- MulticlassConfusionMatrix::fromXContent));
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(ClassificationMetric.class, Accuracy.NAME, Accuracy::fromXContent));
|
|
|
+ // Soft classification metrics
|
|
|
+ new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
|
|
+ new ParseField(registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME)),
|
|
|
+ AucRoc::fromXContent),
|
|
|
+ new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
|
|
+ new ParseField(registeredMetricName(BinarySoftClassification.NAME, Precision.NAME)),
|
|
|
+ Precision::fromXContent),
|
|
|
+ new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
|
|
+ new ParseField(registeredMetricName(BinarySoftClassification.NAME, Recall.NAME)),
|
|
|
+ Recall::fromXContent),
|
|
|
+ new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
|
|
+ new ParseField(registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME)),
|
|
|
+ ConfusionMatrix::fromXContent),
|
|
|
|
|
|
- // Regression metrics
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent));
|
|
|
- namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent));
|
|
|
+ // Classification metrics
|
|
|
+ new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
|
|
+ new ParseField(registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME)),
|
|
|
+ MulticlassConfusionMatrix::fromXContent),
|
|
|
+ new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
|
|
+ new ParseField(registeredMetricName(Classification.NAME, Accuracy.NAME)),
|
|
|
+ Accuracy::fromXContent),
|
|
|
+ new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
|
|
+ new ParseField(
|
|
|
+ registeredMetricName(
|
|
|
+ Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME)),
|
|
|
+ org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::fromXContent),
|
|
|
+ new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
|
|
+ new ParseField(
|
|
|
+ registeredMetricName(
|
|
|
+ Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME)),
|
|
|
+ org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::fromXContent),
|
|
|
|
|
|
- return namedXContent;
|
|
|
+ // Regression metrics
|
|
|
+ new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
|
|
+ new ParseField(registeredMetricName(Regression.NAME, MeanSquaredError.NAME)),
|
|
|
+ MeanSquaredError::fromXContent),
|
|
|
+ new NamedXContentRegistry.Entry(EvaluationMetric.class,
|
|
|
+ new ParseField(registeredMetricName(Regression.NAME, RSquared.NAME)),
|
|
|
+ RSquared::fromXContent)
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
- public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
|
|
- List<NamedWriteableRegistry.Entry> namedWriteables = new ArrayList<>();
|
|
|
-
|
|
|
- // Evaluations
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(),
|
|
|
- BinarySoftClassification::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Classification.NAME.getPreferredName(),
|
|
|
- Classification::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, Regression.NAME.getPreferredName(), Regression::new));
|
|
|
-
|
|
|
- // Evaluation Metrics
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(),
|
|
|
- AucRoc::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(),
|
|
|
- Precision::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(),
|
|
|
- Recall::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(),
|
|
|
- ConfusionMatrix::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class,
|
|
|
- MulticlassConfusionMatrix.NAME.getPreferredName(),
|
|
|
- MulticlassConfusionMatrix::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(ClassificationMetric.class, Accuracy.NAME.getPreferredName(), Accuracy::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
|
|
|
- MeanSquaredError.NAME.getPreferredName(),
|
|
|
- MeanSquaredError::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class,
|
|
|
- RSquared.NAME.getPreferredName(),
|
|
|
- RSquared::new));
|
|
|
+ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
|
|
|
+ return Arrays.asList(
|
|
|
+ // Evaluations
|
|
|
+ new NamedWriteableRegistry.Entry(Evaluation.class,
|
|
|
+ BinarySoftClassification.NAME.getPreferredName(),
|
|
|
+ BinarySoftClassification::new),
|
|
|
+ new NamedWriteableRegistry.Entry(Evaluation.class,
|
|
|
+ Classification.NAME.getPreferredName(),
|
|
|
+ Classification::new),
|
|
|
+ new NamedWriteableRegistry.Entry(Evaluation.class,
|
|
|
+ Regression.NAME.getPreferredName(),
|
|
|
+ Regression::new),
|
|
|
|
|
|
- // Evaluation Metrics Results
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(),
|
|
|
- AucRoc.Result::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME,
|
|
|
- ScoreByThresholdResult::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(),
|
|
|
- ConfusionMatrix.Result::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
- MulticlassConfusionMatrix.NAME.getPreferredName(),
|
|
|
- MulticlassConfusionMatrix.Result::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
- Accuracy.NAME.getPreferredName(),
|
|
|
- Accuracy.Result::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
- MeanSquaredError.NAME.getPreferredName(),
|
|
|
- MeanSquaredError.Result::new));
|
|
|
- namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
- RSquared.NAME.getPreferredName(),
|
|
|
- RSquared.Result::new));
|
|
|
+ // Evaluation metrics
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
|
|
+ registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME),
|
|
|
+ AucRoc::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
|
|
+ registeredMetricName(BinarySoftClassification.NAME, Precision.NAME),
|
|
|
+ Precision::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
|
|
+ registeredMetricName(BinarySoftClassification.NAME, Recall.NAME),
|
|
|
+ Recall::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
|
|
+ registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME),
|
|
|
+ ConfusionMatrix::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
|
|
+ registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
|
|
|
+ MulticlassConfusionMatrix::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
|
|
+ registeredMetricName(Classification.NAME, Accuracy.NAME),
|
|
|
+ Accuracy::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
|
|
+ registeredMetricName(
|
|
|
+ Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
|
|
|
+ org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
|
|
+ registeredMetricName(
|
|
|
+ Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
|
|
|
+ org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
|
|
+ registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
|
|
|
+ MeanSquaredError::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetric.class,
|
|
|
+ registeredMetricName(Regression.NAME, RSquared.NAME),
|
|
|
+ RSquared::new),
|
|
|
|
|
|
- return namedWriteables;
|
|
|
+ // Evaluation metrics results
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
+ registeredMetricName(BinarySoftClassification.NAME, AucRoc.NAME),
|
|
|
+ AucRoc.Result::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
+ registeredMetricName(BinarySoftClassification.NAME, ScoreByThresholdResult.NAME),
|
|
|
+ ScoreByThresholdResult::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
+ registeredMetricName(BinarySoftClassification.NAME, ConfusionMatrix.NAME),
|
|
|
+ ConfusionMatrix.Result::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
+ registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME),
|
|
|
+ MulticlassConfusionMatrix.Result::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
+ registeredMetricName(Classification.NAME, Accuracy.NAME),
|
|
|
+ Accuracy.Result::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
+ registeredMetricName(
|
|
|
+ Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.NAME),
|
|
|
+ org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Precision.Result::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
+ registeredMetricName(
|
|
|
+ Classification.NAME, org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.NAME),
|
|
|
+ org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Recall.Result::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
+ registeredMetricName(Regression.NAME, MeanSquaredError.NAME),
|
|
|
+ MeanSquaredError.Result::new),
|
|
|
+ new NamedWriteableRegistry.Entry(EvaluationMetricResult.class,
|
|
|
+ registeredMetricName(Regression.NAME, RSquared.NAME),
|
|
|
+ RSquared.Result::new)
|
|
|
+ );
|
|
|
}
|
|
|
}
|