浏览代码

[ML][Inference] fix support for nested fields (#50258)

This fixes support for nested fields

We now support fully nested, fully collapsed, or a mix of both on inference docs. 

ES mappings allow the `_source` to be any combination of nested objects + dot delimited fields. 
So, we should do our best to find the best path down the Map for the desired field.
Benjamin Trent 5 年之前
父节点
当前提交
e9e6a4a7b4
共有 16 个文件被更改,包括 582 次插入33 次删除
  1. 2 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java
  2. 2 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java
  3. 2 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java
  4. 4 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java
  5. 133 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MapHelper.java
  6. 18 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java
  7. 15 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java
  8. 20 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java
  9. 57 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java
  10. 52 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java
  11. 193 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/MapHelperTests.java
  12. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java
  13. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java
  14. 41 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java
  15. 7 7
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java
  16. 32 16
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java

+ 2 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -103,7 +104,7 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
 
     @Override
     public void process(Map<String, Object> fields) {
-        Object value = fields.get(field);
+        Object value = MapHelper.dig(field, fields);
         if (value == null) {
             return;
         }

+ 2 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -86,7 +87,7 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
 
     @Override
     public void process(Map<String, Object> fields) {
-        Object value = fields.get(field);
+        Object value = MapHelper.dig(field, fields);
         if (value == null) {
             return;
         }

+ 2 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java

@@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -114,7 +115,7 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
 
     @Override
     public void process(Map<String, Object> fields) {
-        Object value = fields.get(field);
+        Object value = MapHelper.dig(field, fields);
         if (value == null) {
             return;
         }

+ 4 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java

@@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfi
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 
 import java.io.IOException;
 import java.util.ArrayDeque;
@@ -129,7 +130,9 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
                 "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
         }
 
-        List<Double> features = featureNames.stream().map(f -> InferenceHelpers.toDouble(fields.get(f))).collect(Collectors.toList());
+        List<Double> features = featureNames.stream()
+            .map(f -> InferenceHelpers.toDouble(MapHelper.dig(f, fields)))
+            .collect(Collectors.toList());
         return infer(features, config);
     }
 

+ 133 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/MapHelper.java

@@ -0,0 +1,133 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.utils;
+
+import org.elasticsearch.common.Nullable;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.Stack;
+
+public final class MapHelper {
+
+    private MapHelper() {}
+
+    /**
+     * This eagerly digs (depth first search, longer keys first) through the map by tokenizing the provided path on '.'.
+     *
+     * It is possible for ES _source docs to have "mixed" path formats. So, we should search all potential paths
+     * given the current knowledge of the map.
+     *
+     * Examples:
+     *
+     * The following maps would return `2` given the path "a.b.c.d"
+     *
+     * {
+     *     "a.b.c.d" : 2
+     * }
+     * {
+     *     "a" :{"b": {"c": {"d" : 2}}}
+     * }
+     * {
+     *     "a" :{"b.c": {"d" : 2}}}
+     * }
+     * {
+     *     "a" :{"b": {"c": {"d" : 2}}},
+     *     "a.b" :{"c": {"d" : 5}} // we choose the first one found, we go down longer keys first
+     * }
+     * {
+     *     "a" :{"b": {"c": {"NOT_d" : 2, "d": 2}}}
+     * }
+     *
+     * Conceptual "Worse case" 5 potential paths explored for "a.b.c.d" until 2 is finally returned
+     * {
+     *     "a.b.c": {"not_d": 2},
+     *     "a.b": {"c": {"not_d": 2}},
+     *     "a": {"b.c": {"not_d": 2}},
+     *     "a": {"b" :{ "c.not_d": 2}},
+     *     "a" :{"b": {"c": {"not_d" : 2}}},
+     *     "a" :{"b": {"c": {"d" : 2}}},
+     * }
+     *
+     * We don't exhaustively create all potential paths.
+     * If we did, this would result in 2^n-1 total possible paths, where {@code n = path.split("\\.").length}.
+     *
+     * Instead we lazily create potential paths once we know that they are possibilities.
+     *
+     * @param path Dot delimited path containing the field desired
+     * @param map The {@link Map} map to dig
+     * @return The found object. Returns {@code null} if not found
+     */
+    @Nullable
+    public static Object dig(String path, Map<String, Object> map) {
+        // short cut before search
+        if (map.keySet().contains(path)) {
+            return map.get(path);
+        }
+        String[] fields = path.split("\\.");
+        if (Arrays.stream(fields).anyMatch(String::isEmpty)) {
+            throw new IllegalArgumentException("Empty path detected. Invalid field name");
+        }
+        Stack<PotentialPath> pathStack = new Stack<>();
+        pathStack.push(new PotentialPath(map, 0));
+        return explore(fields, pathStack);
+    }
+
+    @SuppressWarnings("unchecked")
+    private static Object explore(String[] path, Stack<PotentialPath> pathStack) {
+        while (pathStack.empty() == false) {
+            PotentialPath potentialPath = pathStack.pop();
+            int endPos = potentialPath.pathPosition + 1;
+            int startPos = potentialPath.pathPosition;
+            Map<String, Object> map = potentialPath.map;
+            String candidateKey = null;
+            while(endPos <= path.length) {
+                candidateKey = mergePath(path, startPos, endPos);
+                Object next = map.get(candidateKey);
+                if (endPos == path.length && next != null) { // exit early, we reached the full path and found something
+                    return next;
+                }
+                if (next instanceof Map<?, ?>) { // we found another map, continue exploring down this path
+                    pathStack.push(new PotentialPath((Map<String, Object>)next, endPos));
+                }
+                endPos++;
+            }
+            if (candidateKey != null && map.containsKey(candidateKey)) { //exit early
+                return map.get(candidateKey);
+            }
+        }
+
+        return null;
+    }
+
+    private static String mergePath(String[] path, int start, int end) {
+        if (start + 1 == end) { // early exit, no need to create sb
+            return path[start];
+        }
+
+        StringBuilder sb = new StringBuilder();
+        for (int i = start; i < end - 1; i++) {
+            sb.append(path[i]);
+            sb.append(".");
+        }
+        sb.append(path[end - 1]);
+        return sb.toString();
+    }
+
+    private static class PotentialPath {
+
+        // Pointer to where to start exploring
+        private final Map<String, Object> map;
+        // Where in the requested path are we
+        private final int pathPosition;
+
+        private PotentialPath(Map<String, Object> map, int pathPosition) {
+            this.map = map;
+            this.pathPosition = pathPosition;
+        }
+
+    }
+}

+ 18 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java

@@ -65,4 +65,22 @@ public class FrequencyEncodingTests extends PreProcessingTests<FrequencyEncoding
         testProcess(encoding, fieldValues, matchers);
     }
 
+    public void testProcessWithNestedField() {
+        String field = "categorical.child";
+        List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
+        Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
+            v -> randomDoubleBetween(0.0, 1.0, false)));
+        String encodedFeatureName = "encoded";
+        FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
+
+        Map<String, Object> fieldValues = new HashMap<>() {{
+            put("categorical", new HashMap<>(){{
+                put("child", "farequote");
+            }});
+        }};
+
+        encoding.process(fieldValues);
+        assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote")));
+    }
+
 }

