Browse Source

[ML] Add support for dimensions in google vertex ai request (#132689)

* Add support for dimensions in request

* Update docs/changelog/132689.yaml
Jonathan Buttner 2 months ago
parent
commit
0165233a5f

+ 5 - 0
docs/changelog/132689.yaml

@@ -0,0 +1,5 @@
+pr: 132689
+summary: Add support for dimensions in google vertex ai request
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java

@@ -67,7 +67,7 @@ public class GoogleVertexAiEmbeddingsModel extends GoogleVertexAiModel {
     }
 
     // Should only be used directly for testing
-    GoogleVertexAiEmbeddingsModel(
+    public GoogleVertexAiEmbeddingsModel(
         String inferenceEntityId,
         TaskType taskType,
         String service,

+ 8 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequest.java

@@ -49,8 +49,14 @@ public class GoogleVertexAiEmbeddingsRequest implements GoogleVertexAiRequest {
         HttpPost httpPost = new HttpPost(model.nonStreamingUri());
 
         ByteArrayEntity byteEntity = new ByteArrayEntity(
-            Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), inputType, model.getTaskSettings()))
-                .getBytes(StandardCharsets.UTF_8)
+            Strings.toString(
+                new GoogleVertexAiEmbeddingsRequestEntity(
+                    truncationResult.input(),
+                    inputType,
+                    model.getTaskSettings(),
+                    model.getServiceSettings()
+                )
+            ).getBytes(StandardCharsets.UTF_8)
         );
 
         httpPost.setEntity(byteEntity);

+ 13 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntity.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.request;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
 
 import java.io.IOException;
@@ -21,13 +22,15 @@ import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
 public record GoogleVertexAiEmbeddingsRequestEntity(
     List<String> inputs,
     InputType inputType,
-    GoogleVertexAiEmbeddingsTaskSettings taskSettings
+    GoogleVertexAiEmbeddingsTaskSettings taskSettings,
+    GoogleVertexAiEmbeddingsServiceSettings serviceSettings
 ) implements ToXContentObject {
 
     private static final String INSTANCES_FIELD = "instances";
     private static final String CONTENT_FIELD = "content";
     private static final String PARAMETERS_FIELD = "parameters";
     private static final String AUTO_TRUNCATE_FIELD = "autoTruncate";
+    private static final String OUTPUT_DIMENSIONALITY_FIELD = "outputDimensionality";
     private static final String TASK_TYPE_FIELD = "task_type";
 
     private static final String CLASSIFICATION_TASK_TYPE = "CLASSIFICATION";
@@ -38,6 +41,7 @@ public record GoogleVertexAiEmbeddingsRequestEntity(
     public GoogleVertexAiEmbeddingsRequestEntity {
         Objects.requireNonNull(inputs);
         Objects.requireNonNull(taskSettings);
+        Objects.requireNonNull(serviceSettings);
     }
 
     @Override
@@ -62,15 +66,19 @@ public record GoogleVertexAiEmbeddingsRequestEntity(
 
         builder.endArray();
 
-        if (taskSettings.autoTruncate() != null) {
-            builder.startObject(PARAMETERS_FIELD);
-            {
+        builder.startObject(PARAMETERS_FIELD);
+        {
+            if (taskSettings.autoTruncate() != null) {
                 builder.field(AUTO_TRUNCATE_FIELD, taskSettings.autoTruncate());
             }
-            builder.endObject();
+            if (serviceSettings.dimensionsSetByUser()) {
+                builder.field(OUTPUT_DIMENSIONALITY_FIELD, serviceSettings.dimensions());
+            }
         }
         builder.endObject();
 
+        builder.endObject();
+
         return builder;
     }
 

+ 54 - 12
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestEntityTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
 
 import java.io.IOException;
@@ -26,7 +27,8 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
         var entity = new GoogleVertexAiEmbeddingsRequestEntity(
             List.of("abc"),
             null,
-            new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING)
+            new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING),
+            new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", true, null, 10, null, null)
         );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -42,17 +44,19 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
                     }
                 ],
                 "parameters": {
-                    "autoTruncate": true
+                    "autoTruncate": true,
+                    "outputDimensionality": 10
                 }
             }
             """));
     }
 
-    public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNotDefined() throws IOException {
+    public void testToXContent_SingleEmbeddingRequest_DoesNotWriteUndefinedFields() throws IOException {
         var entity = new GoogleVertexAiEmbeddingsRequestEntity(
             List.of("abc"),
             InputType.INTERNAL_INGEST,
-            new GoogleVertexAiEmbeddingsTaskSettings(null, null)
+            new GoogleVertexAiEmbeddingsTaskSettings(null, null),
+            new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
         );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -66,13 +70,45 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
                         "content": "abc",
                         "task_type": "RETRIEVAL_DOCUMENT"
                     }
-                ]
+                ],
+                "parameters": {
+                }
+            }
+            """));
+    }
+
+    public void testToXContent_SingleEmbeddingRequest_DoesNotWriteUndefinedFields_DimensionsSetByUserFalse() throws IOException {
+        var entity = new GoogleVertexAiEmbeddingsRequestEntity(
+            List.of("abc"),
+            InputType.INTERNAL_INGEST,
+            new GoogleVertexAiEmbeddingsTaskSettings(null, null),
+            new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, 10, null, null)
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
+            {
+                "instances": [
+                    {
+                        "content": "abc",
+                        "task_type": "RETRIEVAL_DOCUMENT"
+                    }
+                ],
+                "parameters": {}
             }
             """));
     }
 
     public void testToXContent_SingleEmbeddingRequest_DoesNotWriteInputTypeIfNotDefined() throws IOException {
-        var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc"), null, new GoogleVertexAiEmbeddingsTaskSettings(false, null));
+        var entity = new GoogleVertexAiEmbeddingsRequestEntity(
+            List.of("abc"),
+            null,
+            new GoogleVertexAiEmbeddingsTaskSettings(false, null),
+            new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -96,7 +132,8 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
         var entity = new GoogleVertexAiEmbeddingsRequestEntity(
             List.of("abc", "def"),
             InputType.INTERNAL_SEARCH,
-            new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING)
+            new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING),
+            new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", true, null, 10, null, null)
         );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -116,7 +153,8 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
                     }
                 ],
                 "parameters": {
-                    "autoTruncate": true
+                    "autoTruncate": true,
+                    "outputDimensionality": 10
                 }
             }
             """));
