|
@@ -5,7 +5,6 @@
|
|
|
*/
|
|
|
package org.elasticsearch.xpack.core.ml.inference.results;
|
|
|
|
|
|
-import org.elasticsearch.Version;
|
|
|
import org.elasticsearch.common.ParseField;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
@@ -26,157 +25,101 @@ import java.util.stream.Collectors;
|
|
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
|
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
|
|
|
|
|
-public class FeatureImportance implements Writeable, ToXContentObject {
|
|
|
+public class ClassificationFeatureImportance extends AbstractFeatureImportance {
|
|
|
|
|
|
private final List<ClassImportance> classImportance;
|
|
|
- private final double importance;
|
|
|
private final String featureName;
|
|
|
- static final String IMPORTANCE = "importance";
|
|
|
+
|
|
|
static final String FEATURE_NAME = "feature_name";
|
|
|
static final String CLASSES = "classes";
|
|
|
|
|
|
- public static FeatureImportance forRegression(String featureName, double importance) {
|
|
|
- return new FeatureImportance(featureName, importance, null);
|
|
|
- }
|
|
|
-
|
|
|
- public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
|
|
|
- return new FeatureImportance(featureName,
|
|
|
- importance,
|
|
|
- classImportance);
|
|
|
- }
|
|
|
-
|
|
|
- public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
|
|
|
- return new FeatureImportance(featureName,
|
|
|
- classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
|
|
|
- classImportance);
|
|
|
- }
|
|
|
-
|
|
|
@SuppressWarnings("unchecked")
|
|
|
- private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
|
|
|
- new ConstructingObjectParser<>("feature_importance",
|
|
|
- a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2])
|
|
|
+ private static final ConstructingObjectParser<ClassificationFeatureImportance, Void> PARSER =
|
|
|
+ new ConstructingObjectParser<>("classification_feature_importance",
|
|
|
+ a -> new ClassificationFeatureImportance((String) a[0], (List<ClassImportance>) a[1])
|
|
|
);
|
|
|
|
|
|
static {
|
|
|
- PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
|
|
|
- PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
|
|
+ PARSER.declareString(constructorArg(), new ParseField(ClassificationFeatureImportance.FEATURE_NAME));
|
|
|
PARSER.declareObjectArray(optionalConstructorArg(),
|
|
|
(p, c) -> ClassImportance.fromXContent(p),
|
|
|
- new ParseField(FeatureImportance.CLASSES));
|
|
|
+ new ParseField(ClassificationFeatureImportance.CLASSES));
|
|
|
}
|
|
|
|
|
|
- public static FeatureImportance fromXContent(XContentParser parser) {
|
|
|
+ public static ClassificationFeatureImportance fromXContent(XContentParser parser) {
|
|
|
return PARSER.apply(parser, null);
|
|
|
}
|
|
|
|
|
|
- FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
|
|
|
+ public ClassificationFeatureImportance(String featureName, List<ClassImportance> classImportance) {
|
|
|
this.featureName = Objects.requireNonNull(featureName);
|
|
|
- this.importance = importance;
|
|
|
- this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
|
|
|
+ this.classImportance = classImportance == null ? Collections.emptyList() : Collections.unmodifiableList(classImportance);
|
|
|
}
|
|
|
|
|
|
- public FeatureImportance(StreamInput in) throws IOException {
|
|
|
+ public ClassificationFeatureImportance(StreamInput in) throws IOException {
|
|
|
this.featureName = in.readString();
|
|
|
- this.importance = in.readDouble();
|
|
|
- if (in.readBoolean()) {
|
|
|
- if (in.getVersion().before(Version.V_7_10_0)) {
|
|
|
- Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
|
|
|
- this.classImportance = ClassImportance.fromMap(classImportance);
|
|
|
- } else {
|
|
|
- this.classImportance = in.readList(ClassImportance::new);
|
|
|
- }
|
|
|
- } else {
|
|
|
- this.classImportance = null;
|
|
|
- }
|
|
|
+ this.classImportance = in.readList(ClassImportance::new);
|
|
|
}
|
|
|
|
|
|
public List<ClassImportance> getClassImportance() {
|
|
|
return classImportance;
|
|
|
}
|
|
|
|
|
|
- public double getImportance() {
|
|
|
- return importance;
|
|
|
- }
|
|
|
-
|
|
|
+ @Override
|
|
|
public String getFeatureName() {
|
|
|
return featureName;
|
|
|
}
|
|
|
|
|
|
+ public double getTotalImportance() {
|
|
|
+ if (classImportance.size() == 2) {
|
|
|
+ // Binary classification. We can return the first class importance here
|
|
|
+ return Math.abs(classImportance.get(0).getImportance());
|
|
|
+ }
|
|
|
+ return classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum();
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public void writeTo(StreamOutput out) throws IOException {
|
|
|
- out.writeString(this.featureName);
|
|
|
- out.writeDouble(this.importance);
|
|
|
- out.writeBoolean(this.classImportance != null);
|
|
|
- if (this.classImportance != null) {
|
|
|
- if (out.getVersion().before(Version.V_7_10_0)) {
|
|
|
- out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
|
|
|
- } else {
|
|
|
- out.writeList(this.classImportance);
|
|
|
- }
|
|
|
- }
|
|
|
+ out.writeString(featureName);
|
|
|
+ out.writeList(classImportance);
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
public Map<String, Object> toMap() {
|
|
|
Map<String, Object> map = new LinkedHashMap<>();
|
|
|
map.put(FEATURE_NAME, featureName);
|
|
|
- map.put(IMPORTANCE, importance);
|
|
|
- if (classImportance != null) {
|
|
|
+ if (classImportance.isEmpty() == false) {
|
|
|
map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList()));
|
|
|
}
|
|
|
return map;
|
|
|
}
|
|
|
|
|
|
- @Override
|
|
|
- public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
- builder.startObject();
|
|
|
- builder.field(FEATURE_NAME, featureName);
|
|
|
- builder.field(IMPORTANCE, importance);
|
|
|
- if (classImportance != null && classImportance.isEmpty() == false) {
|
|
|
- builder.field(CLASSES, classImportance);
|
|
|
- }
|
|
|
- builder.endObject();
|
|
|
- return builder;
|
|
|
- }
|
|
|
-
|
|
|
@Override
|
|
|
public boolean equals(Object object) {
|
|
|
if (object == this) { return true; }
|
|
|
if (object == null || getClass() != object.getClass()) { return false; }
|
|
|
- FeatureImportance that = (FeatureImportance) object;
|
|
|
+ ClassificationFeatureImportance that = (ClassificationFeatureImportance) object;
|
|
|
return Objects.equals(featureName, that.featureName)
|
|
|
- && Objects.equals(importance, that.importance)
|
|
|
&& Objects.equals(classImportance, that.classImportance);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public int hashCode() {
|
|
|
- return Objects.hash(featureName, importance, classImportance);
|
|
|
+ return Objects.hash(featureName, classImportance);
|
|
|
}
|
|
|
|
|
|
public static class ClassImportance implements Writeable, ToXContentObject {
|
|
|
|
|
|
static final String CLASS_NAME = "class_name";
|
|
|
+ static final String IMPORTANCE = "importance";
|
|
|
|
|
|
private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
|
|
|
- new ConstructingObjectParser<>("feature_importance_class_importance",
|
|
|
- a -> new ClassImportance((String) a[0], (Double) a[1])
|
|
|
+ new ConstructingObjectParser<>("classification_feature_importance_class_importance",
|
|
|
+ a -> new ClassImportance(a[0], (Double) a[1])
|
|
|
);
|
|
|
|
|
|
static {
|
|
|
PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
|
|
|
- PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
|
|
|
- }
|
|
|
-
|
|
|
- private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
|
|
|
- return new ClassImportance(entry.getKey(), entry.getValue());
|
|
|
- }
|
|
|
-
|
|
|
- private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
|
|
|
- return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
|
|
|
- }
|
|
|
-
|
|
|
- private static Map<String, Double> toMap(List<ClassImportance> importances) {
|
|
|
- return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
|
|
|
+ PARSER.declareDouble(constructorArg(), new ParseField(IMPORTANCE));
|
|
|
}
|
|
|
|
|
|
public static ClassImportance fromXContent(XContentParser parser) {
|
|
@@ -219,11 +162,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
|
|
|
|
|
|
@Override
|
|
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
- builder.startObject();
|
|
|
- builder.field(CLASS_NAME, className);
|
|
|
- builder.field(IMPORTANCE, importance);
|
|
|
- builder.endObject();
|
|
|
- return builder;
|
|
|
+ return builder.map(toMap());
|
|
|
}
|
|
|
|
|
|
@Override
|