Pārlūkot izejas kodu

Require that the dependent variable column has at most 2 distinct values in classfication analysis. (#47858)

Przemysław Witek 6 gadi atpakaļ
vecāks
revīzija
390a8292f1

+ 6 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

@@ -152,6 +152,12 @@ public class Classification implements DataFrameAnalysis {
         return Collections.singletonList(new RequiredField(dependentVariable, Types.categorical()));
     }
 
+    @Override
+    public Map<String, Long> getFieldCardinalityLimits() {
+        // This restriction is due to the fact that currently the C++ backend only supports binomial classification.
+        return Collections.singletonMap(dependentVariable, 2L);
+    }
+
     @Override
     public boolean supportsMissingValues() {
         return true;

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

@@ -28,6 +28,11 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
      */
     List<RequiredField> getRequiredFields();
 
+    /**
+     * @return {@link Map} containing cardinality limits for the selected (analysis-specific) fields
+     */
+    Map<String, Long> getFieldCardinalityLimits();
+
     /**
      * @return {@code true} if this analysis supports data frame rows with missing values
      */

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

@@ -218,6 +218,11 @@ public class OutlierDetection implements DataFrameAnalysis {
         return Collections.emptyList();
     }
 
+    @Override
+    public Map<String, Long> getFieldCardinalityLimits() {
+        return Collections.emptyMap();
+    }
+
     @Override
     public boolean supportsMissingValues() {
         return false;

+ 5 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

@@ -139,6 +139,11 @@ public class Regression implements DataFrameAnalysis {
         return Collections.singletonList(new RequiredField(dependentVariable, Types.numerical()));
     }
 
+    @Override
+    public Map<String, Long> getFieldCardinalityLimits() {
+        return Collections.emptyMap();
+    }
+
     @Override
     public boolean supportsMissingValues() {
         return true;

+ 7 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

@@ -13,6 +13,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
 import java.io.IOException;
 
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
 
 public class ClassificationTests extends AbstractSerializingTestCase<Classification> {
 
@@ -65,4 +68,8 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
 
         assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
     }
+
+    public void testFieldCardinalityLimitsIsNonNull() {
+        assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
+    }
 }

+ 6 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

@@ -15,6 +15,8 @@ import java.util.Map;
 import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
 
 public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDetection> {
 
@@ -82,6 +84,10 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
         assertThat(params.get(OutlierDetection.STANDARDIZATION_ENABLED.getPreferredName()), is(false));
     }
 
+    public void testFieldCardinalityLimitsIsNonNull() {
+        assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
+    }
+
     public void testGetStateDocId() {
         OutlierDetection outlierDetection = createRandom();
         assertThat(outlierDetection.persistsState(), is(false));

+ 6 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

@@ -14,6 +14,8 @@ import java.io.IOException;
 
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
 
 public class RegressionTests extends AbstractSerializingTestCase<Regression> {
 
@@ -66,6 +68,10 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
         assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
     }
 
+    public void testFieldCardinalityLimitsIsNonNull() {
+        assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
+    }
+
     public void testGetStateDocId() {
         Regression regression = createRandom();
         assertThat(regression.persistsState(), is(true));

+ 53 - 57
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.ml.integration;
 
 import com.google.common.collect.Ordering;
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.bulk.BulkRequestBuilder;
 import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.get.GetResponse;
@@ -37,10 +38,10 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo;
 
 public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
-    private static final String NUMERICAL_FEATURE_FIELD = "feature";
-    private static final String DEPENDENT_VARIABLE_FIELD = "variable";
-    private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
-    private static final List<String> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat", "cow"));
+    private static final String NUMERICAL_FIELD = "numerical-field";
+    private static final String KEYWORD_FIELD = "keyword-field";
+    private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0));
+    private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));
 
     private String jobId;
     private String sourceIndex;
@@ -53,36 +54,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
         initialize("classification_single_numeric_feature_and_mixed_data_set");
+        indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
 
-        {  // Index 350 rows, 300 of them being training rows.
-            client().admin().indices().prepareCreate(sourceIndex)
-                .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
-                .get();
-
-            BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
-                .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
-            for (int i = 0; i < 300; i++) {
-                Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
-                String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
-
-                IndexRequest indexRequest = new IndexRequest(sourceIndex)
-                    .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
-                bulkRequestBuilder.add(indexRequest);
-            }
-            for (int i = 300; i < 350; i++) {
-                Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
-
-                IndexRequest indexRequest = new IndexRequest(sourceIndex)
-                    .source(NUMERICAL_FEATURE_FIELD, field);
-                bulkRequestBuilder.add(indexRequest);
-            }
-            BulkResponse bulkResponse = bulkRequestBuilder.get();
-            if (bulkResponse.hasFailures()) {
-                fail("Failed to index data: " + bulkResponse.buildFailureMessage());
-            }
-        }
-
-        DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
+        DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
         registerAnalytics(config);
         putAnalytics(config);
 
@@ -97,10 +71,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             Map<String, Object> destDoc = getDestDoc(config, hit);
             Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
 
-            assertThat(resultsObject.containsKey("variable_prediction"), is(true));
-            assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
+            assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
+            assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
             assertThat(resultsObject.containsKey("is_training"), is(true));
-            assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
+            assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
             assertThat(resultsObject.containsKey("top_classes"), is(false));
         }
 
