Browse Source

[ML] Use query in cardinality check (#49939)

When checking the cardinality of a field, the query should be take into account. The user might know about some bad data in their index and want to filter down to the target_field values they care about.
Benjamin Trent 5 years ago
parent
commit
dd66fae755

+ 24 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -15,6 +15,8 @@ import org.elasticsearch.action.index.IndexAction;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.WriteRequest;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@@ -228,7 +230,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertEvaluation(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "ml.boolean-field_prediction");
     }
 
-    public void testDependentVariableCardinalityTooHighError() {
+    public void testDependentVariableCardinalityTooHighError() throws Exception {
         initialize("cardinality_too_high");
         indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
         // Index one more document with a class different than the two already used.
@@ -246,6 +248,27 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]"));
     }
 
+    public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRange() throws Exception {
+        initialize("cardinality_too_high_with_query");
+        indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
+        // Index one more document with a class different than the two already used.
+        client().execute(IndexAction.INSTANCE, new IndexRequest(sourceIndex)
+            .source(KEYWORD_FIELD, "fox")
+            .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE))
+            .actionGet();
+        QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.termsQuery(KEYWORD_FIELD, KEYWORD_FIELD_VALUES));
+
+        DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD), query);
+        registerAnalytics(config);
+        putAnalytics(config);
+
+        // Should not throw
+        startAnalytics(jobId);
+        waitUntilAnalyticsIsStopped(jobId);
+
+        assertProgress(jobId, 100, 100, 100, 100);
+    }
+
     private void initialize(String jobId) {
         this.jobId = jobId;
         this.sourceIndex = jobId + "_source_index";

+ 10 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

@@ -16,6 +16,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
@@ -37,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
 import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.core.ml.notifications.AuditorField;
 import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
+import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
 import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
 import org.hamcrest.Matcher;
 import org.hamcrest.Matchers;
@@ -161,10 +163,16 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
     }
 
     protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex,
-                                                             @Nullable String resultsField, DataFrameAnalysis analysis) {
+                                                             @Nullable String resultsField, DataFrameAnalysis analysis) throws Exception {
+        return buildAnalytics(id, sourceIndex, destIndex, resultsField, analysis, QueryBuilders.matchAllQuery());
+    }
+
+    protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex,
+                                                             @Nullable String resultsField, DataFrameAnalysis analysis,
+                                                             QueryBuilder queryBuilder) throws Exception {
         return new DataFrameAnalyticsConfig.Builder()
             .setId(id)
-            .setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null, null))
+            .setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, QueryProvider.fromParsedQuery(queryBuilder), null))
             .setDest(new DataFrameAnalyticsDest(destIndex, resultsField))
             .setAnalysis(analysis)
             .build();

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorFactory.java

@@ -109,7 +109,7 @@ public class ExtractedFieldsDetectorFactory {
             listener::onFailure
         );
 
-        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0);
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(config.getSource().getParsedQuery());
         for (Map.Entry<String, Long> entry : fieldCardinalityLimits.entrySet()) {
             String fieldName = entry.getKey();
             Long limit = entry.getValue();