Browse Source

Do not copy mapping from dependent variable to prediction field in regression analysis (#51227)

Przemysław Witek 5 years ago
parent
commit
927b14e0bf

+ 27 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
+import org.elasticsearch.index.mapper.FieldAliasMapper;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 
 import java.io.IOException;
@@ -28,6 +29,7 @@ import java.util.stream.Stream;
 
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
+import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
 
 public class Classification implements DataFrameAnalysis {
 
@@ -246,12 +248,32 @@ public class Classification implements DataFrameAnalysis {
         return Collections.singletonMap(dependentVariable, 2L);
     }
 
+    @SuppressWarnings("unchecked")
     @Override
-    public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
-        return new HashMap<>() {{
-            put(resultsFieldName + "." + predictionFieldName, dependentVariable);
-            put(resultsFieldName + ".top_classes.class_name", dependentVariable);
-        }};
+    public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
+        Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
+        if ((dependentVariableMapping instanceof Map) == false) {
+            return Collections.emptyMap();
+        }
+        Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
+        // If the source field is an alias, fetch the concrete field that the alias points to.
+        if (FieldAliasMapper.CONTENT_TYPE.equals(dependentVariableMappingAsMap.get("type"))) {
+            String path = (String) dependentVariableMappingAsMap.get(FieldAliasMapper.Names.PATH);
+            dependentVariableMapping = extractMapping(path, mappingsProperties);
+        }
+        // We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
+        // Hence, we need to check the "instanceof" condition again.
+        if ((dependentVariableMapping instanceof Map) == false) {
+            return Collections.emptyMap();
+        }
+        Map<String, Object> additionalProperties = new HashMap<>();
+        additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
+        additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
+        return additionalProperties;
+    }
+
+    private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
+        return extractValue(String.join(".properties.", path.split("\\.")), mappingsProperties);
     }
 
     @Override

+ 4 - 6
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

@@ -42,15 +42,13 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
     Map<String, Long> getFieldCardinalityLimits();
 
     /**
-     * Returns fields for which the mappings should be copied from source index to destination index.
-     * Each entry of the returned {@link Map} is of the form:
-     *   key   - field path in the destination index
-     *   value - field path in the source index from which the mapping should be taken
+     * Returns fields for which the mappings should be either predefined or copied from source index to destination index.
      *
+     * @param mappingsProperties mappings.properties portion of the index mappings
      * @param resultsFieldName name of the results field under which all the results are stored
-     * @return {@link Map} containing fields for which the mappings should be copied from source index to destination index
+     * @return {@link Map} containing fields for which the mappings should be handled explicitly
      */