@@ -126,7 +164,8 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
         var entity = new GoogleVertexAiEmbeddingsRequestEntity(
             List.of("abc", "def"),
             null,
-            new GoogleVertexAiEmbeddingsTaskSettings(true, null)
+            new GoogleVertexAiEmbeddingsTaskSettings(true, null),
+            new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
         );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -154,7 +193,8 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
         var entity = new GoogleVertexAiEmbeddingsRequestEntity(
             List.of("abc", "def"),
             null,
-            new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION)
+            new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION),
+            new GoogleVertexAiEmbeddingsServiceSettings("location", "projectId", "modelId", false, null, null, null, null)
         );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -172,12 +212,14 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
                         "content": "def",
                         "task_type": "CLASSIFICATION"
                     }
-                ]
+                ],
+                "parameters": {
+                }
             }
             """));
     }
 
     public void testToXContent_ThrowsIfTaskSettingsIsNull() {
-        expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null, null));
+        expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null, null, null));
     }
 }

+ 104 - 13
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiEmbeddingsRequestTests.java

@@ -9,16 +9,21 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.request;
 
 import org.apache.http.HttpHeaders;
 import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.InputTypeTests;
 import org.elasticsearch.xpack.inference.common.Truncator;
 import org.elasticsearch.xpack.inference.common.TruncatorTests;
 import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
 
 import java.io.IOException;
 import java.util.List;
@@ -49,12 +54,15 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
         assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
-        assertThat(requestMap, aMapWithSize(1));
+        assertThat(requestMap, aMapWithSize(2));
         if (InputType.isSpecified(inputType)) {
             var convertedInputType = convertToString(inputType);
-            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)))));
+            assertThat(
+                requestMap,
+                is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of()))
+            );
         } else {
-            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")))));
+            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of())));
         }
     }
 
@@ -96,6 +104,43 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
         }
     }
 
+    public void testCreateRequest_WithDimensions() throws IOException {
+        var model = "model";
+        var input = "input";
+        var inputType = InputTypeTests.randomWithNull();
+
+        var request = createRequestWithDimensions(model, input, 10, inputType);
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(2));
+        if (InputType.isSpecified(inputType)) {
+            var convertedInputType = convertToString(inputType);
+            assertThat(
+                requestMap,
+                is(
+                    Map.of(
+                        "instances",
+                        List.of(Map.of("content", "input", "task_type", convertedInputType)),
+                        "parameters",
+                        Map.of("outputDimensionality", 10)
+                    )
+                )
+            );
+        } else {
+            assertThat(
+                requestMap,
+                is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of("outputDimensionality", 10)))
+            );
+        }
+    }
+
     public void testCreateRequest_WithTaskSettingsInputTypeSet() throws IOException {
         var model = "model";
         var input = "input";
@@ -111,12 +156,15 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
         assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
-        assertThat(requestMap, aMapWithSize(1));
+        assertThat(requestMap, aMapWithSize(2));
         if (InputType.isSpecified(inputType)) {
             var convertedInputType = convertToString(inputType);
-            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)))));
+            assertThat(
+                requestMap,
+                is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of()))
+            );
         } else {
-            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")))));
+            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of())));
         }
     }
 
@@ -136,15 +184,21 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
         assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
-        assertThat(requestMap, aMapWithSize(1));
+        assertThat(requestMap, aMapWithSize(2));
         if (InputType.isSpecified(requestInputType)) {
             var convertedInputType = convertToString(requestInputType);
-            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)))));
+            assertThat(
+                requestMap,
+                is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of()))
+            );
         } else if (InputType.isSpecified(taskSettingsInputType)) {
             var convertedInputType = convertToString(taskSettingsInputType);
-            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)))));
+            assertThat(
+                requestMap,
+                is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of()))
+            );
         } else {
-            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")))));
+            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of())));
         }
     }
 
@@ -164,13 +218,16 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
         assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
-        assertThat(requestMap, aMapWithSize(1));
+        assertThat(requestMap, aMapWithSize(2));
 
         if (InputType.isSpecified(inputType)) {
             var convertedInputType = convertToString(inputType);
-            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab", "task_type", convertedInputType)))));
+            assertThat(
+                requestMap,
+                is(Map.of("instances", List.of(Map.of("content", "ab", "task_type", convertedInputType)), "parameters", Map.of()))
+            );
         } else {
-            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab")))));
+            assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab")), "parameters", Map.of())));
         }
     }
 
@@ -191,6 +248,40 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
         );
     }
 
+    private static GoogleVertexAiEmbeddingsRequest createRequestWithDimensions(
+        String modelId,
+        String input,
+        int dimensions,
+        @Nullable InputType requestInputType
+    ) {
+
+        var embeddingsModel = new GoogleVertexAiEmbeddingsModel(
+            "id",
+            TaskType.TEXT_EMBEDDING,
+            "service",
+            new GoogleVertexAiEmbeddingsServiceSettings(
+                randomAlphaOfLength(8),
+                randomAlphaOfLength(8),
+                modelId,
+                true,
+                null,
+                dimensions,
+                null,
+                null
+            ),
+            new GoogleVertexAiEmbeddingsTaskSettings(null, null),
+            null,
+            new GoogleVertexAiSecretSettings(new SecureString(randomAlphaOfLength(8).toCharArray()))
+        );
+
+        return new GoogleVertexAiEmbeddingsWithoutAuthRequest(
+            TruncatorTests.createTruncator(),
+            new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
+            requestInputType,
+            embeddingsModel
+        );
+    }
+
     /**
      * We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest}
      */