@@ -117,9 +91,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
         initialize("classification_only_training_data_and_training_percent_is_100");
-        indexTrainingData(sourceIndex, 300);
+        indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
 
-        DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
+        DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
         registerAnalytics(config);
         putAnalytics(config);
 
@@ -133,8 +107,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         for (SearchHit hit : sourceData.getHits()) {
             Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
 
-            assertThat(resultsObject.containsKey("variable_prediction"), is(true));
-            assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
+            assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
+            assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
             assertThat(resultsObject.containsKey("is_training"), is(true));
             assertThat(resultsObject.get("is_training"), is(true));
             assertThat(resultsObject.containsKey("top_classes"), is(false));
@@ -153,7 +127,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
         initialize("classification_only_training_data_and_training_percent_is_50");
-        indexTrainingData(sourceIndex, 300);
+        indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
 
         DataFrameAnalyticsConfig config =
             buildAnalytics(
@@ -161,7 +135,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 sourceIndex,
                 destIndex,
                 null,
-                new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
+                new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
         registerAnalytics(config);
         putAnalytics(config);
 
@@ -176,8 +150,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
         for (SearchHit hit : sourceData.getHits()) {
             Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
-            assertThat(resultsObject.containsKey("variable_prediction"), is(true));
-            assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
+            assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
+            assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
 
             assertThat(resultsObject.containsKey("is_training"), is(true));
             // Let's just assert there's both training and non-training results
@@ -205,7 +179,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
     @AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712")
     public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
         initialize("classification_top_classes_requested");
-        indexTrainingData(sourceIndex, 300);
+        indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
 
         int numTopClasses = 2;
         DataFrameAnalyticsConfig config =
@@ -214,7 +188,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
                 sourceIndex,
                 destIndex,
                 null,
-                new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
+                new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
         registerAnalytics(config);
         putAnalytics(config);
 
@@ -229,8 +203,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             Map<String, Object> destDoc = getDestDoc(config, hit);
             Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
 
-            assertThat(resultsObject.containsKey("variable_prediction"), is(true));
-            assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
+            assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
+            assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
             assertTopClasses(resultsObject, numTopClasses);
         }
 
@@ -245,25 +219,47 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             "Finished analysis");
     }
 
+    public void testDependentVariableCardinalityTooHighError() {
+        initialize("cardinality_too_high");
+        indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox"));
+
+        DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
+        registerAnalytics(config);
+        putAnalytics(config);
+
+        ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> startAnalytics(jobId));
+        assertThat(e.status().getStatus(), equalTo(400));
+        assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]"));
+    }
+
     private void initialize(String jobId) {
         this.jobId = jobId;
         this.sourceIndex = jobId + "_source_index";
         this.destIndex = sourceIndex + "_results";
     }
 
-    private static void indexTrainingData(String sourceIndex, int numRows) {
+    private static void indexData(String sourceIndex,
+                                  int numTrainingRows, int numNonTrainingRows,
+                                  List<Double> numericalFieldValues, List<String> keywordFieldValues) {
         client().admin().indices().prepareCreate(sourceIndex)
-            .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
+            .addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword")
             .get();
 
         BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
             .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
-        for (int i = 0; i < numRows; i++) {
-            Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
-            String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
+        for (int i = 0; i < numTrainingRows; i++) {
+            Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
+            String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size());
+
+            IndexRequest indexRequest = new IndexRequest(sourceIndex)
+                .source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue);
+            bulkRequestBuilder.add(indexRequest);
+        }
+        for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
+            Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
 
             IndexRequest indexRequest = new IndexRequest(sourceIndex)
-                .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
+                .source(NUMERICAL_FIELD, numericalValue);
             bulkRequestBuilder.add(indexRequest);
         }
         BulkResponse bulkResponse = bulkRequestBuilder.get();
@@ -302,10 +298,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             classNames.add((String) topClass.get("class_name"));
             classProbabilities.add((Double) topClass.get("class_probability"));
         }
-        // Assert that all the class names come from the set of dependent variable values.
-        classNames.forEach(className -> assertThat(className, is(in(DEPENDENT_VARIABLE_VALUES))));
+        // Assert that all the predicted class names come from the set of keyword field values.
+        classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES))));
         // Assert that the first class listed in top classes is the same as the predicted class.
-        assertThat(classNames.get(0), equalTo(resultsObject.get("variable_prediction")));
+        assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction")));
         // Assert that all the class probabilities lie within [0, 1] interval.
         classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
         // Assert that the top classes are listed in the order of decreasing probabilities.

+ 65 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java

@@ -13,6 +13,9 @@ import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse;
 import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction;
 import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest;
 import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
