Browse Source

[ML] Support boolean fields for DF analytics (#46037)

This commit adds support for `boolean` fields in data frame
analytics (and currently both outlier detection and regression).
The analytics process expects `boolean` fields to be encoded as
integers with 0 or 1 value.
Dimitris Athanasiou 6 years ago
parent
commit
9c6d9a9049

+ 51 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java

@@ -11,9 +11,12 @@ import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.fieldcaps.FieldCapabilities;
 import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.document.DocumentField;
 import org.elasticsearch.common.regex.Regex;
 import org.elasticsearch.index.IndexSettings;
+import org.elasticsearch.index.mapper.BooleanFieldMapper;
 import org.elasticsearch.index.mapper.NumberFieldMapper;
+import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
@@ -33,6 +36,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
+import java.util.TreeSet;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
@@ -101,6 +105,7 @@ public class ExtractedFieldsDetector {
                     IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.getKey());
             }
         }
+        extractedFields = fetchBooleanFieldsAsIntegers(extractedFields);
         return extractedFields;
     }
 
@@ -144,6 +149,8 @@ public class ExtractedFieldsDetector {
                     LOGGER.debug("[{}] field [{}] is compatible as it is numerical", config.getId(), field);
                 } else if (config.getAnalysis().supportsCategoricalFields() && CATEGORICAL_TYPES.containsAll(fieldTypes)) {
                     LOGGER.debug("[{}] field [{}] is compatible as it is categorical", config.getId(), field);
+                } else if (isBoolean(fieldTypes)) {
+                    LOGGER.debug("[{}] field [{}] is compatible as it is boolean", config.getId(), field);
                 } else {
                     LOGGER.debug("[{}] Removing field [{}] because its types are not supported; types {}; supported {}",
                         config.getId(), field, fieldTypes, getSupportedTypes());
@@ -154,10 +161,11 @@ public class ExtractedFieldsDetector {
     }
 
     private Set<String> getSupportedTypes() {
-        Set<String> supportedTypes = new HashSet<>(NUMERICAL_TYPES);
+        Set<String> supportedTypes = new TreeSet<>(NUMERICAL_TYPES);
         if (config.getAnalysis().supportsCategoricalFields()) {
             supportedTypes.addAll(CATEGORICAL_TYPES);
         }
+        supportedTypes.add(BooleanFieldMapper.CONTENT_TYPE);
         return supportedTypes;
     }
 
@@ -211,4 +219,46 @@ public class ExtractedFieldsDetector {
         }
         return new ExtractedFields(adjusted);
     }
+
+    private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) {
+        List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
+        for (ExtractedField field : extractedFields.getAllFields()) {
+            if (isBoolean(field.getTypes())) {
+                adjusted.add(new BooleanAsInteger(field));
+            } else {
+                adjusted.add(field);
+            }
+        }
+        return new ExtractedFields(adjusted);
+    }
+
+    private static boolean isBoolean(Set<String> types) {
+        return types.size() == 1 && types.contains(BooleanFieldMapper.CONTENT_TYPE);
+    }
+
+    /**
+     * We convert boolean fields to integers with values 0, 1 as this is the preferred
+     * way to consume such features in the analytics process.
+     */
+    private static class BooleanAsInteger extends ExtractedField {
+
+        protected BooleanAsInteger(ExtractedField field) {
+            super(field.getAlias(), field.getName(), Collections.singleton(BooleanFieldMapper.CONTENT_TYPE), ExtractionMethod.DOC_VALUE);
+        }
+
+        @Override
+        public Object[] value(SearchHit hit) {
+            DocumentField keyValue = hit.field(name);
+            if (keyValue != null) {
+                List<Object> values = keyValue.getValues().stream().map(v -> Boolean.TRUE.equals(v) ? 1 : 0).collect(Collectors.toList());
+                return values.toArray(new Object[0]);
+            }
+            return new Object[0];
+        }
+
+        @Override
+        public boolean supportsFromSource() {
+            return false;
+        }
+    }
 }

+ 45 - 9
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java

@@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.dataframe.extractor;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.fieldcaps.FieldCapabilities;
 import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
+import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@@ -17,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
 import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
 import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
 import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields;
+import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -26,6 +28,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
+import static org.hamcrest.Matchers.arrayContaining;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
@@ -67,7 +70,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
         assertThat(allFields.get(0).getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE));
     }
 
-    public void testDetect_GivenNonNumericField() {
+    public void testDetect_GivenOutlierDetectionAndNonNumericField() {
         FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
             .addAggregatableField("some_keyword", "keyword").build();
 
@@ -76,7 +79,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
 
         assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]." +
-            " Supported types are [scaled_float, double, byte, short, half_float, integer, float, long]."));
+            " Supported types are [boolean, byte, double, float, half_float, integer, long, scaled_float, short]."));
     }
 
     public void testDetect_GivenOutlierDetectionAndFieldWithNumericAndNonNumericTypes() {
@@ -88,7 +91,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
 
         assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]. " +