+ 15 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java

@@ -67,4 +67,19 @@ public class OneHotEncodingTests extends PreProcessingTests<OneHotEncoding> {
         testProcess(encoding, fieldValues, matchers);
     }
 
+    public void testProcessWithNestedField() {
+        String field = "categorical.child";
+        List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
+        Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
+        OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
+        Map<String, Object> fieldValues = new HashMap<>() {{
+            put("categorical", new HashMap<>(){{
+                put("child", "farequote");
+            }});
+        }};
+
+        encoding.process(fieldValues);
+        assertThat(fieldValues.get("Column_farequote"), equalTo(1));
+    }
+
 }

+ 20 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java

@@ -68,4 +68,24 @@ public class TargetMeanEncodingTests extends PreProcessingTests<TargetMeanEncodi
         testProcess(encoding, fieldValues, matchers);
     }
 
+    public void testProcessWithNestedField() {
+        String field = "categorical.child";
+        List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
+        Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
+            v -> randomDoubleBetween(0.0, 1.0, false)));
+        String encodedFeatureName = "encoded";
+        Double defaultvalue = randomDouble();
+        TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
+
+        Map<String, Object> fieldValues = new HashMap<>() {{
+            put("categorical", new HashMap<>(){{
+                put("child", "farequote");
+            }});
+        }};
+
+        encoding.process(fieldValues);
+
+        assertThat(fieldValues.get("encoded"), equalTo(valueMap.get("farequote")));
+    }
+
 }

