Jelajahi Sumber

[ML] updating feature_importance results mapping (#61104)

This updates the feature_importance mapping change from elastic/ml-cpp#1387
Benjamin Trent 5 tahun lalu
induk
melakukan
69f706634e

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

@@ -312,7 +312,7 @@ public class Classification implements DataFrameAnalysis {
     @Override
     public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
         Map<String, Object> additionalProperties = new HashMap<>();
-        additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
+        additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.classificationFeatureImportanceMapping());
         Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
         if ((dependentVariableMapping instanceof Map) == false) {
             return additionalProperties;

+ 32 - 8
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MapUtils.java

@@ -18,22 +18,46 @@ import java.util.Map;
 
 final class MapUtils {
 
-    private static final Map<String, Object> FEATURE_IMPORTANCE_MAPPING;
-    static {
-        Map<String, Object> featureImportanceMappingProperties = new HashMap<>();
+    private static Map<String, Object> createFeatureImportanceMapping(Map<String, Object> featureImportanceMappingProperties){
         featureImportanceMappingProperties.put("feature_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
-        featureImportanceMappingProperties.put("importance",
-            Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
         Map<String, Object> featureImportanceMapping = new HashMap<>();
         // TODO sorted indices don't support nested types
         //featureImportanceMapping.put("dynamic", true);
         //featureImportanceMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
         featureImportanceMapping.put("properties", featureImportanceMappingProperties);
-        FEATURE_IMPORTANCE_MAPPING = Collections.unmodifiableMap(featureImportanceMapping);
+        return featureImportanceMapping;
+    }
+
+    private static final Map<String, Object> CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING;
+    static {
+        Map<String, Object> classImportancePropertiesMapping = new HashMap<>();
+        // TODO sorted indices don't support nested types
+        //classImportancePropertiesMapping.put("dynamic", true);
+        //classImportancePropertiesMapping.put("type", ObjectMapper.NESTED_CONTENT_TYPE);
+        classImportancePropertiesMapping.put("class_name", Collections.singletonMap("type", KeywordFieldMapper.CONTENT_TYPE));
+        classImportancePropertiesMapping.put("importance",
+            Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
+        Map<String, Object> featureImportancePropertiesMapping = new HashMap<>();
+        featureImportancePropertiesMapping.put("classes", Collections.singletonMap("properties", classImportancePropertiesMapping));
+        CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING =
+            Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping));
+    }
+
+    private static final Map<String, Object> REGRESSION_FEATURE_IMPORTANCE_MAPPING;
+    static {
+        Map<String, Object> featureImportancePropertiesMapping = new HashMap<>();
+        featureImportancePropertiesMapping.put("importance",
+            Collections.singletonMap("type", NumberFieldMapper.NumberType.DOUBLE.typeName()));
+        REGRESSION_FEATURE_IMPORTANCE_MAPPING =
+            Collections.unmodifiableMap(createFeatureImportanceMapping(featureImportancePropertiesMapping));
+    }
+
+    static Map<String, Object> regressionFeatureImportanceMapping() {
+        return REGRESSION_FEATURE_IMPORTANCE_MAPPING;
     }
 
-    static Map<String, Object> featureImportanceMapping() {
-        return FEATURE_IMPORTANCE_MAPPING;
+    static Map<String, Object> classificationFeatureImportanceMapping() {
+        return CLASSIFICATION_FEATURE_IMPORTANCE_MAPPING;
     }
 
     private MapUtils() {}

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

@@ -233,7 +233,7 @@ public class Regression implements DataFrameAnalysis {
     @Override
     public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
         Map<String, Object> additionalProperties = new HashMap<>();
-        additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
+        additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.regressionFeatureImportanceMapping());
         // Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
         // high (over 10M) values of dependent variable.
         additionalProperties.put(resultsFieldName + "." + predictionFieldName,

+ 117 - 19
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java

@@ -5,6 +5,7 @@
  */
 package org.elasticsearch.xpack.core.ml.inference.results;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -16,65 +17,74 @@ import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.stream.Collectors;
 
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
 
 public class FeatureImportance implements Writeable, ToXContentObject {
 
-    private final Map<String, Double> classImportance;
+    private final List<ClassImportance> classImportance;
     private final double importance;
     private final String featureName;
     static final String IMPORTANCE = "importance";
     static final String FEATURE_NAME = "feature_name";
-    static final String CLASS_IMPORTANCE = "class_importance";
+    static final String CLASSES = "classes";
 
     public static FeatureImportance forRegression(String featureName, double importance) {
         return new FeatureImportance(featureName, importance, null);
     }
 
-    public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
-        return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
+    public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
+        return new FeatureImportance(featureName,
+            classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
+            classImportance);
     }
 
     @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<FeatureImportance, Void> PARSER =
         new ConstructingObjectParser<>("feature_importance",
-            a -> new FeatureImportance((String) a[0], (Double) a[1], (Map<String, Double>) a[2])
+            a -> new FeatureImportance((String) a[0], (Double) a[1], (List<ClassImportance>) a[2])
         );
 
     static {
         PARSER.declareString(constructorArg(), new ParseField(FeatureImportance.FEATURE_NAME));
         PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
-        PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
-            new ParseField(FeatureImportance.CLASS_IMPORTANCE));
+        PARSER.declareObjectArray(optionalConstructorArg(),
+            (p, c) -> ClassImportance.fromXContent(p),
+            new ParseField(FeatureImportance.CLASSES));
     }
 
     public static FeatureImportance fromXContent(XContentParser parser) {
         return PARSER.apply(parser, null);
     }
 
-    FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
+    FeatureImportance(String featureName, double importance, List<ClassImportance> classImportance) {
         this.featureName = Objects.requireNonNull(featureName);
         this.importance = importance;
-        this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
+        this.classImportance = classImportance == null ? null : Collections.unmodifiableList(classImportance);
     }
 
     public FeatureImportance(StreamInput in) throws IOException {
         this.featureName = in.readString();
         this.importance = in.readDouble();
         if (in.readBoolean()) {
-            this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
+            if (in.getVersion().before(Version.V_7_10_0)) {
+                Map<String, Double> classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
+                this.classImportance = ClassImportance.fromMap(classImportance);
+            } else {
+                this.classImportance = in.readList(ClassImportance::new);
+            }
         } else {
             this.classImportance = null;
         }
     }
 
-    public Map<String, Double> getClassImportance() {
+    public List<ClassImportance> getClassImportance() {
         return classImportance;
     }
 
@@ -92,7 +102,11 @@ public class FeatureImportance implements Writeable, ToXContentObject {
         out.writeDouble(this.importance);
         out.writeBoolean(this.classImportance != null);
         if (this.classImportance != null) {
-            out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
+            if (out.getVersion().before(Version.V_7_10_0)) {
+                out.writeMap(ClassImportance.toMap(this.classImportance), StreamOutput::writeString, StreamOutput::writeDouble);
+            } else {
+                out.writeList(this.classImportance);
+            }
         }
     }
 
@@ -101,7 +115,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
         map.put(FEATURE_NAME, featureName);
         map.put(IMPORTANCE, importance);
         if (classImportance != null) {
-            classImportance.forEach(map::put);
+            map.put(CLASSES, classImportance.stream().map(ClassImportance::toMap).collect(Collectors.toList()));
         }
         return map;
     }
@@ -112,11 +126,7 @@ public class FeatureImportance implements Writeable, ToXContentObject {
         builder.field(FEATURE_NAME, featureName);
         builder.field(IMPORTANCE, importance);
         if (classImportance != null && classImportance.isEmpty() == false) {
-            builder.startObject(CLASS_IMPORTANCE);
-            for (Map.Entry<String, Double> entry : classImportance.entrySet()) {
-                builder.field(entry.getKey(), entry.getValue());
-            }
-            builder.endObject();
+            builder.field(CLASSES, classImportance);
         }
         builder.endObject();
         return builder;
@@ -136,4 +146,92 @@ public class FeatureImportance implements Writeable, ToXContentObject {
     public int hashCode() {
         return Objects.hash(featureName, importance, classImportance);
     }
+
+    public static class ClassImportance implements Writeable, ToXContentObject {
+
+        static final String CLASS_NAME = "class_name";
+
+        private static final ConstructingObjectParser<ClassImportance, Void> PARSER =
+            new ConstructingObjectParser<>("feature_importance_class_importance",
+                a -> new ClassImportance((String) a[0], (Double) a[1])
+            );
+
+        static {
+            PARSER.declareString(constructorArg(), new ParseField(CLASS_NAME));
+            PARSER.declareDouble(constructorArg(), new ParseField(FeatureImportance.IMPORTANCE));
+        }
+
+        private static ClassImportance fromMapEntry(Map.Entry<String, Double> entry) {
+            return new ClassImportance(entry.getKey(), entry.getValue());
+        }
+
+        private static List<ClassImportance> fromMap(Map<String, Double> classImportanceMap) {
+            return classImportanceMap.entrySet().stream().map(ClassImportance::fromMapEntry).collect(Collectors.toList());
+        }
+
+        private static Map<String, Double> toMap(List<ClassImportance> importances) {
+            return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance));
+        }
+
+        public static ClassImportance fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        private final String className;
+        private final double importance;
+
+        public ClassImportance(String className, double importance) {
+            this.className = className;
+            this.importance = importance;
+        }
+
+        public ClassImportance(StreamInput in) throws IOException {
+            this.className = in.readString();
+            this.importance = in.readDouble();
+        }
+
+        public String getClassName() {
+            return className;
+        }
+
+        public double getImportance() {
+            return importance;
+        }
+
+        public Map<String, Object> toMap() {
+            Map<String, Object> map = new LinkedHashMap<>();
+            map.put(CLASS_NAME, className);
+            map.put(IMPORTANCE, importance);
+            return map;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(className);
+            out.writeDouble(importance);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(CLASS_NAME, className);
+            builder.field(IMPORTANCE, importance);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            ClassImportance that = (ClassImportance) o;
+            return Double.compare(that.importance, importance) == 0 &&
+                Objects.equals(className, that.className);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(className, importance);
+        }
+    }
 }

+ 4 - 3
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

@@ -15,7 +15,6 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
-import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -139,11 +138,13 @@ public final class InferenceHelpers {
             if (v.length == 1) {
                 importances.add(FeatureImportance.forRegression(k, v[0]));
             } else {
-                Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
+                List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
                 // If the classificationLabels exist, their length must match leaf_value length
                 assert classificationLabels == null || classificationLabels.size() == v.length;
                 for (int i = 0; i < v.length; i++) {
-                    classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
+                    classImportance.add(new FeatureImportance.ClassImportance(
+                        classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i),
+                        v[i]));
                 }
                 importances.add(FeatureImportance.forClassification(k, classImportance));
             }

+ 6 - 6
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

@@ -259,12 +259,12 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
 
     public void testGetExplicitlyMappedFields() {
         assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"),
-            equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
+            equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
         assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"),
-            equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
+            equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
         assertThat(
             new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
-            equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
+            equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
         Map<String, Object> explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
             Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
             "results");
@@ -272,7 +272,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
             allOf(
                 hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
                 hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
-        assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
+        assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()));
 
         explicitlyMappedFields = new Classification("foo").getExplicitlyMappedFields(
             new HashMap<>() {{
@@ -287,7 +287,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
             allOf(
                 hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
                 hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
-        assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
+        assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.classificationFeatureImportanceMapping()));
 
         assertThat(
             new Classification("foo").getExplicitlyMappedFields(
@@ -296,7 +296,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
                     put("path", "missing");
                 }}),
                 "results"),
-            equalTo(Collections.singletonMap("results.feature_importance", MapUtils.featureImportanceMapping())));
+            equalTo(Collections.singletonMap("results.feature_importance", MapUtils.classificationFeatureImportanceMapping())));
     }
 
     public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

@@ -206,7 +206,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
     public void testGetExplicitlyMappedFields() {
         Map<String, Object> explicitlyMappedFields = new Regression("foo").getExplicitlyMappedFields(null, "results");
         assertThat(explicitlyMappedFields, hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
-        assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.featureImportanceMapping()));
+        assertThat(explicitlyMappedFields, hasEntry("results.feature_importance", MapUtils.regressionFeatureImportanceMapping()));
     }
 
     public void testGetStateDocId() {

+ 9 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java

@@ -152,8 +152,15 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
             FeatureImportance importance = importanceList.get(i);
             assertThat(objectMap.get("feature_name"), equalTo(importance.getFeatureName()));
             assertThat(objectMap.get("importance"), equalTo(importance.getImportance()));
+            @SuppressWarnings("unchecked")
+            List<Map<String, Object>> classImportances = (List<Map<String, Object>>)objectMap.get("classes");
             if (importance.getClassImportance() != null) {
-                importance.getClassImportance().forEach((k, v) -> assertThat(objectMap.get(k), equalTo(v)));
+                for (int j = 0; j < importance.getClassImportance().size(); j++) {
+                    Map<String, Object> classMap = classImportances.get(j);
+                    FeatureImportance.ClassImportance classImportance = importance.getClassImportance().get(j);
+                    assertThat(classMap.get("class_name"), equalTo(classImportance.getClassName()));
+                    assertThat(classMap.get("importance"), equalTo(classImportance.getImportance()));
+                }
             }
         }
     }