-            "Supported types are [scaled_float, double, byte, short, half_float, integer, float, long]."));
+            "Supported types are [boolean, byte, double, float, half_float, integer, long, scaled_float, short]."));
     }
 
     public void testDetect_GivenOutlierDetectionAndMultipleFields() {
@@ -96,6 +99,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
             .addAggregatableField("some_float", "float")
             .addAggregatableField("some_long", "long")
             .addAggregatableField("some_keyword", "keyword")
+            .addAggregatableField("some_boolean", "boolean")
             .build();
 
         ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
@@ -103,9 +107,9 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
         ExtractedFields extractedFields = extractedFieldsDetector.detect();
 
         List<ExtractedField> allFields = extractedFields.getAllFields();
-        assertThat(allFields.size(), equalTo(2));
+        assertThat(allFields.size(), equalTo(3));
         assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toSet()),
-            containsInAnyOrder("some_float", "some_long"));
+            containsInAnyOrder("some_float", "some_long", "some_boolean"));
         assertThat(allFields.stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()),
             contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)));
     }
@@ -115,6 +119,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
             .addAggregatableField("some_float", "float")
             .addAggregatableField("some_long", "long")
             .addAggregatableField("some_keyword", "keyword")
+            .addAggregatableField("some_boolean", "boolean")
             .addAggregatableField("foo", "keyword")
             .build();
 
@@ -123,9 +128,9 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
         ExtractedFields extractedFields = extractedFieldsDetector.detect();
 
         List<ExtractedField> allFields = extractedFields.getAllFields();
-        assertThat(allFields.size(), equalTo(4));
+        assertThat(allFields.size(), equalTo(5));
         assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toList()),
-            contains("foo", "some_float", "some_keyword", "some_long"));
+            containsInAnyOrder("foo", "some_float", "some_keyword", "some_long", "some_boolean"));
         assertThat(allFields.stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()),
             contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)));
     }
@@ -153,7 +158,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
 
         assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]. " +
-            "Supported types are [scaled_float, double, byte, short, half_float, integer, float, long]."));
+            "Supported types are [boolean, byte, double, float, half_float, integer, long, scaled_float, short]."));
     }
 
     public void testDetect_ShouldSortFieldsAlphabetically() {
@@ -207,7 +212,7 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
             SOURCE_INDEX, buildOutlierDetectionConfig(desiredFields), RESULTS_FIELD, false, 100, fieldCapabilities);
         ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect());
         assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]. " +
-            "Supported types are [scaled_float, double, byte, short, half_float, integer, float, long]."));
+            "Supported types are [boolean, byte, double, float, half_float, integer, long, scaled_float, short]."));
     }
 
     public void testDetectedExtractedFields_GivenInclusionsAndExclusions() {
@@ -336,6 +341,37 @@ public class ExtractedFieldsDetectorTests extends ESTestCase {
             contains(equalTo(ExtractedField.ExtractionMethod.SOURCE)));
     }
 
+    public void testDetect_GivenBooleanField() {
+        FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
+            .addAggregatableField("some_boolean", "boolean")
+            .build();
+
+        ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
+            SOURCE_INDEX, buildOutlierDetectionConfig(), RESULTS_FIELD, false, 100, fieldCapabilities);
+        ExtractedFields extractedFields = extractedFieldsDetector.detect();
+
+        List<ExtractedField> allFields = extractedFields.getAllFields();
+        assertThat(allFields.size(), equalTo(1));
+        ExtractedField booleanField = allFields.get(0);
+        assertThat(booleanField.getTypes(), contains("boolean"));
+        assertThat(booleanField.getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE));
+
+        SearchHit hit = new SearchHitBuilder(42).addField("some_boolean", true).build();
+        Object[] values = booleanField.value(hit);
+        assertThat(values.length, equalTo(1));
+        assertThat(values[0], equalTo(1));
+
+        hit = new SearchHitBuilder(42).addField("some_boolean", false).build();
+        values = booleanField.value(hit);
+        assertThat(values.length, equalTo(1));
+        assertThat(values[0], equalTo(0));
+
+        hit = new SearchHitBuilder(42).addField("some_boolean", Arrays.asList(false, true, false)).build();
+        values = booleanField.value(hit);
+        assertThat(values.length, equalTo(3));
+        assertThat(values, arrayContaining(0, 1, 0));
+    }
+
     private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
         return buildOutlierDetectionConfig(null);
     }