|
@@ -35,8 +35,10 @@ import static org.hamcrest.Matchers.is;
|
|
|
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
private static final String NUMERICAL_FEATURE_FIELD = "feature";
|
|
|
+ private static final String DISCRETE_NUMERICAL_FEATURE_FIELD = "discrete-feature";
|
|
|
private static final String DEPENDENT_VARIABLE_FIELD = "variable";
|
|
|
private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
|
|
|
+ private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(10L, 20L, 30L));
|
|
|
private static final List<Double> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0));
|
|
|
|
|
|
private String jobId;
|
|
@@ -50,6 +52,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
|
|
initialize("regression_single_numeric_feature_and_mixed_data_set");
|
|
|
+ String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
|
|
indexData(sourceIndex, 300, 50);
|
|
|
|
|
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
|
|
@@ -78,19 +81,24 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
// it seems for this case values can be as far off as 2.0
|
|
|
|
|
|
// double featureValue = (double) destDoc.get(NUMERICAL_FEATURE_FIELD);
|
|
|
- // double predictionValue = (double) resultsObject.get("variable_prediction");
|
|
|
+ // double predictionValue = (double) resultsObject.get(predictedClassField);
|
|
|
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
|
|
|
|
|
|
- assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
|
|
+ assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
|
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
|
|
assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
|
|
|
- assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true));
|
|
|
+ assertThat(
|
|
|
+ resultsObject.toString(),
|
|
|
+ resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD)
|
|
|
+ || resultsObject.containsKey("feature_importance." + DISCRETE_NUMERICAL_FEATURE_FIELD),
|
|
|
+ is(true));
|
|
|
}
|
|
|
|
|
|
assertProgress(jobId, 100, 100, 100, 100);
|
|
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
|
|
assertModelStatePersisted(stateDocId());
|
|
|
assertInferenceModelPersisted(jobId);
|
|
|
+ assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
|
|
assertThatAuditMessagesMatch(jobId,
|
|
|
"Created analytics with analysis type [regression]",
|
|
|
"Estimated memory usage for this analytics to be",
|
|
@@ -103,6 +111,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
|
|
initialize("regression_only_training_data_and_training_percent_is_100");
|
|
|
+ String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
|
|
indexData(sourceIndex, 350, 0);
|
|
|
|
|
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
|
@@ -119,7 +128,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
for (SearchHit hit : sourceData.getHits()) {
|
|
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
|
|
|
|
|
- assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
|
|
+ assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
|
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
|
|
assertThat(resultsObject.get("is_training"), is(true));
|
|
|
}
|
|
@@ -128,6 +137,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
|
|
assertModelStatePersisted(stateDocId());
|
|
|
assertInferenceModelPersisted(jobId);
|
|
|
+ assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
|
|
assertThatAuditMessagesMatch(jobId,
|
|
|
"Created analytics with analysis type [regression]",
|
|
|
"Estimated memory usage for this analytics to be",
|
|
@@ -140,6 +150,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
|
|
initialize("regression_only_training_data_and_training_percent_is_50");
|
|
|
+ String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
|
|
indexData(sourceIndex, 350, 0);
|
|
|
|
|
|
DataFrameAnalyticsConfig config =
|
|
@@ -164,7 +175,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
for (SearchHit hit : sourceData.getHits()) {
|
|
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
|
|
|
|
|
- assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
|
|
+ assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
|
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
|
|
// Let's just assert there's both training and non-training results
|
|
|
if ((boolean) resultsObject.get("is_training")) {
|
|
@@ -180,6 +191,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
|
|
assertModelStatePersisted(stateDocId());
|
|
|
assertInferenceModelPersisted(jobId);
|
|
|
+ assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
|
|
assertThatAuditMessagesMatch(jobId,
|
|
|
"Created analytics with analysis type [regression]",
|
|
|
"Estimated memory usage for this analytics to be",
|
|
@@ -192,6 +204,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
public void testStopAndRestart() throws Exception {
|
|
|
initialize("regression_stop_and_restart");
|
|
|
+ String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
|
|
indexData(sourceIndex, 350, 0);
|
|
|
|
|
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
|
@@ -233,7 +246,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
for (SearchHit hit : sourceData.getHits()) {
|
|
|
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
|
|
|
|
|
|
- assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
|
|
+ assertThat(resultsObject.containsKey(predictedClassField), is(true));
|
|
|
assertThat(resultsObject.containsKey("is_training"), is(true));
|
|
|
assertThat(resultsObject.get("is_training"), is(true));
|
|
|
}
|
|
@@ -242,6 +255,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
|
|
assertModelStatePersisted(stateDocId());
|
|
|
assertInferenceModelPersisted(jobId);
|
|
|
+ assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
|
|
}
|
|
|
|
|
|
public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
|
|
@@ -289,6 +303,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
public void testDeleteExpiredData_RemovesUnusedState() throws Exception {
|
|
|
initialize("regression_delete_expired_data");
|
|
|
+ String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
|
|
|
indexData(sourceIndex, 100, 0);
|
|
|
|
|
|
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
|
|
@@ -301,6 +316,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
|
|
assertModelStatePersisted(stateDocId());
|
|
|
assertInferenceModelPersisted(jobId);
|
|
|
+ assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
|
|
|
|
|
// Call _delete_expired_data API and check nothing was deleted
|
|
|
assertThat(deleteExpiredData().isDeleted(), is(true));
|
|
@@ -319,6 +335,31 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
assertThat(stateIndexSearchResponse.getHits().getTotalHits().value, equalTo(0L));
|
|
|
}
|
|
|
|
|
|
+ public void testDependentVariableIsLong() throws Exception {
|
|
|
+ initialize("regression_dependent_variable_is_long");
|
|
|
+ String predictedClassField = DISCRETE_NUMERICAL_FEATURE_FIELD + "_prediction";
|
|
|
+ indexData(sourceIndex, 100, 0);
|
|
|
+
|
|
|
+ DataFrameAnalyticsConfig config =
|
|
|
+ buildAnalytics(
|
|
|
+ jobId,
|
|
|
+ sourceIndex,
|
|
|
+ destIndex,
|
|
|
+ null,
|
|
|
+ new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null));
|
|
|
+ registerAnalytics(config);
|
|
|
+ putAnalytics(config);
|
|
|
+
|
|
|
+ assertIsStopped(jobId);
|
|
|
+ assertProgress(jobId, 0, 0, 0, 0);
|
|
|
+
|
|
|
+ startAnalytics(jobId);
|
|
|
+ waitUntilAnalyticsIsStopped(jobId);
|
|
|
+ assertProgress(jobId, 100, 100, 100, 100);
|
|
|
+
|
|
|
+ assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
|
|
|
+ }
|
|
|
+
|
|
|
private void initialize(String jobId) {
|
|
|
this.jobId = jobId;
|
|
|
this.sourceIndex = jobId + "_source_index";
|
|
@@ -327,7 +368,10 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
|
|
|
private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
|
|
|
client().admin().indices().prepareCreate(sourceIndex)
|
|
|
- .setMapping(NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
|
|
|
+ .setMapping(
|
|
|
+ NUMERICAL_FEATURE_FIELD, "type=double",
|
|
|
+ DISCRETE_NUMERICAL_FEATURE_FIELD, "type=long",
|
|
|
+ DEPENDENT_VARIABLE_FIELD, "type=double")
|
|
|
.get();
|
|
|
|
|
|
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
|
|
@@ -335,12 +379,15 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
for (int i = 0; i < numTrainingRows; i++) {
|
|
|
List<Object> source = List.of(
|
|
|
NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
|
|
|
+ DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()),
|
|
|
DEPENDENT_VARIABLE_FIELD, DEPENDENT_VARIABLE_VALUES.get(i % DEPENDENT_VARIABLE_VALUES.size()));
|
|
|
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
|
|
bulkRequestBuilder.add(indexRequest);
|
|
|
}
|
|
|
for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
|
|
|
- List<Object> source = List.of(NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()));
|
|
|
+ List<Object> source = List.of(
|
|
|
+ NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
|
|
|
+ DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()));
|
|
|
IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
|
|
|
bulkRequestBuilder.add(indexRequest);
|
|
|
}
|
|
@@ -363,10 +410,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|
|
}
|
|
|
|
|
|
private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Object> destDoc) {
|
|
|
- assertThat(destDoc.containsKey("ml"), is(true));
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
- Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
|
|
- return resultsObject;
|
|
|
+ return getFieldValue(destDoc, "ml");
|
|
|
}
|
|
|
|
|
|
protected String stateDocId() {
|