Browse Source

Add painless method getByPath, get value from nested collections with dotted path (#43170)

Given a nested structure composed of Lists and Maps, getByPath will return the value
keyed by path.  getByPath is a method on Lists and Maps.

The path is string Map keys and integer List indices separated by dot. An optional third
argument returns a default value if the path lookup fails due to a missing value.

Eg.
['key0': ['a', 'b'], 'key1': ['c', 'd']].getByPath('key1') = ['c', 'd']
['key0': ['a', 'b'], 'key1': ['c', 'd']].getByPath('key1.0') = 'c'
['key0': ['a', 'b'], 'key1': ['c', 'd']].getByPath('key2', 'x') = 'x'
[['key0': 'value0'], ['key1': 'value1']].getByPath('1.key1') = 'value1'

Throws IllegalArgumentException if an item cannot be found and a default is not given.
Throws NumberFormatException if a path element operating on a List is not an integer.

Fixes #42769
Stuart Tettemer 6 years ago
parent
commit
2c8e9aeb2a

+ 6 - 0
docs/painless/painless-api-reference/painless-api-reference-score/packages.asciidoc

@@ -30,6 +30,8 @@ See the <<painless-api-reference-score, Score API>> for a high-level overview of
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * Map groupBy(Function)
 * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]()
@@ -84,6 +86,8 @@ See the <<painless-api-reference-score, Score API>> for a high-level overview of
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * Map groupBy(Function)
 * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]()
@@ -138,6 +142,8 @@ See the <<painless-api-reference-score, Score API>> for a high-level overview of
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * Map groupBy(Function)
 * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]()

+ 48 - 0
docs/painless/painless-api-reference/painless-api-reference-shared/packages.asciidoc

@@ -4335,6 +4335,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * Map groupBy(Function)
 * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]()
@@ -4386,6 +4388,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(BiFunction)
 * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer)
 * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def)
 * Map groupBy(BiFunction)
 * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]()
@@ -4500,6 +4504,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * Map groupBy(Function)
 * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]()
@@ -4666,6 +4672,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * Map groupBy(Function)
 * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]()
@@ -5367,6 +5375,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(BiFunction)
 * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer)
 * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def)
 * Map groupBy(BiFunction)
 * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]()
@@ -5457,6 +5467,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(BiFunction)
 * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer)
 * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def)
 * Map groupBy(BiFunction)
 * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]()
@@ -5502,6 +5514,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(BiFunction)
 * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer)
 * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def)
 * Map groupBy(BiFunction)
 * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]()
@@ -5668,6 +5682,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(BiFunction)
 * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer)
 * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def)
 * Map groupBy(BiFunction)
 * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]()
@@ -5764,6 +5780,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * def {java11-javadoc}/java.base/java/util/Deque.html#getFirst()[getFirst]()
 * def {java11-javadoc}/java.base/java/util/Deque.html#getLast()[getLast]()
 * int getLength()
@@ -5836,6 +5854,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * Map groupBy(Function)
 * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]()
@@ -6056,6 +6076,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(BiFunction)
 * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer)
 * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def)
 * Map groupBy(BiFunction)
 * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]()
@@ -6157,6 +6179,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * def {java11-javadoc}/java.base/java/util/NavigableMap.html#floorKey(java.lang.Object)[floorKey](def)
 * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer)
 * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def)
 * Map groupBy(BiFunction)
 * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]()
@@ -6642,6 +6666,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * def {java11-javadoc}/java.base/java/util/SortedMap.html#firstKey()[firstKey]()
 * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer)
 * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def)
 * Map groupBy(BiFunction)
 * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]()
@@ -6844,6 +6870,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * def {java11-javadoc}/java.base/java/util/Vector.html#firstElement()[firstElement]()
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * Map groupBy(Function)
 * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]()
@@ -6988,6 +7016,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * def {java11-javadoc}/java.base/java/util/NavigableMap.html#floorKey(java.lang.Object)[floorKey](def)
 * void {java11-javadoc}/java.base/java/util/Map.html#forEach(java.util.function.BiConsumer)[forEach](BiConsumer)
 * def {java11-javadoc}/java.base/java/util/Map.html#get(java.lang.Object)[get](def)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * def {java11-javadoc}/java.base/java/util/Map.html#getOrDefault(java.lang.Object,java.lang.Object)[getOrDefault](def, def)
 * Map groupBy(BiFunction)
 * int {java11-javadoc}/java.base/java/lang/Object.html#hashCode()[hashCode]()
