Browse Source

Fixing bug setting index when parsing Google Vertex AI results (#117287) (#117358)

* Using record ID as index value when parsing Google Vertex AI rerank results

* Update docs/changelog/117287.yaml

* PR feedback
Ying Mao 10 months ago
parent
commit
95524222b2

+ 5 - 0
docs/changelog/117287.yaml

@@ -0,0 +1,5 @@
+pr: 117287
+summary: Fixing bug setting index when parsing Google Vertex AI results
+area: Machine Learning
+type: bug
+issues: []

+ 25 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntity.java

@@ -30,6 +30,8 @@ import static org.elasticsearch.xpack.inference.external.response.XContentUtils.
 public class GoogleVertexAiRerankResponseEntity {
 
     private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Google Vertex AI rerank response";
+    private static final String INVALID_ID_FIELD_FORMAT_TEMPLATE = "Expected numeric value for record ID field in Google Vertex AI rerank "
+        + "response but received [%s]";
 
     /**
      * Parses the Google Vertex AI rerank response.
@@ -109,14 +111,27 @@ public class GoogleVertexAiRerankResponseEntity {
                 throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.SCORE.getPreferredName()));
             }
 
-            return new RankedDocsResults.RankedDoc(index, parsedRankedDoc.score, parsedRankedDoc.content);
+            if (parsedRankedDoc.id == null) {
+                throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDoc.ID.getPreferredName()));
+            }
+
+            try {
+                return new RankedDocsResults.RankedDoc(
+                    Integer.parseInt(parsedRankedDoc.id),
+                    parsedRankedDoc.score,
+                    parsedRankedDoc.content
+                );
+            } catch (NumberFormatException e) {
+                throw new IllegalStateException(format(INVALID_ID_FIELD_FORMAT_TEMPLATE, parsedRankedDoc.id));
+            }
         });
     }
 
-    private record RankedDoc(@Nullable Float score, @Nullable String content) {
+    private record RankedDoc(@Nullable Float score, @Nullable String content, @Nullable String id) {
 
         private static final ParseField CONTENT = new ParseField("content");
         private static final ParseField SCORE = new ParseField("score");
+        private static final ParseField ID = new ParseField("id");
         private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
             "google_vertex_ai_rerank_response",
             true,
@@ -126,6 +141,7 @@ public class GoogleVertexAiRerankResponseEntity {
         static {
             PARSER.declareString(Builder::setContent, CONTENT);
             PARSER.declareFloat(Builder::setScore, SCORE);
+            PARSER.declareString(Builder::setId, ID);
         }
 
         public static RankedDoc parse(XContentParser parser) {
@@ -137,6 +153,7 @@ public class GoogleVertexAiRerankResponseEntity {
 
             private String content;
             private Float score;
+            private String id;
 
             private Builder() {}
 
@@ -150,8 +167,13 @@ public class GoogleVertexAiRerankResponseEntity {
                 return this;
             }
 
+            public Builder setId(String id) {
+                this.id = id;
+                return this;
+            }
+
             public RankedDoc build() {
-                return new RankedDoc(score, content);
+                return new RankedDoc(score, content, id);
             }
         }
     }

+ 35 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googlevertexai/GoogleVertexAiRerankResponseEntityTests.java

@@ -39,7 +39,7 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
             new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
         );
 
-        assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"))));
+        assertThat(parsedResults.getRankedDocs(), is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"))));
     }
 
     public void testFromResponse_CreatesResultsForMultipleItems() throws IOException {
@@ -68,7 +68,7 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
 
         assertThat(
             parsedResults.getRankedDocs(),
-            is(List.of(new RankedDocsResults.RankedDoc(0, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1")))
+            is(List.of(new RankedDocsResults.RankedDoc(2, 0.97F, "content 2"), new RankedDocsResults.RankedDoc(1, 0.90F, "content 1")))
         );
     }
 
@@ -161,4 +161,37 @@ public class GoogleVertexAiRerankResponseEntityTests extends ESTestCase {
 
         assertThat(thrownException.getMessage(), is("Failed to find required field [score] in Google Vertex AI rerank response"));
     }
+
+    public void testFromResponse_FailsWhenIDFieldIsNotInteger() {
+        String responseJson = """
+            {
+                 "records": [
+                     {
+                         "id": "abcd",
+                         "title": "title 2",
+                         "content": "content 2",
+                         "score": 0.97
+                     },
+                     {
+                        "id": "1",
+                        "title": "title 1",
+                        "content": "content 1",
+                        "score": 0.96
+                     }
+                ]
+            }
+            """;
+
+        var thrownException = expectThrows(
+            IllegalStateException.class,
+            () -> GoogleVertexAiRerankResponseEntity.fromResponse(
+                new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+            )
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            is("Expected numeric value for record ID field in Google Vertex AI rerank response but received [abcd]")
+        );
+    }
 }