|
@@ -15,6 +15,8 @@ import org.elasticsearch.action.index.IndexAction;
|
|
|
import org.elasticsearch.action.index.IndexRequest;
|
|
|
import org.elasticsearch.action.search.SearchResponse;
|
|
|
import org.elasticsearch.action.support.WriteRequest;
|
|
|
+import org.elasticsearch.index.query.QueryBuilder;
|
|
|
+import org.elasticsearch.index.query.QueryBuilders;
|
|
|
import org.elasticsearch.search.SearchHit;
|
|
|
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
|
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
|
@@ -228,7 +230,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
assertEvaluation(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "ml.boolean-field_prediction");
|
|
|
}
|
|
|
|
|
|
- public void testDependentVariableCardinalityTooHighError() {
|
|
|
+ public void testDependentVariableCardinalityTooHighError() throws Exception {
|
|
|
initialize("cardinality_too_high");
|
|
|
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
|
|
|
// Index one more document with a class different than the two already used.
|
|
@@ -246,6 +248,27 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]"));
|
|
|
}
|
|
|
|
|
|
+ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRange() throws Exception {
|
|
|
+ initialize("cardinality_too_high_with_query");
|
|
|
+ indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
|
|
|
+ // Index one more document with a class different than the two already used.
|
|
|
+ client().execute(IndexAction.INSTANCE, new IndexRequest(sourceIndex)
|
|
|
+ .source(KEYWORD_FIELD, "fox")
|
|
|
+ .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE))
|
|
|
+ .actionGet();
|
|
|
+ QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.termsQuery(KEYWORD_FIELD, KEYWORD_FIELD_VALUES));
|
|
|
+
|
|
|
+ DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD), query);
|
|
|
+ registerAnalytics(config);
|
|
|
+ putAnalytics(config);
|
|
|
+
|
|
|
+ // Should not throw
|
|
|
+ startAnalytics(jobId);
|
|
|
+ waitUntilAnalyticsIsStopped(jobId);
|
|
|
+
|
|
|
+ assertProgress(jobId, 100, 100, 100, 100);
|
|
|
+ }
|
|
|
+
|
|
|
private void initialize(String jobId) {
|
|
|
this.jobId = jobId;
|
|
|
this.sourceIndex = jobId + "_source_index";
|