-    Map<String, String> getExplicitlyMappedFields(String resultsFieldName);
+    Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName);
 
     /**
      * @return {@code true} if this analysis supports data frame rows with missing values

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

@@ -230,7 +230,7 @@ public class OutlierDetection implements DataFrameAnalysis {
     }
 
     @Override
-    public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
+    public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
         return Collections.emptyMap();
     }
 

+ 4 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

@@ -187,8 +187,10 @@ public class Regression implements DataFrameAnalysis {
     }
 
     @Override
-    public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
-        return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, dependentVariable);
+    public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
+        // Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
+        // high (over 10M) values of dependent variable.
+        return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double"));
     }
 
     @Override

+ 37 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

@@ -20,13 +20,16 @@ import org.elasticsearch.test.AbstractSerializingTestCase;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
 
+import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.anEmptyMap;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasEntry;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.notNullValue;
@@ -172,8 +175,40 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
         assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
     }
 
-    public void testFieldMappingsToCopyIsNonEmpty() {
-        assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
+    public void testGetExplicitlyMappedFields() {
+        assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap()));
+        assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap()));
+        assertThat(
+            new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
+            is(anEmptyMap()));
+        assertThat(
+            new Classification("foo").getExplicitlyMappedFields(
+                Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
+                "results"),
+            allOf(
+                hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
+                hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
+        assertThat(
+            new Classification("foo").getExplicitlyMappedFields(
+                new HashMap<>() {{
+                    put("foo", new HashMap<>() {{
+                        put("type", "alias");
+                        put("path", "bar");
+                    }});
+                    put("bar", Collections.singletonMap("type", "long"));
+                }},
+                "results"),
+            allOf(
+                hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
+                hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
+        assertThat(
+            new Classification("foo").getExplicitlyMappedFields(
+                Collections.singletonMap("foo", new HashMap<>() {{
+                    put("type", "alias");
+                    put("path", "missing");
+                }}),
+                "results"),
+            is(anEmptyMap()));
     }
 
     public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {

+ 2 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

@@ -92,8 +92,8 @@ public class OutlierDetectionTests extends AbstractSerializingTestCase<OutlierDe
         assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
     }
 
-    public void testFieldMappingsToCopyIsEmpty() {
-        assertThat(createTestInstance().getExplicitlyMappedFields(""), is(anEmptyMap()));
+    public void testGetExplicitlyMappedFields() {
+        assertThat(createTestInstance().getExplicitlyMappedFields(null, null), is(anEmptyMap()));
     }
 
     public void testGetStateDocId() {

+ 6 - 3
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

@@ -23,6 +23,7 @@ import static org.hamcrest.Matchers.anEmptyMap;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasEntry;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.notNullValue;
@@ -42,7 +43,7 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
         return createRandom();
     }
 
-    public static Regression createRandom() {
+    private static Regression createRandom() {
         String dependentVariableName = randomAlphaOfLength(10);
         BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
         String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
@@ -109,8 +110,10 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
         assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
     }
 
-    public void testFieldMappingsToCopyIsNonEmpty() {
-        assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
+    public void testGetExplicitlyMappedFields() {
+        assertThat(
+            new Regression("foo").getExplicitlyMappedFields(null, "results"),
+            hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
     }
 
     public void testGetStateDocId() {

+ 7 - 39
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

@@ -7,8 +7,6 @@ package org.elasticsearch.xpack.ml.integration;
 
 import com.google.common.collect.Ordering;
 import org.elasticsearch.ElasticsearchStatusException;
-import org.elasticsearch.action.admin.indices.get.GetIndexAction;
-import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
 import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
 import org.elasticsearch.action.bulk.BulkRequestBuilder;
 import org.elasticsearch.action.bulk.BulkResponse;
@@ -42,7 +40,6 @@ import java.util.Map;
 import java.util.Set;
 
 import static java.util.stream.Collectors.toList;
-import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
@@ -116,7 +113,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
-        assertMlResultsFieldMappings(predictedClassField, "keyword");
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [classification]",
             "Estimated memory usage for this analytics to be",
@@ -157,7 +154,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
-        assertMlResultsFieldMappings(predictedClassField, "keyword");
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [classification]",
             "Estimated memory usage for this analytics to be",
@@ -220,7 +217,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
-        assertMlResultsFieldMappings(predictedClassField, expectedMappingTypeForPredictedField);
+        assertMlResultsFieldMappings(destIndex, predictedClassField, expectedMappingTypeForPredictedField);
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [classification]",
             "Estimated memory usage for this analytics to be",
@@ -308,7 +305,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
-        assertMlResultsFieldMappings(predictedClassField, "keyword");
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
     }
 
@@ -366,7 +363,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
-        assertMlResultsFieldMappings(predictedClassField, "keyword");
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
     }
 
@@ -385,7 +382,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
-        assertMlResultsFieldMappings(predictedClassField, "keyword");
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
     }
 
@@ -404,7 +401,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
-        assertMlResultsFieldMappings(predictedClassField, "keyword");
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
         assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
     }
 
@@ -564,15 +561,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         return destDoc;
     }
 
-    /**
-     * Wrapper around extractValue that:
-     * - allows dots (".") in the path elements provided as arguments
-     * - supports implicit casting to the appropriate type
-     */
-    private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
-        return (T)extractValue(String.join(".", path), doc);
-    }
-
     private static <T> void assertTopClasses(Map<String, Object> resultsObject,
                                              int numTopClasses,
                                              String dependentVariable,
@@ -656,26 +644,6 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         }
     }
 
-    private void assertMlResultsFieldMappings(String predictedClassField, String expectedType) {
-        Map<String, Object> mappings =
-            client()
-                .execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(destIndex))
-                .actionGet()
-                .mappings()
-                .get(destIndex)
-                .sourceAsMap();
-        assertThat(
-            mappings.toString(),
-            getFieldValue(
-                mappings,
-                "properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
-            equalTo(expectedType));
-        assertThat(
-            mappings.toString(),
-            getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
-            equalTo(expectedType));
-    }
-
     private String stateDocId() {
         return jobId + "_classification_state#1";
     }

