Browse Source

[ML] Make OpenAI embeddings parser more flexible (#106808)

Fixes a parse failure that was dependent on the order of the fields
David Kyle 1 year ago
parent
commit
b40b17601c

+ 5 - 0
docs/changelog/106808.yaml

@@ -0,0 +1,5 @@
+pr: 106808
+summary: Make OpenAI embeddings parser more flexible
+area: Machine Learning
+type: bug
+issues: []

+ 27 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/XContentUtils.java

@@ -37,17 +37,42 @@ public class XContentUtils {
      * @throws IllegalStateException if the field cannot be found
      */
     public static void positionParserAtTokenAfterField(XContentParser parser, String field, String errorMsgTemplate) throws IOException {
-        XContentParser.Token token;
+        XContentParser.Token token = parser.nextToken();
 
-        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
+        while (token != null && token != XContentParser.Token.END_OBJECT) {
             if (token == XContentParser.Token.FIELD_NAME && parser.currentName().equals(field)) {
                 parser.nextToken();
                 return;
             }
+            token = parser.nextToken();
         }
 
         throw new IllegalStateException(format(errorMsgTemplate, field));
     }
 
+    /**
+     * Progress the parser consuming and discarding tokens until the
+     * parser points to the end of the current object. Nested objects
+     * and arrays are skipped.
+     *
+     * If successful the parser's current token is the end object token.
+     *
+     * @param parser
+     * @throws IOException
+     */
+    public static void consumeUntilObjectEnd(XContentParser parser) throws IOException {
+        XContentParser.Token token = parser.nextToken();
+
+        // token == null when correctly formed input has
+        // been fully parsed.
+        while (token != null && token != XContentParser.Token.END_OBJECT) {
+            if (token == XContentParser.Token.START_OBJECT || token == XContentParser.Token.START_ARRAY) {
+                parser.skipChildren();
+            }
+
+            token = parser.nextToken();
+        }
+    }
+
     private XContentUtils() {}
 }

+ 6 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java

@@ -159,18 +159,17 @@ public class CohereEmbeddingsResponseEntity {
     }
 
     private static InferenceServiceResults parseEmbeddingsObject(XContentParser parser) throws IOException {
-        XContentParser.Token token;
+        XContentParser.Token token = parser.nextToken();
 
-        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
+        while (token != null && token != XContentParser.Token.END_OBJECT) {
             if (token == XContentParser.Token.FIELD_NAME) {
                 var embeddingValueParser = EMBEDDING_PARSERS.get(parser.currentName());
-                if (embeddingValueParser == null) {
-                    continue;
+                if (embeddingValueParser != null) {
+                    parser.nextToken();
+                    return embeddingValueParser.apply(parser);
                 }
-
-                parser.nextToken();
-                return embeddingValueParser.apply(parser);
             }
+            token = parser.nextToken();
         }
 
         throw new IllegalStateException(

+ 4 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/huggingface/HuggingFaceElserResponseEntity.java

@@ -83,13 +83,15 @@ public class HuggingFaceElserResponseEntity {
         XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, parser);
 
         List<SparseEmbeddingResults.WeightedToken> weightedTokens = new ArrayList<>();
-
-        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
+        token = parser.nextToken();
+        while (token != null && token != XContentParser.Token.END_OBJECT) {
             XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, token, parser);
             var floatToken = parser.nextToken();
             XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, floatToken, parser);
 
             weightedTokens.add(new SparseEmbeddingResults.WeightedToken(parser.currentName(), parser.floatValue()));
+
+            token = parser.nextToken();
         }
 
         // prevent an out of bounds if for some reason the truncation list is smaller than the results

+ 3 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntity.java

@@ -20,6 +20,7 @@ import org.elasticsearch.xpack.inference.external.request.Request;
 import java.io.IOException;
 import java.util.List;
 
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd;
 import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
 import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
 
@@ -95,11 +96,8 @@ public class OpenAiEmbeddingsResponseEntity {
         positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE);
 
         List<Float> embeddingValues = XContentParserUtils.parseList(parser, OpenAiEmbeddingsResponseEntity::parseEmbeddingList);
-
-        // the parser is currently sitting at an ARRAY_END so go to the next token
-        parser.nextToken();
-        // if there are additional fields within this object, lets skip them, so we can begin parsing the next embedding array
-        parser.skipChildren();
+        // parse and discard the rest of the object
+        consumeUntilObjectEnd(parser);
 
         return new TextEmbeddingResults.Embedding(embeddingValues);
     }

+ 132 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/XContentUtilsTests.java

@@ -8,12 +8,15 @@
 package org.elasticsearch.xpack.inference.external.response;
 
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentEOFException;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.XContentType;
 
 import java.io.IOException;
 import java.util.Locale;
 
+import static org.hamcrest.Matchers.containsString;
+
 public class XContentUtilsTests extends ESTestCase {
 
     public void testMoveToFirstToken() throws IOException {
@@ -83,4 +86,133 @@ public class XContentUtilsTests extends ESTestCase {
             assertEquals(String.format(Locale.ROOT, errorFormat, missingField), exception.getMessage());
         }
     }
+
+    public void testPositionParserAtTokenAfterField_ThrowsWithMalformedJSON() throws IOException {
+        var json = """
+            {
+                "key": "value",
+                "foo": "bar"
+            """;
+        var errorFormat = "Error: %s";
+        var missingField = "missing field";
+
+        try (XContentParser parser = createParser(XContentType.JSON.xContent(), json)) {
+            var exception = expectThrows(
+                XContentEOFException.class,
+                () -> XContentUtils.positionParserAtTokenAfterField(parser, missingField, errorFormat)
+            );
+
+            assertThat(exception.getMessage(), containsString("Unexpected end-of-input"));
+        }
+    }
+
+    public void testConsumeUntilObjectEnd() throws IOException {
+        var json = """
+            {
+                "key": "value",
+                "foo": true,
+                "bar": 0.1
+            }
+            """;
+
+        try (XContentParser parser = createParser(XContentType.JSON.xContent(), json)) {
+            assertEquals(XContentParser.Token.START_OBJECT, parser.nextToken());
+            XContentUtils.consumeUntilObjectEnd(parser);
+            assertEquals(XContentParser.Token.END_OBJECT, parser.currentToken());
+            assertNull(parser.nextToken());
+        }
+
+        try (XContentParser parser = createParser(XContentType.JSON.xContent(), json)) {
+            parser.nextToken();
+            parser.nextToken();
+            assertEquals(XContentParser.Token.VALUE_STRING, parser.nextToken());
+            XContentUtils.consumeUntilObjectEnd(parser);
+            assertEquals(XContentParser.Token.END_OBJECT, parser.currentToken());
+            assertNull(parser.nextToken()); // fully parsed
+        }
+    }
+
+    public void testConsumeUntilObjectEnd_SkipArray() throws IOException {
+        var json = """
+            {
+                "key": "value",
+                "skip_array": [1.0, 2.0, 3.0]
+            }
+            """;
+
+        try (XContentParser parser = createParser(XContentType.JSON.xContent(), json)) {
+            assertEquals(XContentParser.Token.START_OBJECT, parser.nextToken());
+            XContentUtils.consumeUntilObjectEnd(parser);
+            assertEquals(XContentParser.Token.END_OBJECT, parser.currentToken());
+            assertNull(parser.nextToken());
+        }
+    }
+
+    public void testConsumeUntilObjectEnd_SkipNestedObject() throws IOException {
+        var json = """
+            {
+                "key": "value",
+                "skip_obj": {
+                  "foo": "bar"
+                }
+            }
+            """;
+
+        try (XContentParser parser = createParser(XContentType.JSON.xContent(), json)) {
+            assertEquals(XContentParser.Token.START_OBJECT, parser.nextToken());
+            XContentUtils.consumeUntilObjectEnd(parser);
+            assertEquals(XContentParser.Token.END_OBJECT, parser.currentToken());
+            assertNull(parser.nextToken()); // fully parsed
+        }
+    }
+
+    public void testConsumeUntilObjectEnd_InArray() throws IOException {
+        var json = """
+            [
+                {
+                    "key": "value",
+                    "skip_obj": {
+                      "foo": "bar"
+                    }
+                },
+                {
+                    "key": "value",
+                    "skip_array": [1.0, 2.0, 3.0]
+                },
+                {
+                    "key": "value",
+                    "skip_field1": "f1",
+                    "skip_field2": "f2"
+                }
+            ]
+            """;
+
+        try (XContentParser parser = createParser(XContentType.JSON.xContent(), json)) {
+            assertEquals(XContentParser.Token.START_ARRAY, parser.nextToken());
+            assertEquals(XContentParser.Token.START_OBJECT, parser.nextToken());
+
+            // Parser now inside object 1
+            assertEquals(XContentParser.Token.FIELD_NAME, parser.nextToken());
+            assertEquals("key", parser.currentName());
+            XContentUtils.consumeUntilObjectEnd(parser);
+            assertEquals(XContentParser.Token.END_OBJECT, parser.currentToken());
+
+            // Start of object 2
+            assertEquals(XContentParser.Token.START_OBJECT, parser.nextToken());
+            XContentUtils.consumeUntilObjectEnd(parser);
+            assertEquals(XContentParser.Token.END_OBJECT, parser.currentToken());
+
+            // Start of object 3
+            assertEquals(XContentParser.Token.START_OBJECT, parser.nextToken());
+            assertEquals(XContentParser.Token.FIELD_NAME, parser.nextToken());
+            assertEquals(XContentParser.Token.VALUE_STRING, parser.nextToken());
+            assertEquals(XContentParser.Token.FIELD_NAME, parser.nextToken());
+            assertEquals("skip_field1", parser.currentName());
+            XContentUtils.consumeUntilObjectEnd(parser);
+            assertEquals(XContentParser.Token.END_OBJECT, parser.currentToken());
+
+            assertEquals(XContentParser.Token.END_ARRAY, parser.nextToken());
+            assertNull(parser.nextToken()); // fully parsed
+        }
+    }
 }

+ 59 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiEmbeddingsResponseEntityTests.java

@@ -327,4 +327,63 @@ public class OpenAiEmbeddingsResponseEntityTests extends ESTestCase {
             is("Failed to parse object: expecting token of type [VALUE_NUMBER] but found [START_OBJECT]")
         );
     }
