|
@@ -22,6 +22,7 @@ import org.elasticsearch.test.http.MockWebServer;
|
|
|
import org.elasticsearch.threadpool.ThreadPool;
|
|
|
import org.elasticsearch.xcontent.XContentType;
|
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
|
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
|
|
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
|
@@ -45,6 +46,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
|
|
import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
|
|
|
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
|
|
+import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
|
|
|
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
|
|
|
import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsRequestTaskSettingsTests.createRequestTaskSettingsMap;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
@@ -54,6 +56,11 @@ import static org.mockito.Mockito.mock;
|
|
|
|
|
|
public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|
|
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
|
|
+ private static final Settings ZERO_TIMEOUT_SETTINGS = buildSettingsWithRetryFields(
|
|
|
+ TimeValue.timeValueMillis(1),
|
|
|
+ TimeValue.timeValueMinutes(1),
|
|
|
+ TimeValue.timeValueSeconds(0)
|
|
|
+ );
|
|
|
private final MockWebServer webServer = new MockWebServer();
|
|
|
private ThreadPool threadPool;
|
|
|
private HttpClientManager clientManager;
|
|
@@ -116,7 +123,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|
|
validateRequestWithApiKey(webServer.requests().get(0), "apikey");
|
|
|
|
|
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
|
- validateRequestMapWithUser(requestMap, List.of("abc"), "overridden_user");
|
|
|
+ validateEmbeddingsRequestMapWithUser(requestMap, List.of("abc"), "overridden_user");
|
|
|
} catch (URISyntaxException e) {
|
|
|
throw new RuntimeException(e);
|
|
|
}
|
|
@@ -166,7 +173,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|
|
validateRequestWithApiKey(webServer.requests().get(0), "apikey");
|
|
|
|
|
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
|
- validateRequestMapWithUser(requestMap, List.of("abc"), null);
|
|
|
+ validateEmbeddingsRequestMapWithUser(requestMap, List.of("abc"), null);
|
|
|
} catch (URISyntaxException e) {
|
|
|
throw new RuntimeException(e);
|
|
|
}
|
|
@@ -174,12 +181,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|
|
|
|
|
public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException {
|
|
|
// timeout as zero for no retries
|
|
|
- var settings = buildSettingsWithRetryFields(
|
|
|
- TimeValue.timeValueMillis(1),
|
|
|
- TimeValue.timeValueMinutes(1),
|
|
|
- TimeValue.timeValueSeconds(0)
|
|
|
- );
|
|
|
- var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
|
|
|
|
|
|
try (var sender = senderFactory.createSender("test_service")) {
|
|
|
sender.start();
|
|
@@ -226,7 +228,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|
|
validateRequestWithApiKey(webServer.requests().get(0), "apikey");
|
|
|
|
|
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
|
- validateRequestMapWithUser(requestMap, List.of("abc"), "overridden_user");
|
|
|
+ validateEmbeddingsRequestMapWithUser(requestMap, List.of("abc"), "overridden_user");
|
|
|
} catch (URISyntaxException e) {
|
|
|
throw new RuntimeException(e);
|
|
|
}
|
|
@@ -295,13 +297,13 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|
|
validateRequestWithApiKey(webServer.requests().get(0), "apikey");
|
|
|
|
|
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
|
- validateRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user");
|
|
|
+ validateEmbeddingsRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user");
|
|
|
}
|
|
|
{
|
|
|
validateRequestWithApiKey(webServer.requests().get(1), "apikey");
|
|
|
|
|
|
var requestMap = entityAsMap(webServer.requests().get(1).getBody());
|
|
|
- validateRequestMapWithUser(requestMap, List.of("ab"), "overridden_user");
|
|
|
+ validateEmbeddingsRequestMapWithUser(requestMap, List.of("ab"), "overridden_user");
|
|
|
}
|
|
|
} catch (URISyntaxException e) {
|
|
|
throw new RuntimeException(e);
|
|
@@ -371,13 +373,13 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|
|
validateRequestWithApiKey(webServer.requests().get(0), "apikey");
|
|
|
|
|
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
|
- validateRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user");
|
|
|
+ validateEmbeddingsRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user");
|
|
|
}
|
|
|
{
|
|
|
validateRequestWithApiKey(webServer.requests().get(1), "apikey");
|
|
|
|
|
|
var requestMap = entityAsMap(webServer.requests().get(1).getBody());
|
|
|
- validateRequestMapWithUser(requestMap, List.of("ab"), "overridden_user");
|
|
|
+ validateEmbeddingsRequestMapWithUser(requestMap, List.of("ab"), "overridden_user");
|
|
|
}
|
|
|
} catch (URISyntaxException e) {
|
|
|
throw new RuntimeException(e);
|
|
@@ -429,13 +431,186 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|
|
validateRequestWithApiKey(webServer.requests().get(0), "apikey");
|
|
|
|
|
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
|
- validateRequestMapWithUser(requestMap, List.of("sup"), "overridden_user");
|
|
|
+ validateEmbeddingsRequestMapWithUser(requestMap, List.of("sup"), "overridden_user");
|
|
|
} catch (URISyntaxException e) {
|
|
|
throw new RuntimeException(e);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private void validateRequestMapWithUser(Map<String, Object> requestMap, List<String> input, @Nullable String user) {
|
|
|
+ public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+
|
|
|
+ try (var sender = senderFactory.createSender("test_service")) {
|
|
|
+ sender.start();
|
|
|
+
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "choices": [
|
|
|
+ {
|
|
|
+ "finish_reason": "stop",
|
|
|
+ "index": 0,
|
|
|
+ "logprobs": null,
|
|
|
+ "message": {
|
|
|
+ "content": "response",
|
|
|
+ "role": "assistant"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "model": "gpt-4",
|
|
|
+ "object": "chat.completion"
|
|
|
+ }""";
|
|
|
+
|
|
|
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
+
|
|
|
+ var originalUser = "original_user";
|
|
|
+ var overriddenUser = "overridden_user";
|
|
|
+ var apiKey = "api_key";
|
|
|
+ var completionInput = "some input";
|
|
|
+
|
|
|
+ var model = createCompletionModel("resource", "deployment", "apiversion", originalUser, apiKey, null, "id");
|
|
|
+ model.setUri(new URI(getUrl(webServer)));
|
|
|
+ var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
|
|
|
+ var taskSettingsWithUserOverride = createRequestTaskSettingsMap(overriddenUser);
|
|
|
+ var action = (AzureOpenAiCompletionAction) actionCreator.create(model, taskSettingsWithUserOverride);
|
|
|
+
|
|
|
+ PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
+ action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
|
|
+
|
|
|
+ var result = listener.actionGet(TIMEOUT);
|
|
|
+
|
|
|
+ assertThat(webServer.requests(), hasSize(1));
|
|
|
+
|
|
|
+ var request = webServer.requests().get(0);
|
|
|
+ var requestMap = entityAsMap(request.getBody());
|
|
|
+
|
|
|
+ assertThat(
|
|
|
+ result.asMap(),
|
|
|
+ is(Map.of(ChatCompletionResults.COMPLETION, List.of(Map.of(ChatCompletionResults.Result.RESULT, "response"))))
|
|
|
+ );
|
|
|
+ validateRequestWithApiKey(request, apiKey);
|
|
|
+ validateCompletionRequestMapWithUser(requestMap, List.of(completionInput), overriddenUser);
|
|
|
+
|
|
|
+ } catch (URISyntaxException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+
|
|
|
+ try (var sender = senderFactory.createSender("test_service")) {
|
|
|
+ sender.start();
|
|
|
+
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "choices": [
|
|
|
+ {
|
|
|
+ "finish_reason": "stop",
|
|
|
+ "index": 0,
|
|
|
+ "logprobs": null,
|
|
|
+ "message": {
|
|
|
+ "content": "response",
|
|
|
+ "role": "assistant"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "model": "gpt-4",
|
|
|
+ "object": "chat.completion"
|
|
|
+ }""";
|
|
|
+
|
|
|
+ var completionInput = "some input";
|
|
|
+ var apiKey = "api key";
|
|
|
+
|
|
|
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
+
|
|
|
+ var model = createCompletionModel("resource", "deployment", "apiversion", null, apiKey, null, "id");
|
|
|
+ model.setUri(new URI(getUrl(webServer)));
|
|
|
+ var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
|
|
|
+ var requestTaskSettingsWithoutUser = createRequestTaskSettingsMap(null);
|
|
|
+ var action = (AzureOpenAiCompletionAction) actionCreator.create(model, requestTaskSettingsWithoutUser);
|
|
|
+
|
|
|
+ PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
+ action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
|
|
+
|
|
|
+ var result = listener.actionGet(TIMEOUT);
|
|
|
+
|
|
|
+ assertThat(webServer.requests(), hasSize(1));
|
|
|
+
|
|
|
+ var request = webServer.requests().get(0);
|
|
|
+ var requestMap = entityAsMap(request.getBody());
|
|
|
+
|
|
|
+ assertThat(
|
|
|
+ result.asMap(),
|
|
|
+ is(Map.of(ChatCompletionResults.COMPLETION, List.of(Map.of(ChatCompletionResults.Result.RESULT, "response"))))
|
|
|
+ );
|
|
|
+ validateRequestWithApiKey(request, apiKey);
|
|
|
+ validateCompletionRequestMapWithUser(requestMap, List.of(completionInput), null);
|
|
|
+ } catch (URISyntaxException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testInfer_AzureOpenAiCompletionModel_FailsFromInvalidResponseFormat() throws IOException {
|
|
|
+ // timeout as zero for no retries
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
|
|
|
+
|
|
|
+ try (var sender = senderFactory.createSender("test_service")) {
|
|
|
+ sender.start();
|
|
|
+
|
|
|
+ // "choices" missing
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "not_choices": [
|
|
|
+ {
|
|
|
+ "finish_reason": "stop",
|
|
|
+ "index": 0,
|
|
|
+ "logprobs": null,
|
|
|
+ "message": {
|
|
|
+ "content": "response",
|
|
|
+ "role": "assistant"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "model": "gpt-4",
|
|
|
+ "object": "chat.completion"
|
|
|
+ }""";
|
|
|
+
|
|
|
+ var completionInput = "some input";
|
|
|
+ var apiKey = "api key";
|
|
|
+ var userOverride = "overridden_user";
|
|
|
+
|
|
|
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
+
|
|
|
+ var model = createCompletionModel("resource", "deployment", "apiversion", null, apiKey, null, "id");
|
|
|
+ model.setUri(new URI(getUrl(webServer)));
|
|
|
+ var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
|
|
|
+ var requestTaskSettingsWithoutUser = createRequestTaskSettingsMap(userOverride);
|
|
|
+ var action = (AzureOpenAiCompletionAction) actionCreator.create(model, requestTaskSettingsWithoutUser);
|
|
|
+
|
|
|
+ PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
+ action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
|
|
+
|
|
|
+ var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
|
|
+ assertThat(
|
|
|
+ thrownException.getMessage(),
|
|
|
+ is(format("Failed to send Azure OpenAI completion request to [%s]", getUrl(webServer)))
|
|
|
+ );
|
|
|
+ assertThat(
|
|
|
+ thrownException.getCause().getMessage(),
|
|
|
+ is("Failed to find required field [choices] in Azure OpenAI completions response")
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThat(webServer.requests(), hasSize(1));
|
|
|
+ validateRequestWithApiKey(webServer.requests().get(0), apiKey);
|
|
|
+
|
|
|
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
|
+ validateCompletionRequestMapWithUser(requestMap, List.of(completionInput), userOverride);
|
|
|
+ } catch (URISyntaxException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void validateEmbeddingsRequestMapWithUser(Map<String, Object> requestMap, List<String> input, @Nullable String user) {
|
|
|
var expectedSize = user == null ? 1 : 2;
|
|
|
|
|
|
assertThat(requestMap.size(), is(expectedSize));
|
|
@@ -446,6 +621,24 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private void validateCompletionRequestMapWithUser(Map<String, Object> requestMap, List<String> input, @Nullable String user) {
|
|
|
+ assertThat("input for completions can only be of size 1", input.size(), equalTo(1));
|
|
|
+
|
|
|
+ var expectedSize = user == null ? 2 : 3;
|
|
|
+
|
|
|
+ assertThat(requestMap.size(), is(expectedSize));
|
|
|
+ assertThat(getContentOfMessageInRequestMap(requestMap), is(input.get(0)));
|
|
|
+
|
|
|
+ if (user != null) {
|
|
|
+ assertThat(requestMap.get("user"), is(user));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public static String getContentOfMessageInRequestMap(Map<String, Object> requestMap) {
|
|
|
+ return ((Map<String, Object>) ((List<Object>) requestMap.get("messages")).get(0)).get("content").toString();
|
|
|
+ }
|
|
|
+
|
|
|
private void validateRequestWithApiKey(MockRequest request, String apiKey) {
|
|
|
assertNull(request.getUri().getQuery());
|
|
|
assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|