+ 34 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

@@ -5,6 +5,8 @@
  */
 package org.elasticsearch.xpack.ml.integration;
 
+import org.elasticsearch.action.admin.indices.get.GetIndexAction;
+import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
 import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
 import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
 import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
@@ -53,6 +55,7 @@ import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 
+import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
 import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.arrayWithSize;
 import static org.hamcrest.Matchers.equalTo;
@@ -281,4 +284,35 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
             .get();
         assertThat(searchResponse.getHits().getHits().length, equalTo(1));
     }
+
+    protected static void assertMlResultsFieldMappings(String index, String predictedClassField, String expectedType) {
+        Map<String, Object> mappings =
+            client()
+                .execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(index))
+                .actionGet()
+                .mappings()
+                .get(index)
+                .sourceAsMap();
+        assertThat(
+            mappings.toString(),
+            getFieldValue(
+                mappings,
+                "properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
+            equalTo(expectedType));
+        if (getFieldValue(mappings, "properties", "ml", "properties", "top_classes") != null) {
+            assertThat(
+                mappings.toString(),
+                getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
+                equalTo(expectedType));
+        }
+    }
+
+    /**
+     * Wrapper around extractValue that:
+     * - allows dots (".") in the path elements provided as arguments
+     * - supports implicit casting to the appropriate type
+     */
+    protected static <T> T getFieldValue(Map<String, Object> doc, String... path) {
+        return (T)extractValue(String.join(".", path), doc);
+    }
 }

+ 56 - 12
x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

@@ -35,8 +35,10 @@ import static org.hamcrest.Matchers.is;
 public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     private static final String NUMERICAL_FEATURE_FIELD = "feature";
+    private static final String DISCRETE_NUMERICAL_FEATURE_FIELD = "discrete-feature";
     private static final String DEPENDENT_VARIABLE_FIELD = "variable";
     private static final List<Double> NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(1.0, 2.0, 3.0));
+    private static final List<Long> DISCRETE_NUMERICAL_FEATURE_VALUES = Collections.unmodifiableList(Arrays.asList(10L, 20L, 30L));
     private static final List<Double> DEPENDENT_VARIABLE_VALUES = Collections.unmodifiableList(Arrays.asList(10.0, 20.0, 30.0));
 
     private String jobId;
@@ -50,6 +52,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
         initialize("regression_single_numeric_feature_and_mixed_data_set");
+        String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
         indexData(sourceIndex, 300, 50);
 
         DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null,
@@ -78,19 +81,24 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
             // it seems for this case values can be as far off as 2.0
 
             // double featureValue = (double) destDoc.get(NUMERICAL_FEATURE_FIELD);
-            // double predictionValue = (double) resultsObject.get("variable_prediction");
+            // double predictionValue = (double) resultsObject.get(predictedClassField);
             // assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
 
-            assertThat(resultsObject.containsKey("variable_prediction"), is(true));
+            assertThat(resultsObject.containsKey(predictedClassField), is(true));
             assertThat(resultsObject.containsKey("is_training"), is(true));
             assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD)));
-            assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true));
+            assertThat(
+                resultsObject.toString(),
+                resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD)
+                    || resultsObject.containsKey("feature_importance." + DISCRETE_NUMERICAL_FEATURE_FIELD),
+                is(true));
         }
 
         assertProgress(jobId, 100, 100, 100, 100);
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [regression]",
             "Estimated memory usage for this analytics to be",
@@ -103,6 +111,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
         initialize("regression_only_training_data_and_training_percent_is_100");
+        String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
         indexData(sourceIndex, 350, 0);
 
         DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
@@ -119,7 +128,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         for (SearchHit hit : sourceData.getHits()) {
             Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
 
-            assertThat(resultsObject.containsKey("variable_prediction"), is(true));
+            assertThat(resultsObject.containsKey(predictedClassField), is(true));
             assertThat(resultsObject.containsKey("is_training"), is(true));
             assertThat(resultsObject.get("is_training"), is(true));
         }
@@ -128,6 +137,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [regression]",
             "Estimated memory usage for this analytics to be",
@@ -140,6 +150,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
         initialize("regression_only_training_data_and_training_percent_is_50");