@@ -7158,6 +7188,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * def {java11-javadoc}/java.base/java/util/Vector.html#firstElement()[firstElement]()
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * def {java11-javadoc}/java.base/java/util/List.html#get(int)[get](int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * Map groupBy(Function)
 * int {java11-javadoc}/java.base/java/util/List.html#hashCode()[hashCode]()
@@ -8016,6 +8048,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * Boolean get(int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * boolean getValue()
 * Map groupBy(Function)
@@ -8071,6 +8105,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * BytesRef get(int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * BytesRef getValue()
 * Map groupBy(Function)
@@ -8126,6 +8162,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * JodaCompatibleZonedDateTime get(int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * JodaCompatibleZonedDateTime getValue()
 * Map groupBy(Function)
@@ -8181,6 +8219,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * Double get(int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * double getValue()
 * Map groupBy(Function)
@@ -8240,6 +8280,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * double geohashDistance(String)
 * double geohashDistanceWithDefault(String, double)
 * GeoPoint get(int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * double getLat()
 * double[] getLats()
 * int getLength()
@@ -8301,6 +8343,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * Long get(int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * long getValue()
 * Map groupBy(Function)
@@ -8356,6 +8400,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * String get(int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * String getValue()
 * Map groupBy(Function)
@@ -8415,6 +8461,8 @@ See the <<painless-api-reference-shared, Shared API>> for a high-level overview
 * List findResults(Function)
 * void {java11-javadoc}/java.base/java/lang/Iterable.html#forEach(java.util.function.Consumer)[forEach](Consumer)
 * String get(int)
+* Object getByPath(String)
+* Object getByPath(String, Object)
 * int getLength()
 * String getValue()
 * Map groupBy(Function)

+ 113 - 0
modules/lang-painless/src/main/java/org/elasticsearch/painless/api/Augmentation.java

@@ -25,6 +25,7 @@ import java.util.Base64;
 import java.util.Collection;
 import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.TreeMap;
 import java.util.function.BiConsumer;
@@ -34,6 +35,7 @@ import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.ObjIntConsumer;
 import java.util.function.Predicate;
+import java.util.function.Supplier;
 import java.util.function.ToDoubleFunction;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
@@ -552,4 +554,115 @@ public class Augmentation {
         // O(N) or faster depending on implementation
         return result.toArray(new String[0]);
     }
+
+    /**
+     * Access values in nested containers with a dot separated path.  Path elements are treated
+     * as strings for Maps and integers for Lists.
+     * @throws IllegalArgumentException if any of the following:
+     *  - path is empty
+     *  - path contains a trailing '.' or a repeated '.'
+     *  - an element of the path does not exist, ie key or index not present
+     *  - there is a non-container type at a non-terminal path element
+     *  - a path element for a List is not an integer
+     * @return object at path
+     */
+    public static <E> Object getByPath(List<E> receiver, String path) {
+        return getByPathDispatch(receiver, splitPath(path), 0, throwCantFindValue(path));
+    }
+
+    /**
+     * Same as {@link #getByPath(List, String)}, but for Map.
+     */
+    public static <K,V> Object getByPath(Map<K,V> receiver, String path) {
+        return getByPathDispatch(receiver, splitPath(path), 0, throwCantFindValue(path));
+    }
+
+    /**
+     * Same as {@link #getByPath(List, String)}, but with a default value.
+     * @return element at path or {@code defaultValue} if the terminal path element does not exist.
+     */
+    public static <E> Object getByPath(List<E> receiver, String path, Object defaultValue) {
+        return getByPathDispatch(receiver, splitPath(path), 0, () -> defaultValue);
+    }
+
+    /**
+     * Same as {@link #getByPath(List, String, Object)}, but for Map.
+     */
+    public static <K,V> Object getByPath(Map<K,V> receiver, String path, Object defaultValue) {
+        return getByPathDispatch(receiver, splitPath(path), 0, () -> defaultValue);
+    }
+
+    // Dispatches to getByPathMap, getByPathList or returns obj if done. See handleMissing for dealing with missing
+    // elements.
+    private static Object getByPathDispatch(Object obj, String[] elements, int i, Supplier<Object> defaultSupplier) {
+        if (i > elements.length - 1) {
+            return obj;
+        } else if (elements[i].length() == 0 ) {
+            String format = "Extra '.' in path [%s] at index [%d]";
+            throw new IllegalArgumentException(String.format(Locale.ROOT, format, String.join(".", elements), i));
+        } else if (obj instanceof Map<?,?>) {
+            return getByPathMap((Map<?,?>) obj, elements, i, defaultSupplier);
+        } else if (obj instanceof List<?>) {
+            return getByPathList((List<?>) obj, elements, i, defaultSupplier);
+        }
+        return handleMissing(obj, elements, i, defaultSupplier);
+    }
+
+    // lookup existing key in map, call back to dispatch.
+    private static <K,V> Object getByPathMap(Map<K,V> map, String[] elements, int i, Supplier<Object> defaultSupplier) {
+        String element = elements[i];
+        if (map.containsKey(element)) {
+            return getByPathDispatch(map.get(element), elements, i + 1, defaultSupplier);
+        }
+        return handleMissing(map, elements, i, defaultSupplier);
+    }
+
+    // lookup existing index in list, call back to dispatch.  Throws IllegalArgumentException with NumberFormatException
+    // if index can't be parsed as an int.
+    private static <E> Object getByPathList(List<E> list, String[] elements, int i, Supplier<Object> defaultSupplier) {
+        String element = elements[i];
+        try {
+            int elemInt = Integer.parseInt(element);
+            if (list.size() >= elemInt) {
+                return getByPathDispatch(list.get(elemInt), elements, i + 1, defaultSupplier);
+            }
+        } catch (NumberFormatException e) {
+            String format = "Could not parse [%s] as a int index into list at path [%s] and index [%d]";
+            throw new IllegalArgumentException(String.format(Locale.ROOT, format, element, String.join(".", elements), i), e);
+        }
+        return handleMissing(list, elements, i, defaultSupplier);
+    }
+
+    // Split path on '.', throws IllegalArgumentException for empty paths and paths ending in '.'
+    private static String[] splitPath(String path) {
+        if (path.length() == 0) {
+            throw new IllegalArgumentException("Missing path");
+        }
+        if (path.endsWith(".")) {
+            String format = "Trailing '.' in path [%s]";
+            throw new IllegalArgumentException(String.format(Locale.ROOT, format, path));
+        }
+        return path.split("\\.");
+    }
+
+    // A supplier that throws IllegalArgumentException
+    private static Supplier<Object> throwCantFindValue(String path) {
+        return () -> {
+            throw new IllegalArgumentException(String.format(Locale.ROOT, "Could not find value at path [%s]", path));
+        };
+    }
+
+    // Use defaultSupplier if at last path element, otherwise throw IllegalArgumentException
+    private static Object handleMissing(Object obj, String[] elements, int i, Supplier<Object> defaultSupplier) {
+        if (obj instanceof List || obj instanceof Map) {
+            if (elements.length - 1 == i) {
+                return defaultSupplier.get();
+            }
+            String format = "Container does not have [%s], for non-terminal index [%d] in path [%s]";
+            throw new IllegalArgumentException(String.format(Locale.ROOT, format, elements[i], i, String.join(".", elements)));
+        }
+        String format = "Non-container [%s] at [%s], index [%d] in path [%s]";
+        throw new IllegalArgumentException(
+            String.format(Locale.ROOT, format, obj.getClass().getName(), elements[i], i, String.join(".", elements)));
+    }
 }

+ 4 - 0
modules/lang-painless/src/main/resources/org/elasticsearch/painless/spi/java.util.txt

@@ -126,6 +126,8 @@ class java.util.List {
   int org.elasticsearch.painless.api.Augmentation getLength()
   void sort(Comparator)
   List subList(int,int)
+  Object org.elasticsearch.painless.api.Augmentation getByPath(String)
+  Object org.elasticsearch.painless.api.Augmentation getByPath(String, Object)
 }
 
 class java.util.ListIterator {
@@ -161,6 +163,8 @@ class java.util.Map {
   void replaceAll(BiFunction)
   int size()
   Collection values()
+  Object org.elasticsearch.painless.api.Augmentation getByPath(String)
+  Object org.elasticsearch.painless.api.Augmentation getByPath(String, Object)
 
   # some adaptations of groovy methods
   List org.elasticsearch.painless.api.Augmentation collect(BiFunction)

+ 0 - 1
modules/lang-painless/src/test/java/org/elasticsearch/painless/AugmentationTests.java

@@ -232,7 +232,6 @@ public class AugmentationTests extends ScriptTestCase {
             new SplitCase("1\n1.1.\r\n1\r\n111", "\r\n"),
         };
         for (SplitCase split : cases) {
-            //System.out.println(String.format("Splitting '%s' by '%s' %d times", split.input, split.token, split.count));
             assertArrayEquals(
                 split.input.split(Pattern.quote(split.token), split.count),
                 (String[])exec("return \""+split.input+"\".splitOnToken(\""+split.token+"\", "+split.count+");")

+ 259 - 0
modules/lang-painless/src/test/java/org/elasticsearch/painless/GetByPathAugmentationTests.java

@@ -0,0 +1,259 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.painless;
+
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+public class GetByPathAugmentationTests extends ScriptTestCase {
+
+    private final String k001Key = "k011";
+    private final String k001Value = "b";
+    private final Map<String,String> k001Obj = new HashMap<>();
+    private final String k001MapStr = "['" + k001Key + "': '" + k001Value + "']";
+    private final String mapMapList = "['k0': ['k01': [['k010': 'a'], " + k001MapStr + "]], 'k1': ['q']]";
+
+    private final String l2m2l1Index0 = "ll0";
+    private final String l2m2l1Index1 = "ll1";
+    private final List<String> l2m2l1Obj = new ArrayList<>();
+    private final String l2m2l1Str = "['" + l2m2l1Index0 + "', '" + l2m2l1Index1 + "']";
+    private final String listMapListList = "[['m0':'v0'],['m1':'v1'],['m2':['l0','l1', " + l2m2l1Str + "]]]";
+
+    private final String mapList = "['key0': ['a', 'b'], 'key1': ['c', 'd']]";
+    private final String mapMap = "['a': ['b': 'c']]";
+
+    public GetByPathAugmentationTests() {
+        l2m2l1Obj.add(l2m2l1Index0);
+        l2m2l1Obj.add(l2m2l1Index1);
+        k001Obj.put(k001Key, k001Value);
+    }
+
+    private String toScript(String collection, String key) {
+        return String.format(Locale.ROOT, "return %s.getByPath('%s')", collection, key);
+    }
+
+    private String toScript(String collection, String key, String defaultValue) {
+        return String.format(Locale.ROOT, "return %s.getByPath('%s', %s)", collection, key, defaultValue);
+    }
+
+    private String numberFormat(String unparsable, String path, int i) {
+        String format = "Could not parse [%s] as a int index into list at path [%s] and index [%d]";
+        return String.format(Locale.ROOT, format, unparsable, path, i);
+    }
+
+    private String missingValue(String path) {
+        return String.format(Locale.ROOT, "Could not find value at path [%s]", path);
+    }
+
+    private void assertPathValue(String collection, String key, Object value) {
+        assertEquals(value, exec(toScript(collection, key)));
+    }
+
+    private void assertPathDefaultValue(String collection, String key, Object value, String defaultValue) {
+        assertEquals(value, exec(toScript(collection, key, defaultValue)));
+    }
+
+    private IllegalArgumentException assertPathError(String collection, String key, String message) {
+        return assertPathError(toScript(collection, key), message);
+    }
+
+    private IllegalArgumentException assertPathError(String collection, String key, String defaultValue, String message) {
+        return assertPathError(toScript(collection, key, defaultValue), message);
+    }
+
+    private IllegalArgumentException assertPathError(String script, String message) {
+        IllegalArgumentException illegal = expectScriptThrows(
+            IllegalArgumentException.class,
+            () -> exec(script)
+        );
+        assertEquals(message, illegal.getMessage());
+        return illegal;
+    }
+
+    public void testOneLevelMap() {
+        assertPathValue("['k0':'v0']", "k0", "v0");
+    }
+
+    public void testOneLevelList() {
+        assertPathValue("['a','b','c','d']", "2", "c");
+    }
+
+    public void testTwoLevelMapList() {
+        assertPathValue("['key0': ['a', 'b'], 'key1': ['c', 'd']]", "key1.0", "c");
+    }
+
+    public void testMapDiffSizeList() {
+        assertPathValue("['k0': ['a','b','c','d'], 'k1': ['q']]", "k0.3", "d");
+    }
+
+    public void testBiMapList() {
+        assertPathValue(mapMapList, "k0.k01.1.k011", k001Value);
+    }
+
+    public void testBiMapListObject() {
+        assertPathValue(mapMapList, "k0.k01.1", k001Obj);
+    }
+
+    public void testListMap() {
+        assertPathValue("[['key0': 'value0'], ['key1': 'value1']]", "1.key1", "value1");
+    }
+
+    public void testTriList() {
+        assertPathValue("[['a','b'],['c','d'],[['e','f'],['g','h']]]", "2.1.1", "h");
+    }
+
+    public void testMapBiListObject() {
+        assertPathValue(listMapListList, "2.m2.2", l2m2l1Obj);
+    }
+
+    public void testMapBiList() {
+        assertPathValue(listMapListList, "2.m2.2.1", l2m2l1Index1);
+    }
+
+    public void testGetCollection() {
+        List<String> k1List = new ArrayList<>();
+        k1List.add("c");
+        k1List.add("d");
+        assertPathValue("['key0': ['a', 'b'], 'key1': ['c', 'd']]", "key1", k1List);
+    }
+
+    public void testMapListDefaultOneLevel() {
+        assertPathDefaultValue(mapList, "key2", "x", "'x'");
+    }
+
+    public void testMapListDefaultTwoLevel() {
+        assertPathDefaultValue(mapList, "key1.1", "d", "'x'");
+    }
+
+    public void testBiMapListDefault() {
+        assertPathDefaultValue(mapMapList, "k0.k01.1.k012", "foo", "'foo'");
+    }
+
+    public void testBiMapListDefaultExists() {
+        assertPathDefaultValue(mapMapList, "k0.k01.1.k011", "b", "'foo'");
+    }
+
+    public void testBiMapListDefaultObjectExists() {
+        assertPathDefaultValue(mapMapList, "k0.k01.1", k001Obj, "'foo'");
+    }
+
+    public void testBiMapListDefaultObject() {
+        assertPathDefaultValue(mapMapList, "k0.k01.9", k001Obj, k001MapStr);
+    }
+
+    public void testListMapBiListDefaultExists() {
+        assertPathDefaultValue(listMapListList, "2.m2.2", l2m2l1Obj, "'foo'");
+    }
+
+    public void testListMapBiListDefaultObject() {
+        assertPathDefaultValue(listMapListList, "2.m2.9", l2m2l1Obj, l2m2l1Str);
+    }
+
+    public void testBiListBadIndex() {
+        String path = "1.k0";
+        IllegalArgumentException err = assertPathError("[['a','b'],['c','d']]", path, numberFormat("k0", path, 1));
+        assertEquals(err.getCause().getClass(), NumberFormatException.class);
+    }
+
+    public void testBiMapListMissingLast() {
+        String path = "k0.k01.1.k012";
+        assertPathError(mapMapList, path, missingValue(path));
+    }
+
+    public void testBiMapListBadIndex() {
+        String path = "k0.k01.k012";
+        IllegalArgumentException err = assertPathError(mapMapList, path, numberFormat("k012", path, 2));
+        assertEquals(err.getCause().getClass(), NumberFormatException.class);
+    }
+
+    public void testListMapBiListMissingObject() {
+        String path = "2.m2.12";
+        assertPathError(listMapListList, path, missingValue(path));
+    }
+
+    public void testListMapBiListBadIndexAtObject() {
+        String path = "2.m2.a8";
+        IllegalArgumentException err = assertPathError(listMapListList, path, numberFormat("a8", path, 2));
+        assertEquals(err.getCause().getClass(), NumberFormatException.class);
+    }
+
+    public void testNonContainer() {
+        assertPathError(mapMap, "a.b.c", "Non-container [java.lang.String] at [c], index [2] in path [a.b.c]");
+    }
+
+    public void testMissingPath() {
+        assertPathError(mapMap, "", "Missing path");
+    }
+
+    public void testDoubleDot() {
+        assertPathError(mapMap, "a..b", "Extra '.' in path [a..b] at index [1]");
+    }
+
+    public void testTrailingDot() {
+        assertPathError(mapMap, "a.b.", "Trailing '.' in path [a.b.]");
+    }
+
+    public void testBiListDefaultBadIndex() {
+        String path = "1.k0";
+        IllegalArgumentException err = assertPathError(
+            "[['a','b'],['c','d']]",
+            path,
+            "'foo'",
+            numberFormat("k0", path, 1));
+        assertEquals(err.getCause().getClass(), NumberFormatException.class);
+    }
+
+    public void testBiMapListDefaultBadIndex() {
+        String path = "k0.k01.k012";
+        IllegalArgumentException err = assertPathError(
+            mapMapList,
+            path,
+            "'foo'",
+            numberFormat("k012", path, 2));
+        assertEquals(err.getCause().getClass(), NumberFormatException.class);
+    }
+
+    public void testListMapBiListObjectDefaultBadIndex() {
+        String path = "2.m2.a8";
+        IllegalArgumentException err = assertPathError(
+            listMapListList,
+            path,
+            "'foo'",
+            numberFormat("a8", path, 2));
+        assertEquals(err.getCause().getClass(), NumberFormatException.class);
+    }
+
+    public void testNonContainerDefaultBadIndex() {
+        assertPathError(mapMap, "a.b.c", "'foo'",
+            "Non-container [java.lang.String] at [c], index [2] in path [a.b.c]");
+    }
+
+    public void testDoubleDotDefault() {
+        assertPathError(mapMap, "a..b", "'foo'", "Extra '.' in path [a..b] at index [1]");
+    }
+
+    public void testTrailingDotDefault() {
+        assertPathError(mapMap, "a.b.", "'foo'", "Trailing '.' in path [a.b.]");
+    }
+}