+ 57 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java

@@ -445,6 +445,63 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
             closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
     }
 
+    public void testInferNestedFields() {
+        List<String> featureNames = Arrays.asList("foo.baz", "bar.biz");
+        Tree tree1 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(0.3))
+            .addNode(TreeNode.builder(2)
+                .setThreshold(0.8)
+                .setSplitFeature(1)
+                .setLeftChild(3)
+                .setRightChild(4))
+            .addNode(TreeNode.builder(3).setLeafValue(0.1))
+            .addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
+        Tree tree2 = Tree.builder()
+            .setFeatureNames(featureNames)
+            .setRoot(TreeNode.builder(0)
+                .setLeftChild(1)
+                .setRightChild(2)
+                .setSplitFeature(0)
+                .setThreshold(0.5))
+            .addNode(TreeNode.builder(1).setLeafValue(1.5))
+            .addNode(TreeNode.builder(2).setLeafValue(0.9))
+            .build();
+        Ensemble ensemble = Ensemble.builder()
+            .setTargetType(TargetType.REGRESSION)
+            .setFeatureNames(featureNames)
+            .setTrainedModels(Arrays.asList(tree1, tree2))
+            .setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5}))
+            .build();
+
+        Map<String, Object> featureMap = new HashMap<>() {{
+            put("foo", new HashMap<>(){{
+                put("baz", 0.4);
+            }});
+            put("bar", new HashMap<>(){{
+                put("biz", 0.0);
+            }});
+        }};
+        assertThat(0.9,
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+
+        featureMap = new HashMap<>() {{
+            put("foo", new HashMap<>(){{
+                put("baz", 2.0);
+            }});
+            put("bar", new HashMap<>(){{
+                put("biz", 0.7);
+            }});
+        }};
+        assertThat(0.5,
+            closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+    }
+
     public void testOperationsEstimations() {
         Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2);
         Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);

+ 52 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java

@@ -169,6 +169,58 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
             closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
     }
 