@@ -205,7 +212,7 @@ public class ClassificationInferenceResultsTests extends AbstractWireSerializing
         expected = "{\"predicted_value\":\"label1\",\"prediction_probability\":1.0,\"prediction_score\":1.0}";
         assertEquals(expected, stringRep);
 
-        FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
+        FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList());
         TopClassEntry tp = new TopClassEntry("class", 1.0, 1.0);
         result = new ClassificationInferenceResults(1.0, "label1", Collections.singletonList(tp),
             Collections.singletonList(fi), config,

+ 2 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportanceTests.java

@@ -10,7 +10,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.test.AbstractSerializingTestCase;
 
 import java.io.IOException;
-import java.util.function.Function;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
@@ -29,7 +28,8 @@ public class FeatureImportanceTests extends AbstractSerializingTestCase<FeatureI
             randomAlphaOfLength(10),
             Stream.generate(() -> randomAlphaOfLength(10))
                 .limit(randomLongBetween(2, 10))
-                .collect(Collectors.toMap(Function.identity(), (k) -> randomDoubleBetween(-10, 10, false))));
+                .map(name -> new FeatureImportance.ClassImportance(name, randomDoubleBetween(-10, 10, false)))
+                .collect(Collectors.toList()));
     }
 
     @Override

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java

@@ -92,7 +92,7 @@ public class RegressionInferenceResultsTests extends AbstractWireSerializingTest
         String expected = "{\"" + resultsField + "\":1.0}";
         assertEquals(expected, stringRep);
 
-        FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyMap());
+        FeatureImportance fi = new FeatureImportance("foo", 1.0, Collections.emptyList());
         result = new RegressionInferenceResults(1.0, resultsField, Collections.singletonList(fi));
         stringRep = Strings.toString(result);
         expected = "{\"" + resultsField + "\":1.0,\"feature_importance\":[{\"feature_name\":\"foo\",\"importance\":1.0}]}";