+        String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
         indexData(sourceIndex, 350, 0);
 
         DataFrameAnalyticsConfig config =
@@ -164,7 +175,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         for (SearchHit hit : sourceData.getHits()) {
             Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
 
-            assertThat(resultsObject.containsKey("variable_prediction"), is(true));
+            assertThat(resultsObject.containsKey(predictedClassField), is(true));
             assertThat(resultsObject.containsKey("is_training"), is(true));
             // Let's just assert there's both training and non-training results
             if ((boolean) resultsObject.get("is_training")) {
@@ -180,6 +191,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
         assertThatAuditMessagesMatch(jobId,
             "Created analytics with analysis type [regression]",
             "Estimated memory usage for this analytics to be",
@@ -192,6 +204,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testStopAndRestart() throws Exception {
         initialize("regression_stop_and_restart");
+        String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
         indexData(sourceIndex, 350, 0);
 
         DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
@@ -233,7 +246,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         for (SearchHit hit : sourceData.getHits()) {
             Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit));
 
-            assertThat(resultsObject.containsKey("variable_prediction"), is(true));
+            assertThat(resultsObject.containsKey(predictedClassField), is(true));
             assertThat(resultsObject.containsKey("is_training"), is(true));
             assertThat(resultsObject.get("is_training"), is(true));
         }
@@ -242,6 +255,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
     }
 
     public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception {
@@ -289,6 +303,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     public void testDeleteExpiredData_RemovesUnusedState() throws Exception {
         initialize("regression_delete_expired_data");
+        String predictedClassField = DEPENDENT_VARIABLE_FIELD + "_prediction";
         indexData(sourceIndex, 100, 0);
 
         DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD));
@@ -301,6 +316,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
         assertModelStatePersisted(stateDocId());
         assertInferenceModelPersisted(jobId);
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
 
         // Call _delete_expired_data API and check nothing was deleted
         assertThat(deleteExpiredData().isDeleted(), is(true));
@@ -319,6 +335,31 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         assertThat(stateIndexSearchResponse.getHits().getTotalHits().value, equalTo(0L));
     }
 
+    public void testDependentVariableIsLong() throws Exception {
+        initialize("regression_dependent_variable_is_long");
+        String predictedClassField = DISCRETE_NUMERICAL_FEATURE_FIELD + "_prediction";
+        indexData(sourceIndex, 100, 0);
+
+        DataFrameAnalyticsConfig config =
+            buildAnalytics(
+                jobId,
+                sourceIndex,
+                destIndex,
+                null,
+                new Regression(DISCRETE_NUMERICAL_FEATURE_FIELD, BoostedTreeParams.builder().build(), null, null, null));
+        registerAnalytics(config);
+        putAnalytics(config);
+
+        assertIsStopped(jobId);
+        assertProgress(jobId, 0, 0, 0, 0);
+
+        startAnalytics(jobId);
+        waitUntilAnalyticsIsStopped(jobId);
+        assertProgress(jobId, 100, 100, 100, 100);
+
+        assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
+    }
+
     private void initialize(String jobId) {
         this.jobId = jobId;
         this.sourceIndex = jobId + "_source_index";
@@ -327,7 +368,10 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
 
     private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows) {
         client().admin().indices().prepareCreate(sourceIndex)
-            .setMapping(NUMERICAL_FEATURE_FIELD, "type=double", DEPENDENT_VARIABLE_FIELD, "type=double")
+            .setMapping(
+                NUMERICAL_FEATURE_FIELD, "type=double",
+                DISCRETE_NUMERICAL_FEATURE_FIELD, "type=long",
+                DEPENDENT_VARIABLE_FIELD, "type=double")
             .get();
 
         BulkRequestBuilder bulkRequestBuilder = client().prepareBulk()
@@ -335,12 +379,15 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
         for (int i = 0; i < numTrainingRows; i++) {
             List<Object> source = List.of(
                 NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
+                DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()),
                 DEPENDENT_VARIABLE_FIELD, DEPENDENT_VARIABLE_VALUES.get(i % DEPENDENT_VARIABLE_VALUES.size()));
             IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
             bulkRequestBuilder.add(indexRequest);
         }
         for (int i = numTrainingRows; i < numTrainingRows + numNonTrainingRows; i++) {
-            List<Object> source = List.of(NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()));
+            List<Object> source = List.of(
+                NUMERICAL_FEATURE_FIELD, NUMERICAL_FEATURE_VALUES.get(i % NUMERICAL_FEATURE_VALUES.size()),
+                DISCRETE_NUMERICAL_FEATURE_FIELD, DISCRETE_NUMERICAL_FEATURE_VALUES.get(i % DISCRETE_NUMERICAL_FEATURE_VALUES.size()));
             IndexRequest indexRequest = new IndexRequest(sourceIndex).source(source.toArray());
             bulkRequestBuilder.add(indexRequest);
         }