+import org.elasticsearch.action.search.SearchAction;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.common.collect.ImmutableOpenMap;
@@ -22,6 +25,10 @@ import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.aggregations.AggregationBuilders;
+import org.elasticsearch.search.aggregations.Aggregations;
+import org.elasticsearch.search.aggregations.metrics.Cardinality;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -34,6 +41,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
 
 public class DataFrameDataExtractorFactory {
 
@@ -172,13 +180,65 @@ public class DataFrameDataExtractorFactory {
                                                       boolean isTaskRestarting,
                                                       ActionListener<ExtractedFields> listener) {
         AtomicInteger docValueFieldsLimitHolder = new AtomicInteger();
+        AtomicReference<ExtractedFields> extractedFieldsHolder = new AtomicReference<>();
 
-        // Step 3. Extract fields (if possible) and notify listener
+        // Step 4. Check fields cardinality vs limits and notify listener
+        ActionListener<SearchResponse> checkCardinalityHandler = ActionListener.wrap(
+            searchResponse -> {
+                if (searchResponse != null) {
+                    Aggregations aggs = searchResponse.getAggregations();
+                    if (aggs == null) {
+                        listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities"));
+                        return;
+                    }
+                    for (Map.Entry<String, Long> entry : config.getAnalysis().getFieldCardinalityLimits().entrySet()) {
+                        String fieldName = entry.getKey();
+                        Long limit = entry.getValue();
+                        Cardinality cardinality = aggs.get(fieldName);
+                        if (cardinality == null) {
+                            listener.onFailure(ExceptionsHelper.serverError("Unexpected null response when gathering field cardinalities"));
+                            return;
+                        }
+                        if (cardinality.getValue() > limit) {
+                            listener.onFailure(
+                                ExceptionsHelper.badRequestException(
+                                    "Field [{}] must have at most [{}] distinct values but there were at least [{}]",
+                                    fieldName, limit, cardinality.getValue()));
+                            return;
+                        }
+                    }
+                }
+                listener.onResponse(extractedFieldsHolder.get());
+            },
+            listener::onFailure
+        );
+
+        // Step 3. Extract fields (if possible)
         ActionListener<FieldCapabilitiesResponse> fieldCapabilitiesHandler = ActionListener.wrap(
-            fieldCapabilitiesResponse -> listener.onResponse(
-                new ExtractedFieldsDetector(
-                        index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse)
-                    .detect()),
+            fieldCapabilitiesResponse -> {
+                extractedFieldsHolder.set(
+                    new ExtractedFieldsDetector(
+                            index, config, resultsField, isTaskRestarting, docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse)
+                        .detect());
+
+                Map<String, Long> fieldCardinalityLimits = config.getAnalysis().getFieldCardinalityLimits();
+                if (fieldCardinalityLimits.isEmpty()) {
+                    checkCardinalityHandler.onResponse(null);
+                } else {
+                    SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0);
+                    for (Map.Entry<String, Long> entry : fieldCardinalityLimits.entrySet()) {
+                        String fieldName = entry.getKey();
+                        Long limit = entry.getValue();
+                        searchSourceBuilder.aggregation(
+                            AggregationBuilders.cardinality(fieldName)
+                                .field(fieldName)
+                                .precisionThreshold(limit + 1));
+                    }
+                    SearchRequest searchRequest = new SearchRequest(config.getSource().getIndex()).source(searchSourceBuilder);
+                    ClientHelper.executeWithHeadersAsync(
+                        config.getHeaders(), ClientHelper.ML_ORIGIN, client, SearchAction.INSTANCE, searchRequest, checkCardinalityHandler);
+                }
+            },
             listener::onFailure
         );
 

+ 50 - 1
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml

@@ -69,7 +69,7 @@
         body:
           mappings:
             properties:
-              long_field: { "type": "long" }
+              long_field: { type: "long" }
 
   - do:
       ml.put_data_frame_analytics:
@@ -140,3 +140,52 @@
       catch: /dest index \[non-empty-dest\] must be empty/
       ml.start_data_frame_analytics:
         id: "start_given_empty_dest_index"
+
+---
+"Test start classification analysis when the dependent variable cardinality is too high":
+  - do:
+      indices.create:
+        index: index-with-dep-var-with-too-high-card
+        body:
+          mappings:
+            properties:
+              numeric_field: { type: "long" }
+              keyword_field: { type: "keyword" }
+
+  - do:
+      index:
+        index: index-with-dep-var-with-too-high-card
+        body: { numeric_field: 1.0, keyword_field: "class_a" }
+
+  - do:
+      index:
+        index: index-with-dep-var-with-too-high-card
+        body: { numeric_field: 2.0, keyword_field: "class_b" }
+
+  - do:
+      index:
+        index: index-with-dep-var-with-too-high-card
+        body: { numeric_field: 3.0, keyword_field: "class_c" }
+
+  - do:
+      indices.refresh:
+        index: index-with-dep-var-with-too-high-card
+
+  - do:
+      ml.put_data_frame_analytics:
+        id: "too-high-card"
+        body: >
+          {
+            "source": {
+              "index": "index-with-dep-var-with-too-high-card"
+            },
+            "dest": {
+              "index": "index-with-dep-var-with-too-high-card-dest"
+            },
+            "analysis": { "classification": { "dependent_variable": "keyword_field" } }
+          }
+
+  - do:
+      catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/
+      ml.start_data_frame_analytics:
+        id: "too-high-card"