|
@@ -9,16 +9,21 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.request;
|
|
|
|
|
|
import org.apache.http.HttpHeaders;
|
|
|
import org.apache.http.client.methods.HttpPost;
|
|
|
+import org.elasticsearch.common.settings.SecureString;
|
|
|
import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.inference.InputType;
|
|
|
+import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
import org.elasticsearch.xcontent.XContentType;
|
|
|
import org.elasticsearch.xpack.inference.InputTypeTests;
|
|
|
import org.elasticsearch.xpack.inference.common.Truncator;
|
|
|
import org.elasticsearch.xpack.inference.common.TruncatorTests;
|
|
|
import org.elasticsearch.xpack.inference.external.request.Request;
|
|
|
+import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
|
|
|
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
|
|
|
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModelTests;
|
|
|
+import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
|
|
|
+import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.util.List;
|
|
@@ -49,12 +54,15 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
|
|
|
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
|
|
|
|
|
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
|
|
- assertThat(requestMap, aMapWithSize(1));
|
|
|
+ assertThat(requestMap, aMapWithSize(2));
|
|
|
if (InputType.isSpecified(inputType)) {
|
|
|
var convertedInputType = convertToString(inputType);
|
|
|
- assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)))));
|
|
|
+ assertThat(
|
|
|
+ requestMap,
|
|
|
+ is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of()))
|
|
|
+ );
|
|
|
} else {
|
|
|
- assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")))));
|
|
|
+ assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of())));
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -96,6 +104,43 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testCreateRequest_WithDimensions() throws IOException {
|
|
|
+ var model = "model";
|
|
|
+ var input = "input";
|
|
|
+ var inputType = InputTypeTests.randomWithNull();
|
|
|
+
|
|
|
+ var request = createRequestWithDimensions(model, input, 10, inputType);
|
|
|
+ var httpRequest = request.createHttpRequest();
|
|
|
+
|
|
|
+ assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
|
|
+ var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
|
|
+
|
|
|
+ assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
|
|
|
+ assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
|
|
|
+
|
|
|
+ var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
|
|
+ assertThat(requestMap, aMapWithSize(2));
|
|
|
+ if (InputType.isSpecified(inputType)) {
|
|
|
+ var convertedInputType = convertToString(inputType);
|
|
|
+ assertThat(
|
|
|
+ requestMap,
|
|
|
+ is(
|
|
|
+ Map.of(
|
|
|
+ "instances",
|
|
|
+ List.of(Map.of("content", "input", "task_type", convertedInputType)),
|
|
|
+ "parameters",
|
|
|
+ Map.of("outputDimensionality", 10)
|
|
|
+ )
|
|
|
+ )
|
|
|
+ );
|
|
|
+ } else {
|
|
|
+ assertThat(
|
|
|
+ requestMap,
|
|
|
+ is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of("outputDimensionality", 10)))
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testCreateRequest_WithTaskSettingsInputTypeSet() throws IOException {
|
|
|
var model = "model";
|
|
|
var input = "input";
|
|
@@ -111,12 +156,15 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
|
|
|
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
|
|
|
|
|
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
|
|
- assertThat(requestMap, aMapWithSize(1));
|
|
|
+ assertThat(requestMap, aMapWithSize(2));
|
|
|
if (InputType.isSpecified(inputType)) {
|
|
|
var convertedInputType = convertToString(inputType);
|
|
|
- assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)))));
|
|
|
+ assertThat(
|
|
|
+ requestMap,
|
|
|
+ is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of()))
|
|
|
+ );
|
|
|
} else {
|
|
|
- assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")))));
|
|
|
+ assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of())));
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -136,15 +184,21 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
|
|
|
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
|
|
|
|
|
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
|
|
- assertThat(requestMap, aMapWithSize(1));
|
|
|
+ assertThat(requestMap, aMapWithSize(2));
|
|
|
if (InputType.isSpecified(requestInputType)) {
|
|
|
var convertedInputType = convertToString(requestInputType);
|
|
|
- assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)))));
|
|
|
+ assertThat(
|
|
|
+ requestMap,
|
|
|
+ is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of()))
|
|
|
+ );
|
|
|
} else if (InputType.isSpecified(taskSettingsInputType)) {
|
|
|
var convertedInputType = convertToString(taskSettingsInputType);
|
|
|
- assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)))));
|
|
|
+ assertThat(
|
|
|
+ requestMap,
|
|
|
+ is(Map.of("instances", List.of(Map.of("content", "input", "task_type", convertedInputType)), "parameters", Map.of()))
|
|
|
+ );
|
|
|
} else {
|
|
|
- assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")))));
|
|
|
+ assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of())));
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -164,13 +218,16 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
|
|
|
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
|
|
|
|
|
|
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
|
|
- assertThat(requestMap, aMapWithSize(1));
|
|
|
+ assertThat(requestMap, aMapWithSize(2));
|
|
|
|
|
|
if (InputType.isSpecified(inputType)) {
|
|
|
var convertedInputType = convertToString(inputType);
|
|
|
- assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab", "task_type", convertedInputType)))));
|
|
|
+ assertThat(
|
|
|
+ requestMap,
|
|
|
+ is(Map.of("instances", List.of(Map.of("content", "ab", "task_type", convertedInputType)), "parameters", Map.of()))
|
|
|
+ );
|
|
|
} else {
|
|
|
- assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab")))));
|
|
|
+ assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab")), "parameters", Map.of())));
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -191,6 +248,40 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ private static GoogleVertexAiEmbeddingsRequest createRequestWithDimensions(
|
|
|
+ String modelId,
|
|
|
+ String input,
|
|
|
+ int dimensions,
|
|
|
+ @Nullable InputType requestInputType
|
|
|
+ ) {
|
|
|
+
|
|
|
+ var embeddingsModel = new GoogleVertexAiEmbeddingsModel(
|
|
|
+ "id",
|
|
|
+ TaskType.TEXT_EMBEDDING,
|
|
|
+ "service",
|
|
|
+ new GoogleVertexAiEmbeddingsServiceSettings(
|
|
|
+ randomAlphaOfLength(8),
|
|
|
+ randomAlphaOfLength(8),
|
|
|
+ modelId,
|
|
|
+ true,
|
|
|
+ null,
|
|
|
+ dimensions,
|
|
|
+ null,
|
|
|
+ null
|
|
|
+ ),
|
|
|
+ new GoogleVertexAiEmbeddingsTaskSettings(null, null),
|
|
|
+ null,
|
|
|
+ new GoogleVertexAiSecretSettings(new SecureString(randomAlphaOfLength(8).toCharArray()))
|
|
|
+ );
|
|
|
+
|
|
|
+ return new GoogleVertexAiEmbeddingsWithoutAuthRequest(
|
|
|
+ TruncatorTests.createTruncator(),
|
|
|
+ new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
|
|
|
+ requestInputType,
|
|
|
+ embeddingsModel
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* We use this class to fake the auth implementation to avoid static mocking of {@link GoogleVertexAiRequest}
|
|
|
*/
|