|
@@ -6,6 +6,7 @@
|
|
|
package org.elasticsearch.xpack.ml.integration;
|
|
|
|
|
|
import com.google.common.collect.Ordering;
|
|
|
+import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
|
|
import org.elasticsearch.action.bulk.BulkResponse;
|
|
|
import org.elasticsearch.action.get.GetResponse;
|
|
@@ -37,10 +38,10 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
|
|
|
|
|
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
- private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
|
|
- private static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
|
|
- private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
|
|
|
- private static final List<String> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat", "cow"));
|
|
|
+ private static final String NUMERICAL_FIELD = "numerical-field";
|
|
|
+ private static final String KEYWORD_FIELD = "keyword-field";
|
|
|
+ private static final List<Double> NUMERICAL_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0, 4.0));
|
|
|
+ private static final List<String> KEYWORD_FIELD_VALUES = Collections.unmodifiableList(Arrays.asList("dog", "cat"));
|
|
|
|
|
|
private String jobId;
|
|
|
private String sourceIndex;
|
|
@@ -53,36 +54,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
|
|
initialize("classification_single_numeric_feature_and_mixed_data_set");
|
|
|
+ indexData(sourceIndex, 300, 50, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
|
|
|
|
|
- { // Index 350 rows, 300 of them being training rows.
|
|
|
- client().admin().indices().prepareCreate(sourceIndex)
|
|
|
- .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
|
|
|
- .get();
|
|
|
-
|
|
|
- BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
|
|
- .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
|
- for (int i = 0; i < 300; i++) {
|
|
|
- Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
|
|
- String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
|
|
-
|
|
|
- IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
|
- .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
|
|
- bulkRequestBuilder.add(indexRequest);
|
|
|
- }
|
|
|
- for (int i = 300; i < 350; i++) {
|
|
|
- Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
|
|
-
|
|
|
- IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
|
- .source(NUMERICAL_FEATURE_FIELD, field);
|
|
|
- bulkRequestBuilder.add(indexRequest);
|
|
|
- }
|
|
|
- BulkResponse bulkResponse = bulkRequestBuilder.get();
|
|
|
- if (bulkResponse.hasFailures()) {
|
|
|
- fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
|
|
|
+ DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
|
|
registerAnalytics(config);
|
|
|
putAnalytics(config);
|
|
|
|
|
@@ -97,10 +71,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
Map<String, Object> destDoc = getDestDoc(config, hit);
|
|
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
|
|
|
|
|
- assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
|
|
- assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
|
|
+ assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
|
|
+ assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
|
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
|
|
- assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
|
|
+ assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD)));
|
|
|
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
|
|
}
|
|
|
|
|
@@ -117,9 +91,9 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
|
|
initialize("classification_only_training_data_and_training_percent_is_100");
|
|
|
- indexTrainingData(sourceIndex, 300);
|
|
|
+ indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
|
|
|
|
|
- DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(DEPENDENT_VARIABLE_FIELD));
|
|
|
+ DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
|
|
registerAnalytics(config);
|
|
|
putAnalytics(config);
|
|
|
|
|
@@ -133,8 +107,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
for (SearchHit hit : sourceData.getHits()) {
|
|
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
|
|
|
|
|
- assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
|
|
- assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
|
|
+ assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
|
|
+ assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
|
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
|
|
assertThat(resultsObject.get("is_training"), is(true));
|
|
|
assertThat(resultsObject.containsKey("top_classes"), is(false));
|
|
@@ -153,7 +127,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
|
|
initialize("classification_only_training_data_and_training_percent_is_50");
|
|
|
- indexTrainingData(sourceIndex, 300);
|
|
|
+ indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
|
|
|
|
|
DataFrameAnalyticsConfig config =
|
|
|
buildAnalytics(
|
|
@@ -161,7 +135,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
sourceIndex,
|
|
|
destIndex,
|
|
|
null,
|
|
|
- new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
|
|
|
+ new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, null, 50.0));
|
|
|
registerAnalytics(config);
|
|
|
putAnalytics(config);
|
|
|
|
|
@@ -176,8 +150,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
|
|
for (SearchHit hit : sourceData.getHits()) {
|
|
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
|
|
- assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
|
|
- assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
|
|
+ assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
|
|
+ assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
|
|
|
|
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
|
|
// Let's just assert there's both training and non-training results
|
|
@@ -205,7 +179,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
@AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/issues/712")
|
|
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows_TopClassesRequested() throws Exception {
|
|
|
initialize("classification_top_classes_requested");
|
|
|
- indexTrainingData(sourceIndex, 300);
|
|
|
+ indexData(sourceIndex, 300, 0, NUMERICAL_FIELD_VALUES, KEYWORD_FIELD_VALUES);
|
|
|
|
|
|
int numTopClasses = 2;
|
|
|
DataFrameAnalyticsConfig config =
|
|
@@ -214,7 +188,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
sourceIndex,
|
|
|
destIndex,
|
|
|
null,
|
|
|
- new Classification(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
|
|
|
+ new Classification(KEYWORD_FIELD, BoostedTreeParamsTests.createRandom(), null, numTopClasses, null));
|
|
|
registerAnalytics(config);
|
|
|
putAnalytics(config);
|
|
|
|
|
@@ -229,8 +203,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
Map<String, Object> destDoc = getDestDoc(config, hit);
|
|
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
|
|
|
|
|
- assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
|
|
- assertThat((String) resultsObject.get("variable_prediction"), is(in(DEPENDENT_VARIABLE_VALUES)));
|
|
|
+ assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true));
|
|
|
+ assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES)));
|
|
|
assertTopClasses(resultsObject, numTopClasses);
|
|
|
}
|
|
|
|
|
@@ -245,25 +219,47 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
"Finished analysis");
|
|
|
}
|
|
|
|
|
|
+ public void testDependentVariableCardinalityTooHighError() {
|
|
|
+ initialize("cardinality_too_high");
|
|
|
+ indexData(sourceIndex, 6, 5, NUMERICAL_FIELD_VALUES, Arrays.asList("dog", "cat", "fox"));
|
|
|
+
|
|
|
+ DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
|
|
+ registerAnalytics(config);
|
|
|
+ putAnalytics(config);
|
|
|
+
|
|
|
+ ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> startAnalytics(jobId));
|
|
|
+ assertThat(e.status().getStatus(), equalTo(400));
|
|
|
+ assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]"));
|
|
|
+ }
|
|
|
+
|
|
|
private void initialize(String jobId) {
|
|
|
this.jobId = jobId;
|
|
|
this.sourceIndex = jobId + "_source_index";
|
|
|
this.destIndex = sourceIndex + "_results";
|
|
|
}
|
|
|
|
|
|
- private static void indexTrainingData(String sourceIndex, int numRows) {
|
|
|
+ private static void indexData(String sourceIndex,
|
|
|
+ int numTrainingRows, int numNonTrainingRows,
|
|
|
+ List<Double> numericalFieldValues, List<String> keywordFieldValues) {
|
|
|
client().admin().indices().prepareCreate(sourceIndex)
|
|
|
- .addMapping("_doc", NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=keyword")
|
|
|
+ .addMapping("_doc", NUMERICAL_FIELD, "type=double", KEYWORD_FIELD, "type=keyword")
|
|
|
.get();
|
|
|
|
|
|
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
|
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
|
|
- for (int i = 0; i < numRows; i++) {
|
|
|
- Double field = NUMERICAL_FEATURE_VALUES.get(i % 3);
|
|
|
- String value = DEPENDENT_VARIABLE_VALUES.get(i % 3);
|
|
|
+ for (int i = 0; i < numTrainingRows; i++) {
|
|
|
+ Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
|
|
|
+ String keywordValue = keywordFieldValues.get(i % keywordFieldValues.size());
|
|
|
+
|
|
|
+ IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
|
+ .source(NUMERICAL_FIELD, numericalValue, KEYWORD_FIELD, keywordValue);
|
|
|
+ bulkRequestBuilder.add(indexRequest);
|
|
|
+ }
|
|
|
+ for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
|
|
|
+ Double numericalValue = numericalFieldValues.get(i % numericalFieldValues.size());
|
|
|
|
|
|
IndexRequest indexRequest = new IndexRequest(sourceIndex)
|
|
|
- .source(NUMERICAL_FEATURE_FIELD, field, DEPENDENT_VARIABLE_FIELD, value);
|
|
|
+ .source(NUMERICAL_FIELD, numericalValue);
|
|
|
bulkRequestBuilder.add(indexRequest);
|
|
|
}
|
|
|
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
|
@@ -302,10 +298,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
classNames.add((String) topClass.get("class_name"));
|
|
|
classProbabilities.add((Double) topClass.get("class_probability"));
|
|
|
}
|
|
|
- // Assert that all the class names come from the set of dependent variable values.
|
|
|
- classNames.forEach(className -> assertThat(className, is(in(DEPENDENT_VARIABLE_VALUES))));
|
|
|
+ // Assert that all the predicted class names come from the set of keyword field values.
|
|
|
+ classNames.forEach(className -> assertThat(className, is(in(KEYWORD_FIELD_VALUES))));
|
|
|
// Assert that the first class listed in top classes is the same as the predicted class.
|
|
|
- assertThat(classNames.get(0), equalTo(resultsObject.get("variable_prediction")));
|
|
|
+ assertThat(classNames.get(0), equalTo(resultsObject.get("keyword-field_prediction")));
|
|
|
// Assert that all the class probabilities lie within [0, 1] interval.
|
|
|
classProbabilities.forEach(p -> assertThat(p, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))));
|
|
|
// Assert that the top classes are listed in the order of decreasing probabilities.
|