+    public void testInferNestedFields() {
+        // Build a tree with 2 nodes and 3 leaves using 2 features
+        // The leaves have unique values 0.1, 0.2, 0.3
+        Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
+        TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5);
+        builder.addLeaf(rootNode.getRightChild(), 0.3);
+        TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8);
+        builder.addLeaf(leftChildNode.getLeftChild(), 0.1);
+        builder.addLeaf(leftChildNode.getRightChild(), 0.2);
+
+        List<String> featureNames = Arrays.asList("foo.baz", "bar.biz");
+        Tree tree = builder.setFeatureNames(featureNames).build();
+
+        // This feature vector should hit the right child of the root node
+        Map<String, Object> featureMap = new HashMap<>() {{
+            put("foo", new HashMap<>(){{
+                put("baz", 0.6);
+            }});
+            put("bar", new HashMap<>(){{
+                put("biz", 0.0);
+            }});
+        }};
+        assertThat(0.3,
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+
+        // This should hit the left child of the left child of the root node
+        // i.e. it takes the path left, left
+        featureMap = new HashMap<>() {{
+            put("foo", new HashMap<>(){{
+                put("baz", 0.3);
+            }});
+            put("bar", new HashMap<>(){{
+                put("biz", 0.7);
+            }});
+        }};
+        assertThat(0.1,
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+
+        // This should hit the right child of the left child of the root node
+        // i.e. it takes the path left, right
+        featureMap = new HashMap<>() {{
+            put("foo", new HashMap<>(){{
+                put("baz", 0.3);
+            }});
+            put("bar", new HashMap<>(){{
+                put("biz", 0.9);
+            }});
+        }};
+        assertThat(0.2,
+            closeTo(((SingleValueInferenceResults)tree.infer(featureMap, RegressionConfig.EMPTY_PARAMS)).value(), 0.00001));
+    }
+
     public void testTreeClassificationProbability() {
         // Build a tree with 2 nodes and 3 leaves using 2 features
         // The leaves have unique values 0.1, 0.2, 0.3

+ 193 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/MapHelperTests.java

@@ -0,0 +1,193 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.core.ml.utils;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.nullValue;
+
+public class MapHelperTests extends ESTestCase {
+
+    public void testAbsolutePathStringAsKey() {
+        String path = "a.b.c.d";
+        Map<String, Object> map = Collections.singletonMap(path, 2);
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+        assertThat(MapHelper.dig(path, Collections.emptyMap()), is(nullValue()));
+    }
+
+    public void testSimplePath() {
+        String path = "a.b.c.d";
+        Map<String, Object> map = Collections.singletonMap("a",
+            Collections.singletonMap("b",
+                Collections.singletonMap("c",
+                    Collections.singletonMap("d", 2))));
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = Collections.singletonMap("a",
+            Collections.singletonMap("b",
+                Collections.singletonMap("e", // Not part of path
+                    Collections.singletonMap("d", 2))));
+        assertThat(MapHelper.dig(path, map), is(nullValue()));
+    }
+
+    public void testSimplePathReturningMap() {
+        String path = "a.b.c";
+        Map<String, Object> map = Collections.singletonMap("a",
+            Collections.singletonMap("b",
+                Collections.singletonMap("c",
+                    Collections.singletonMap("d", 2))));
+        assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
+    }
+
+    public void testSimpleMixedPath() {
+        String path = "a.b.c.d";
+        Map<String, Object> map = Collections.singletonMap("a",
+            Collections.singletonMap("b.c",
+                    Collections.singletonMap("d", 2)));
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = Collections.singletonMap("a.b",
+            Collections.singletonMap("c",
+                Collections.singletonMap("d", 2)));
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = Collections.singletonMap("a.b.c",
+                Collections.singletonMap("d", 2));
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = Collections.singletonMap("a",
+            Collections.singletonMap("b",
+                Collections.singletonMap("c.d", 2)));
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = Collections.singletonMap("a",
+            Collections.singletonMap("b.c.d", 2));
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = Collections.singletonMap("a.b",
+            Collections.singletonMap("c.d", 2));
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = Collections.singletonMap("a",
+            Collections.singletonMap("b.foo",
+                Collections.singletonMap("d", 2)));
+        assertThat(MapHelper.dig(path, map), is(nullValue()));
+
+        map = Collections.singletonMap("a",
+            Collections.singletonMap("b.c",
+                Collections.singletonMap("foo", 2)));
+        assertThat(MapHelper.dig(path, map), is(nullValue()));
+
+        map = Collections.singletonMap("x",
+            Collections.singletonMap("b.c",
+                Collections.singletonMap("d", 2)));
+        assertThat(MapHelper.dig(path, map), is(nullValue()));
+    }
+
+    public void testSimpleMixedPathReturningMap() {
+        String path = "a.b.c";
+        Map<String, Object> map = Collections.singletonMap("a",
+            Collections.singletonMap("b.c",
+                Collections.singletonMap("d", 2)));
+        assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
+
+        map = Collections.singletonMap("a",
+            Collections.singletonMap("b.foo",
+                Collections.singletonMap("d", 2)));
+        assertThat(MapHelper.dig(path, map), is(nullValue()));
+
+        map = Collections.singletonMap("a",
+            Collections.singletonMap("b.not_c",
+                Collections.singletonMap("foo", 2)));
+        assertThat(MapHelper.dig(path, map), is(nullValue()));
+
+        map = Collections.singletonMap("x",
+            Collections.singletonMap("b.c",
+                Collections.singletonMap("d", 2)));
+        assertThat(MapHelper.dig(path, map), is(nullValue()));
+    }
+
+    public void testMultiplePotentialPaths() {
+        String path = "a.b.c.d";
+        Map<String, Object> map = new LinkedHashMap<>() {{
+            put("a", Collections.singletonMap("b",
+                Collections.singletonMap("c",
+                    Collections.singletonMap("not_d", 5))));
+            put("a.b", Collections.singletonMap("c", Collections.singletonMap("d", 2)));
+        }};
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = new LinkedHashMap<>() {{
+            put("a", Collections.singletonMap("b",
+                Collections.singletonMap("c",
+                    Collections.singletonMap("d", 2))));
+            put("a.b", Collections.singletonMap("c", Collections.singletonMap("not_d", 5)));
+        }};
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = new LinkedHashMap<>() {{
+            put("a", Collections.singletonMap("b",
+                new HashMap<>() {{
+                    put("c", Collections.singletonMap("not_d", 5));
+                    put("c.d", 2);
+                }}));
+        }};
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = new LinkedHashMap<>() {{
+            put("a", Collections.singletonMap("b",
+                new HashMap<>() {{
+                    put("c", Collections.singletonMap("d", 2));
+                    put("c.not_d", 5);
+                }}));
+        }};
+        assertThat(MapHelper.dig(path, map), equalTo(2));
+
+        map = new LinkedHashMap<>() {{
+            put("a", Collections.singletonMap("b",
+                Collections.singletonMap("c",
+                    Collections.singletonMap("not_d", 5))));
+            put("a.b", Collections.singletonMap("c", Collections.singletonMap("not_d", 2)));
+        }};
+
+        assertThat(MapHelper.dig(path, map), is(nullValue()));
+    }
+
+    public void testMultiplePotentialPathsReturningMap() {
+        String path = "a.b.c";
+        Map<String, Object> map = new LinkedHashMap<>() {{
+            put("a", Collections.singletonMap("b",
+                Collections.singletonMap("c",
+                    Collections.singletonMap("d", 2))));
+            put("a.b", Collections.singletonMap("not_c", Collections.singletonMap("d", 2)));
+        }};
+        assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
+
+        map = new LinkedHashMap<>() {{
+            put("a", Collections.singletonMap("b",
+                Collections.singletonMap("not_c",
+                    Collections.singletonMap("d", 2))));
+            put("a.b", Collections.singletonMap("c", Collections.singletonMap("d", 2)));
+        }};
+        assertThat(MapHelper.dig(path, map), equalTo(Collections.singletonMap("d", 2)));
+
+        map = new LinkedHashMap<>() {{
+            put("a", Collections.singletonMap("b",
+                Collections.singletonMap("not_c",
+                    Collections.singletonMap("d", 2))));
+            put("a.b", Collections.singletonMap("not_c", Collections.singletonMap("d", 2)));
+        }};
+        assertThat(MapHelper.dig(path, map), is(nullValue()));
+    }
+
+}

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java

@@ -35,6 +35,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
 
 import java.util.Arrays;
@@ -128,7 +129,7 @@ public class InferenceProcessor extends AbstractProcessor {
         Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
         if (fieldMapping != null) {
             fieldMapping.forEach((src, dest) -> {
-                Object srcValue = fields.remove(src);
+                Object srcValue = MapHelper.dig(src, fields);
                 if (srcValue != null) {
                     fields.put(dest, srcValue);
                 }

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java

@@ -6,7 +6,6 @@
 package org.elasticsearch.xpack.ml.inference.loadingservice;
 
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
@@ -16,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
+import org.elasticsearch.xpack.core.ml.utils.MapHelper;
 
 import java.util.HashSet;
 import java.util.Map;
@@ -61,7 +61,7 @@ public class LocalModel implements Model {
     @Override
     public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
         try {
-            if (Sets.haveEmptyIntersection(fieldNames, fields.keySet())) {
+            if (fieldNames.stream().allMatch(f -> MapHelper.dig(f, fields) == null)) {
                 listener.onResponse(new WarningInferenceResults(Messages.getMessage(INFERENCE_WARNING_ALL_FIELDS_MISSING, modelId)));
                 return;
             }

+ 41 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java

@@ -181,7 +181,7 @@ public class InferenceProcessorTests extends ESTestCase {
         String modelId = "model";
         Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
 
-        Map<String, String> fieldMapping = new HashMap<>(3) {{
+        Map<String, String> fieldMapping = new HashMap<>(5) {{
             put("value1", "new_value1");
             put("value2", "new_value2");
             put("categorical", "new_categorical");
@@ -195,7 +195,7 @@ public class InferenceProcessorTests extends ESTestCase {
             new ClassificationConfig(topNClasses, null, null),
             fieldMapping);
 
-        Map<String, Object> source = new HashMap<>(3){{
+        Map<String, Object> source = new HashMap<>(5){{
             put("value1", 1);
             put("categorical", "foo");
             put("un_touched", "bar");
@@ -203,8 +203,46 @@ public class InferenceProcessorTests extends ESTestCase {
         Map<String, Object> ingestMetadata = new HashMap<>();
         IngestDocument document = new IngestDocument(source, ingestMetadata);
 
-        Map<String, Object> expectedMap = new HashMap<>(2) {{
+        Map<String, Object> expectedMap = new HashMap<>(7) {{
             put("new_value1", 1);
+            put("value1", 1);
+            put("categorical", "foo");
+            put("new_categorical", "foo");
+            put("un_touched", "bar");
+        }};
+        assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
+    }
+
+    public void testGenerateWithMappingNestedFields() {
+        String modelId = "model";
+        Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
+
+        Map<String, String> fieldMapping = new HashMap<>(5) {{
+            put("value1.foo", "new_value1");
+            put("value2", "new_value2");
+            put("categorical.bar", "new_categorical");
+        }};
+
+        InferenceProcessor processor = new InferenceProcessor(client,
+            auditor,
+            "my_processor",
+            "my_field",
+            modelId,
+            new ClassificationConfig(topNClasses, null, null),
+            fieldMapping);
+
+        Map<String, Object> source = new HashMap<>(5){{
+            put("value1", Collections.singletonMap("foo", 1));
+            put("categorical.bar", "foo");
+            put("un_touched", "bar");
+        }};
+        Map<String, Object> ingestMetadata = new HashMap<>();
+        IngestDocument document = new IngestDocument(source, ingestMetadata);
+
+        Map<String, Object> expectedMap = new HashMap<>(7) {{
+            put("new_value1", 1);
+            put("value1", Collections.singletonMap("foo", 1));
+            put("categorical.bar", "foo");
             put("new_categorical", "foo");
             put("un_touched", "bar");
         }};

+ 7 - 7
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java

@@ -41,7 +41,7 @@ public class LocalModelTests extends ESTestCase {
 
     public void testClassificationInfer() throws Exception {
         String modelId = "classification_model";
-        List<String> inputFields = Arrays.asList("foo", "bar", "categorical");
+        List<String> inputFields = Arrays.asList("field.foo", "field.bar", "categorical");
         TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
             .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
             .setTrainedModel(buildClassification(false))
@@ -49,8 +49,8 @@ public class LocalModelTests extends ESTestCase {
 
         Model model = new LocalModel(modelId, definition, new TrainedModelInput(inputFields));
         Map<String, Object> fields = new HashMap<>() {{
-            put("foo", 1.0);
-            put("bar", 0.5);
+            put("field.foo", 1.0);
+            put("field.bar", 0.5);
             put("categorical", "dog");
         }};
 
@@ -93,8 +93,8 @@ public class LocalModelTests extends ESTestCase {
         Model model = new LocalModel("regression_model", trainedModelDefinition, new TrainedModelInput(inputFields));
 
         Map<String, Object> fields = new HashMap<>() {{
-            put("foo", 1.0);
-            put("bar", 0.5);
+            put("field.foo", 1.0);
+            put("field.bar", 0.5);
             put("categorical", "dog");
         }};
 
@@ -147,7 +147,7 @@ public class LocalModelTests extends ESTestCase {
     }
 
     public static TrainedModel buildClassification(boolean includeLabels) {
-        List<String> featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog");
+        List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
         Tree tree1 = Tree.builder()
             .setFeatureNames(featureNames)
             .setRoot(TreeNode.builder(0)
@@ -193,7 +193,7 @@ public class LocalModelTests extends ESTestCase {
     }
 
     public static TrainedModel buildRegression() {
-        List<String> featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog");
+        List<String> featureNames = Arrays.asList("field.foo", "field.bar", "animal_cat", "animal_dog");
         Tree tree1 = Tree.builder()
             .setFeatureNames(featureNames)
             .setRoot(TreeNode.builder(0)

+ 32 - 16
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java

@@ -66,9 +66,9 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
         oneHotEncoding.put("cat", "animal_cat");
         oneHotEncoding.put("dog", "animal_dog");
         TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2)
-            .setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical")))
+            .setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
             .setParsedDefinition(new TrainedModelDefinition.Builder()
-                .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
+                .setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
                 .setTrainedModel(buildClassification(true)))
             .setVersion(Version.CURRENT)
             .setLicenseLevel(License.OperationMode.PLATINUM.description())
@@ -77,9 +77,9 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
             .setEstimatedHeapMemory(0)
             .build();
         TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1)
-            .setInput(new TrainedModelInput(Arrays.asList("foo", "bar", "categorical")))
+            .setInput(new TrainedModelInput(Arrays.asList("field.foo", "field.bar", "other.categorical")))
             .setParsedDefinition(new TrainedModelDefinition.Builder()
-                .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
+                .setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding)))
                 .setTrainedModel(buildRegression()))
             .setVersion(Version.CURRENT)
             .setEstimatedOperations(0)
@@ -99,26 +99,42 @@ public class ModelInferenceActionIT extends MlSingleNodeTestCase {
 
         List<Map<String, Object>> toInfer = new ArrayList<>();
         toInfer.add(new HashMap<>() {{
-            put("foo", 1.0);
-            put("bar", 0.5);
-            put("categorical", "dog");
+            put("field", new HashMap<>(){{
+                put("foo", 1.0);
+                put("bar", 0.5);
+            }});
+            put("other", new HashMap<>(){{
+                put("categorical", "dog");
+            }});
         }});
         toInfer.add(new HashMap<>() {{
-            put("foo", 0.9);
-            put("bar", 1.5);
-            put("categorical", "cat");
+            put("field", new HashMap<>(){{
+                put("foo", 0.9);
+                put("bar", 1.5);
+            }});
+            put("other", new HashMap<>(){{
+                put("categorical", "cat");
+            }});
         }});
 
         List<Map<String, Object>> toInfer2 = new ArrayList<>();
         toInfer2.add(new HashMap<>() {{
-            put("foo", 0.0);
-            put("bar", 0.01);
-            put("categorical", "dog");
+            put("field", new HashMap<>(){{
+                put("foo", 0.0);
+                put("bar", 0.01);
+            }});
+            put("other", new HashMap<>(){{
+                put("categorical", "dog");
+            }});
         }});
         toInfer2.add(new HashMap<>() {{
-            put("foo", 1.0);
-            put("bar", 0.0);
-            put("categorical", "cat");
+            put("field", new HashMap<>(){{
+                put("foo", 1.0);
+                put("bar", 0.0);
+            }});
+            put("other", new HashMap<>(){{
+                put("categorical", "cat");
+            }});
         }});
 
         // Test regression