Browse Source

[Inference API] Use ObjectParser instead of manual parsing in GoogleVertexAiRerankResponseEntity (#110363)

Tim Grein 1 year ago
parent
commit
e7c3e353f6

+ 58 - 16
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java

@@ -8,7 +8,9 @@
 package org.elasticsearch.xpack.inference.external.response.googlevertexai;
 
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
-import org.elasticsearch.common.xcontent.XContentParserUtils;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ObjectParser;
+import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.XContentParserConfiguration;
@@ -21,10 +23,9 @@ import java.util.List;
 
 import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
 import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
-import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd;
+import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
 import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
-import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterFieldCurrentFlatObj;
 
 public class GoogleVertexAiRerankResponseEntity {
 
@@ -90,27 +91,68 @@ public class GoogleVertexAiRerankResponseEntity {
 
             positionParserAtTokenAfterField(jsonParser, "records", FAILED_TO_FIND_FIELD_TEMPLATE);
 
-            List<RankedDocsResults.RankedDoc> rankedDocs = parseList(jsonParser, GoogleVertexAiRerankResponseEntity::parseRankedDoc);
+            var rankedDocs = doParse(jsonParser);
 
             return new RankedDocsResults(rankedDocs);
         }
     }
 
-    private static RankedDocsResults.RankedDoc parseRankedDoc(XContentParser parser, Integer index) throws IOException {
-        ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
+    private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser) throws IOException {
+        return parseList(parser, (listParser, index) -> {
+            var parsedRankedDoc = RankedDoc.parse(parser);
 
-        positionParserAtTokenAfterFieldCurrentFlatObj(parser, "content", FAILED_TO_FIND_FIELD_TEMPLATE);
-        XContentParser.Token token = parser.currentToken();
-        XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_STRING, token, parser);
-        String content = parser.text();
+            if (parsedRankedDoc.content == null) {
+                throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.CONTENT.getPreferredName()));
+            }
 
-        positionParserAtTokenAfterFieldCurrentFlatObj(parser, "score", FAILED_TO_FIND_FIELD_TEMPLATE);
-        token = parser.currentToken();
-        XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser);
-        float score = parser.floatValue();
+            if (parsedRankedDoc.score == null) {
+                throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName()));
+            }
 
-        consumeUntilObjectEnd(parser);
+            return new RankedDocsResults.RankedDoc(index, parsedRankedDoc.score, parsedRankedDoc.content);
+        });
+    }
+
+    private record RankedDoc(@Nullable Float score, @Nullable String content) {
+
+        private static final ParseField CONTENT = new ParseField("content");
+        private static final ParseField SCORE = new ParseField("score");
+        private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
+            "google_vertex_ai_rerank_response",
+            true,
+            Builder::new
+        );
+
+        static {
+            PARSER.declareString(Builder::setContent, CONTENT);
+            PARSER.declareFloat(Builder::setScore, SCORE);
+        }
+
+        public static RankedDoc parse(XContentParser parser) {
+            Builder builder = PARSER.apply(parser, null);
+            return builder.build();
+        }
 
-        return new RankedDocsResults.RankedDoc(index, score, content);
+        private static final class Builder {
+
+            private String content;
+            private Float score;
+
+            private Builder() {}
+
+            public Builder setScore(Float score) {
+                this.score = score;
+                return this;
+            }
+
+            public Builder setContent(String content) {
+                this.content = content;
+                return this;
+            }
+
+            public RankedDoc build() {
+                return new RankedDoc(score, content);
+            }
+        }
     }
 }