@@ -363,10 +410,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
     }
 
     private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Object> destDoc) {
-        assertThat(destDoc.containsKey("ml"), is(true));
-        @SuppressWarnings("unchecked")
-        Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
-        return resultsObject;
+        return getFieldValue(destDoc, "ml");
     }
 
     protected String stateDocId() {

+ 1 - 26
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java

@@ -24,7 +24,6 @@ import org.elasticsearch.cluster.metadata.MappingMetaData;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.IndexSortConfig;
-import org.elasticsearch.index.mapper.FieldAliasMapper;
 import org.elasticsearch.index.mapper.KeywordFieldMapper;
 import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.xpack.core.ClientHelper;
@@ -40,7 +39,6 @@ import java.util.Map;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Supplier;
 
-import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
 
 /**
@@ -158,36 +156,13 @@ public final class DataFrameAnalyticsIndex {
         return maxValue;
     }
 
-    @SuppressWarnings("unchecked")
     private static Map<String, Object> createAdditionalMappings(DataFrameAnalyticsConfig config, Map<String, Object> mappingsProperties) {
         Map<String, Object> properties = new HashMap<>();
         properties.put(ID_COPY, Map.of("type", KeywordFieldMapper.CONTENT_TYPE));
-        for (Map.Entry<String, String> entry
-                : config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) {
-            String destFieldPath = entry.getKey();
-            String sourceFieldPath = entry.getValue();
-            Object sourceFieldMapping = extractMapping(sourceFieldPath, mappingsProperties);
-            if (sourceFieldMapping instanceof Map) {
-                Map<String, Object> sourceFieldMappingAsMap = (Map) sourceFieldMapping;
-                // If the source field is an alias, fetch the concrete field that the alias points to.
-                if (FieldAliasMapper.CONTENT_TYPE.equals(sourceFieldMappingAsMap.get("type"))) {
-                    String path = (String) sourceFieldMappingAsMap.get(FieldAliasMapper.Names.PATH);
-                    sourceFieldMapping = extractMapping(path, mappingsProperties);
-                }
-            }
-            // We may have updated the value of {@code sourceFieldMapping} in the "if" block above.
-            // Hence, we need to check the "instanceof" condition again.
-            if (sourceFieldMapping instanceof Map) {
-                properties.put(destFieldPath, sourceFieldMapping);
-            }
-        }
+        properties.putAll(config.getAnalysis().getExplicitlyMappedFields(mappingsProperties, config.getDest().getResultsField()));
         return properties;
     }
 
-    private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
-        return extractValue(String.join("." + PROPERTIES + ".", path.split("\\.")), mappingsProperties);
-    }
-
     private static Map<String, Object> createMetaData(String analyticsId, Clock clock) {
         Map<String, Object> metadata = new HashMap<>();
         metadata.put(CREATION_DATE_MILLIS, clock.millis());

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java

@@ -190,7 +190,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase {
 
     public void testCreateDestinationIndex_Regression() throws IOException {
         Map<String, Object> map = testCreateDestinationIndex(new Regression(NUMERICAL_FIELD));
-        assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
+        assertThat(extractValue("_doc.properties.ml.numerical-field_prediction.type", map), equalTo("double"));
     }
 
     public void testCreateDestinationIndex_Classification() throws IOException {
@@ -291,7 +291,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase {
 
     public void testUpdateMappingsToDestIndex_Regression() throws IOException {
         Map<String, Object> map = testUpdateMappingsToDestIndex(new Regression(NUMERICAL_FIELD));
-        assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("integer"));
+        assertThat(extractValue("properties.ml.numerical-field_prediction.type", map), equalTo("double"));
     }
 
     public void testUpdateMappingsToDestIndex_Classification() throws IOException {