Prechádzať zdrojové kódy

[ML] Implement JSONPath replacement for Inference API (#127036)

* Adding initial extractor

* Finishing tests

* Addressing feedback
Jonathan Buttner 6 mesiacov pred
rodič
commit
3156cc7c0f

+ 209 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java

@@ -0,0 +1,209 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.elasticsearch.common.Strings;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+/**
+ * Extracts fields from a {@link Map}.
+ *
+ * Uses a subset of the JSONPath schema to extract fields from a map.
+ * For more information <a href="https://en.wikipedia.org/wiki/JSONPath">see here</a>.
+ *
+ * This implementation differs in how it handles lists in that JSONPath will flatten inner lists. This implementation
+ * preserves inner lists.
+ *
+ * Examples of the schema:
+ *
+ * <pre>
+ * {@code
+ * $.field1.array[*].field2
+ * $.field1.field2
+ * }
+ * </pre>
+ *
+ * Given the map
+ * <pre>
+ * {@code
+ * {
+ *     "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4",
+ *     "latency": 38,
+ *     "usage": {
+ *         "token_count": 3072
+ *     },
+ *     "result": {
+ *         "embeddings": [
+ *             {
+ *                 "index": 0,
+ *                 "embedding": [
+ *                     2,
+ *                     4
+ *                 ]
+ *             },
+ *             {
+ *                 "index": 1,
+ *                 "embedding": [
+ *                     1,
+ *                     2
+ *                 ]
+ *             }
+ *         ]
+ *     }
+ * }
+ * }
+ * </pre>
+ *
+ * <pre>
+ * {@code
+ * var embeddings = MapPathExtractor.extract(map, "$.result.embeddings[*].embedding");
+ * }
+ * </pre>
+ *
+ * Will result in:
+ *
+ * <pre>
+ * {@code
+ * [
+ *   [2, 4],
+ *   [1, 2]
+ * ]
+ * }
+ * </pre>
+ *
+ * This implementation differs from JSONPath when handling a list of maps. JSONPath will flatten the result and return a single array.
+ * this implementation will preserve each nested list while gathering the results.
+ *
+ * For example
+ *
+ * <pre>
+ * {@code
+ * {
+ *   "result": [
+ *     {
+ *       "key": [
+ *         {
+ *           "a": 1.1
+ *         },
+ *         {
+ *           "a": 2.2
+ *         }
+ *       ]
+ *     },
+ *     {
+ *       "key": [
+ *         {
+ *           "a": 3.3
+ *         },
+ *         {
+ *           "a": 4.4
+ *         }
+ *       ]
+ *     }
+ *   ]
+ * }
+ * }
+ * {@code var embeddings = MapPathExtractor.extract(map, "$.result[*].key[*].a");}
+ *
+ * JSONPath: {@code [1.1, 2.2, 3.3, 4.4]}
+ * This implementation: {@code [[1.1, 2.2], [3.3, 4.4]]}
+ * </pre>
+ */
+public class MapPathExtractor {
+
+    private static final String DOLLAR = "$";
+
+    // default for testing
+    static final Pattern dotFieldPattern = Pattern.compile("^\\.([^.\\[]+)(.*)");
+    static final Pattern arrayWildcardPattern = Pattern.compile("^\\[\\*\\](.*)");
+
+    public static Object extract(Map<String, Object> data, String path) {
+        if (data == null || data.isEmpty() || path == null || path.trim().isEmpty()) {
+            return null;
+        }
+
+        var cleanedPath = path.trim();
+
+        if (cleanedPath.startsWith(DOLLAR)) {
+            cleanedPath = cleanedPath.substring(DOLLAR.length());
+        } else {
+            throw new IllegalArgumentException(Strings.format("Path [%s] must start with a dollar sign ($)", cleanedPath));
+        }
+
+        return navigate(data, cleanedPath);
+    }
+
+    private static Object navigate(Object current, String remainingPath) {
+        if (current == null || remainingPath == null || remainingPath.isEmpty()) {
+            return current;
+        }
+
+        var dotFieldMatcher = dotFieldPattern.matcher(remainingPath);
+        var arrayWildcardMatcher = arrayWildcardPattern.matcher(remainingPath);
+
+        if (dotFieldMatcher.matches()) {
+            String field = dotFieldMatcher.group(1);
+            if (field == null || field.isEmpty()) {
+                throw new IllegalArgumentException(
+                    Strings.format(
+                        "Unable to extract field from remaining path [%s]. Fields must be delimited by a dot character.",
+                        remainingPath
+                    )
+                );
+            }
+
+            String nextPath = dotFieldMatcher.group(2);
+            if (current instanceof Map<?, ?> currentMap) {
+                var fieldFromMap = currentMap.get(field);
+                if (fieldFromMap == null) {
+                    throw new IllegalArgumentException(Strings.format("Unable to find field [%s] in map", field));
+                }
+
+                return navigate(currentMap.get(field), nextPath);
+            } else {
+                throw new IllegalArgumentException(
+                    Strings.format(
+                        "Current path [%s] matched the dot field pattern but the current object is not a map, "
+                            + "found invalid type [%s] instead.",
+                        remainingPath,
+                        current.getClass().getSimpleName()
+                    )
+                );
+            }
+        } else if (arrayWildcardMatcher.matches()) {
+            String nextPath = arrayWildcardMatcher.group(1);
+            if (current instanceof List<?> list) {
+                List<Object> results = new ArrayList<>();
+
+                for (Object item : list) {
+                    Object result = navigate(item, nextPath);
+                    if (result != null) {
+                        results.add(result);
+                    }
+                }
+
+                return results;
+            } else {
+                throw new IllegalArgumentException(
+                    Strings.format(
+                        "Current path [%s] matched the array field pattern but the current object is not a list, "
+                            + "found invalid type [%s] instead.",
+                        remainingPath,
+                        current.getClass().getSimpleName()
+                    )
+                );
+            }
+        }
+
+        throw new IllegalArgumentException(Strings.format("Invalid path received [%s], unable to extract a field name.", remainingPath));
+    }
+}

+ 187 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java

@@ -0,0 +1,187 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class MapPathExtractorTests extends ESTestCase {
+    public void testExtract_RetrievesListOfLists() {
+        Map<String, Object> input = Map.of(
+            "result",
+            Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4))))
+        );
+
+        assertThat(MapPathExtractor.extract(input, "$.result.embeddings[*].embedding"), is(List.of(List.of(1, 2), List.of(3, 4))));
+    }
+
+    public void testExtract_IteratesListOfMapsToListOfStrings() {
+        Map<String, Object> input = Map.of(
+            "result",
+            List.of(Map.of("key", List.of("value1", "value2")), Map.of("key", List.of("value3", "value4")))
+        );
+
+        assertThat(
+            MapPathExtractor.extract(input, "$.result[*].key[*]"),
+            is(List.of(List.of("value1", "value2"), List.of("value3", "value4")))
+        );
+    }
+
+    public void testExtract_IteratesListOfMapsToListOfMapsOfStringToDoubles() {
+        Map<String, Object> input = Map.of(
+            "result",
+            List.of(
+                Map.of("key", List.of(Map.of("a", 1.1d), Map.of("a", 2.2d))),
+                Map.of("key", List.of(Map.of("a", 3.3d), Map.of("a", 4.4d)))
+            )
+        );
+
+        assertThat(MapPathExtractor.extract(input, "$.result[*].key[*].a"), is(List.of(List.of(1.1d, 2.2d), List.of(3.3d, 4.4d))));
+    }
+
+    public void testExtract_ReturnsNullForEmptyList() {
+        Map<String, Object> input = Map.of();
+
+        assertNull(MapPathExtractor.extract(input, "$.awesome"));
+    }
+
+    public void testExtract_ReturnsNull_WhenTheInputMapIsNull() {
+        assertNull(MapPathExtractor.extract(null, "$.result"));
+    }
+
+    public void testExtract_ReturnsNull_WhenPathIsNull() {
+        assertNull(MapPathExtractor.extract(Map.of("key", "value"), null));
+    }
+
+    public void testExtract_ReturnsNull_WhenPathIsWhiteSpace() {
+        assertNull(MapPathExtractor.extract(Map.of("key", "value"), "    "));
+    }
+
+    public void testExtract_ThrowsException_WhenPathDoesNotStartWithDollarSign() {
+        var exception = expectThrows(IllegalArgumentException.class, () -> MapPathExtractor.extract(Map.of("key", "value"), ".key"));
+        assertThat(exception.getMessage(), is("Path [.key] must start with a dollar sign ($)"));
+    }
+
+    public void testExtract_ThrowsException_WhenCannotFindField() {
+        Map<String, Object> input = Map.of("result", "key");
+
+        var exception = expectThrows(IllegalArgumentException.class, () -> MapPathExtractor.extract(input, "$.awesome"));
+        assertThat(exception.getMessage(), is("Unable to find field [awesome] in map"));
+    }
+
+    public void testExtract_ThrowsAnException_WhenThePathIsInvalid() {
+        Map<String, Object> input = Map.of("result", "key");
+
+        var exception = expectThrows(IllegalArgumentException.class, () -> MapPathExtractor.extract(input, "$awesome"));
+        assertThat(exception.getMessage(), is("Invalid path received [awesome], unable to extract a field name."));
+    }
+
+    public void testExtract_ThrowsException_WhenMissingArraySyntax() {
+        Map<String, Object> input = Map.of(
+            "result",
+            Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4))))
+        );
+
+        var exception = expectThrows(
+            IllegalArgumentException.class,
+            // embeddings is missing [*] to indicate that it is an array
+            () -> MapPathExtractor.extract(input, "$.result.embeddings.embedding")
+        );
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Current path [.embedding] matched the dot field pattern but the current object "
+                    + "is not a map, found invalid type [List12] instead."
+            )
+        );
+    }
+
+    public void testExtract_ThrowsException_WhenHasArraySyntaxButIsAMap() {
+        Map<String, Object> input = Map.of(
+            "result",
+            Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4))))
+        );
+
+        var exception = expectThrows(
+            IllegalArgumentException.class,
+            // result is not an array
+            () -> MapPathExtractor.extract(input, "$.result[*].embeddings[*].embedding")
+        );
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Current path [[*].embeddings[*].embedding] matched the array field pattern but the current "
+                    + "object is not a list, found invalid type [Map1] instead."
+            )
+        );
+    }
+
+    public void testExtract_ReturnsAnEmptyList_WhenItIsEmpty() {
+        Map<String, Object> input = Map.of("result", List.of());
+
+        assertThat(MapPathExtractor.extract(input, "$.result"), is(List.of()));
+    }
+
+    public void testExtract_ReturnsAnEmptyList_WhenItIsEmpty_PathIncludesArray() {
+        Map<String, Object> input = Map.of("result", List.of());
+
+        assertThat(MapPathExtractor.extract(input, "$.result[*]"), is(List.of()));
+    }
+
+    public void testDotFieldPattern() {
+        {
+            var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc.123");
+            assertTrue(matcher.matches());
+            assertThat(matcher.group(1), is("abc"));
+            assertThat(matcher.group(2), is(".123"));
+        }
+        {
+            var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc[*].123");
+            assertTrue(matcher.matches());
+            assertThat(matcher.group(1), is("abc"));
+            assertThat(matcher.group(2), is("[*].123"));
+        }
+        {
+            var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc[.123");
+            assertTrue(matcher.matches());
+            assertThat(matcher.group(1), is("abc"));
+            assertThat(matcher.group(2), is("[.123"));
+        }
+        {
+            var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc");
+            assertTrue(matcher.matches());
+            assertThat(matcher.group(1), is("abc"));
+            assertThat(matcher.group(2), is(""));
+        }
+    }
+
+    public void testArrayWildcardPattern() {
+        {
+            var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[*].abc.123");
+            assertTrue(matcher.matches());
+            assertThat(matcher.group(1), is(".abc.123"));
+        }
+        {
+            var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[*]");
+            assertTrue(matcher.matches());
+            assertThat(matcher.group(1), is(""));
+        }
+        {
+            var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[1].abc");
+            assertFalse(matcher.matches());
+        }
+        {
+            var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[].abc");
+            assertFalse(matcher.matches());
+        }
+    }
+}