+
+    public void testFieldsInDifferentOrderServer() throws IOException {
+        // The fields of the objects in the data array are reordered
+        String response = """
+            {
+                "created": 1711530064,
+                "object": "list",
+                "id": "6667830b-716b-4796-9a61-33b67b5cc81d",
+                "model": "mxbai-embed-large-v1",
+                "data": [
+                    {
+                        "embedding": [
+                            -0.9,
+                            0.5,
+                            0.3
+                        ],
+                        "index": 0,
+                        "object": "embedding"
+                    },
+                    {
+                        "index": 0,
+                        "embedding": [
+                            0.1,
+                            0.5
+                        ],
+                        "object": "embedding"
+                    },
+                    {
+                        "object": "embedding",
+                        "index": 0,
+                        "embedding": [
+                            0.5,
+                            0.5
+                        ]
+                    }
+                ],
+                "usage": {
+                    "prompt_tokens": 0,
+                    "completion_tokens": 0,
+                    "total_tokens": 0
+                }
+            }""";
+
+        TextEmbeddingResults parsedResults = OpenAiEmbeddingsResponseEntity.fromResponse(
+            mock(Request.class),
+            new HttpResult(mock(HttpResponse.class), response.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(
+            parsedResults.embeddings(),
+            is(
+                List.of(
+                    new TextEmbeddingResults.Embedding(List.of(-0.9F, 0.5F, 0.3F)),
+                    new TextEmbeddingResults.Embedding(List.of(0.1F, 0.5F)),
+                    new TextEmbeddingResults.Embedding(List.of(0.5F, 0.5F))
+                )
+            )
+        );
+    }
 }