Browse Source

[ML] adding `baseline` field to total_feature_importance objects (#63098)

This adds a new `baseline` field to the feature importance values. 

This field contains the baseline importance for a given feature and class.
Benjamin Trent 5 years ago
parent
commit
8eb83d3b34

+ 12 - 4
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java

@@ -40,6 +40,7 @@ public class TotalFeatureImportance implements ToXContentObject {
     public static final ParseField IMPORTANCE = new ParseField("importance");
     public static final ParseField CLASSES = new ParseField("classes");
     public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
+    public static final ParseField BASELINE = new ParseField("baseline");
     public static final ParseField MIN = new ParseField("min");
     public static final ParseField MAX = new ParseField("max");
 
@@ -102,22 +103,25 @@ public class TotalFeatureImportance implements ToXContentObject {
 
         public static final ConstructingObjectParser<Importance, Void> PARSER = new ConstructingObjectParser<>(NAME,
             true,
-            a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
+            a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3]));
 
         static {
             PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
             PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
             PARSER.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
+            PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
         }
 
         private final double meanMagnitude;
         private final double min;
         private final double max;
+        private final Double baseline;
 
-        public Importance(double meanMagnitude, double min, double max) {
+        public Importance(double meanMagnitude, double min, double max, Double baseline) {
             this.meanMagnitude = meanMagnitude;
             this.min = min;
             this.max = max;
+            this.baseline = baseline;
         }
 
         @Override
@@ -127,12 +131,13 @@ public class TotalFeatureImportance implements ToXContentObject {
             Importance that = (Importance) o;
             return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
                 Double.compare(that.min, min) == 0 &&
-                Double.compare(that.max, max) == 0;
+                Double.compare(that.max, max) == 0 &&
+                Objects.equals(that.baseline, baseline);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(meanMagnitude, min, max);
+            return Objects.hash(meanMagnitude, min, max, baseline);
         }
 
         @Override
@@ -141,6 +146,9 @@ public class TotalFeatureImportance implements ToXContentObject {
             builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
             builder.field(MIN.getPreferredName(), min);
             builder.field(MAX.getPreferredName(), max);
+            if (baseline != null) {
+                builder.field(BASELINE.getPreferredName(), baseline);
+            }
             builder.endObject();
             return builder;
         }

+ 5 - 1
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java

@@ -50,7 +50,11 @@ public class TotalFeatureImportanceTests extends AbstractXContentTestCase<TotalF
     }
 
     private static TotalFeatureImportance.Importance randomImportance() {
-        return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
+        return new TotalFeatureImportance.Importance(
+            randomDouble(),
+            randomDouble(),
+            randomDouble(),
+            randomBoolean() ? null : randomDouble());
     }
 
     @Override

+ 14 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java

@@ -35,6 +35,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
     public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude");
     public static final ParseField MIN = new ParseField("min");
     public static final ParseField MAX = new ParseField("max");
+    public static final ParseField BASELINE = new ParseField("baseline");
 
     // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
     public static final ConstructingObjectParser<TotalFeatureImportance, Void> LENIENT_PARSER = createParser(true);
@@ -124,27 +125,31 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
         private static ConstructingObjectParser<Importance, Void> createParser(boolean ignoreUnknownFields) {
             ConstructingObjectParser<Importance, Void> parser = new ConstructingObjectParser<>(NAME,
                 ignoreUnknownFields,
-                a -> new Importance((double)a[0], (double)a[1], (double)a[2]));
+                a -> new Importance((double)a[0], (double)a[1], (double)a[2], (Double)a[3]));
             parser.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE);
             parser.declareDouble(ConstructingObjectParser.constructorArg(), MIN);
             parser.declareDouble(ConstructingObjectParser.constructorArg(), MAX);
+            parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
             return parser;
         }
 
         private final double meanMagnitude;
         private final double min;
         private final double max;
+        private final Double baseline;
 
-        public Importance(double meanMagnitude, double min, double max) {
+        public Importance(double meanMagnitude, double min, double max, Double baseline) {
             this.meanMagnitude = meanMagnitude;
             this.min = min;
             this.max = max;
+            this.baseline = baseline;
         }
 
         public Importance(StreamInput in) throws IOException {
             this.meanMagnitude = in.readDouble();
             this.min = in.readDouble();
             this.max = in.readDouble();
+            this.baseline = in.readOptionalDouble();
         }
 
         @Override
@@ -154,12 +159,13 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
             Importance that = (Importance) o;
             return Double.compare(that.meanMagnitude, meanMagnitude) == 0 &&
                 Double.compare(that.min, min) == 0 &&
-                Double.compare(that.max, max) == 0;
+                Double.compare(that.max, max) == 0 &&
+                Objects.equals(that.baseline, baseline);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(meanMagnitude, min, max);
+            return Objects.hash(meanMagnitude, min, max, baseline);
         }
 
         @Override
@@ -167,6 +173,7 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
             out.writeDouble(meanMagnitude);
             out.writeDouble(min);
             out.writeDouble(max);
+            out.writeOptionalDouble(baseline);
         }
 
         @Override
@@ -179,6 +186,9 @@ public class TotalFeatureImportance implements ToXContentObject, Writeable {
             map.put(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude);
             map.put(MIN.getPreferredName(), min);
             map.put(MAX.getPreferredName(), max);
+            if (baseline != null) {
+                map.put(BASELINE.getPreferredName(), baseline);
+            }
             return map;
         }
     }

+ 6 - 0
x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json

@@ -85,6 +85,9 @@
                 },
                 "mean_magnitude": {
                   "type": "double"
+                },
+                "baseline": {
+                  "type": "double"
                 }
               }
             },
@@ -105,6 +108,9 @@
                     },
                     "mean_magnitude": {
                       "type": "double"
+                    },
+                    "baseline": {
+                      "type": "double"
                     }
                   }
                 },

+ 5 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java

@@ -41,7 +41,11 @@ public class TotalFeatureImportanceTests extends AbstractBWCSerializationTestCas
     }
 
     private static TotalFeatureImportance.Importance randomImportance() {
-        return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble());
+        return new TotalFeatureImportance.Importance(
+            randomDouble(),
+            randomDouble(),
+            randomDouble(),
+            randomBoolean() ? null : randomDouble());
     }
 
     @Before