|
@@ -8,42 +8,28 @@
|
|
|
package org.elasticsearch.xpack.inference.services.mistral.action;
|
|
|
|
|
|
import org.apache.http.HttpHeaders;
|
|
|
-import org.elasticsearch.ElasticsearchException;
|
|
|
-import org.elasticsearch.ElasticsearchStatusException;
|
|
|
-import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.support.PlainActionFuture;
|
|
|
-import org.elasticsearch.common.settings.Settings;
|
|
|
-import org.elasticsearch.core.TimeValue;
|
|
|
import org.elasticsearch.inference.InferenceServiceResults;
|
|
|
-import org.elasticsearch.rest.RestStatus;
|
|
|
-import org.elasticsearch.test.ESTestCase;
|
|
|
import org.elasticsearch.test.http.MockRequest;
|
|
|
import org.elasticsearch.test.http.MockResponse;
|
|
|
-import org.elasticsearch.test.http.MockWebServer;
|
|
|
-import org.elasticsearch.threadpool.ThreadPool;
|
|
|
import org.elasticsearch.xcontent.XContentType;
|
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
|
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
|
|
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
|
|
|
-import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
|
|
-import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
|
|
-import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
|
|
+import org.elasticsearch.xpack.inference.services.ChatCompletionActionTests;
|
|
|
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
|
|
|
-import org.junit.After;
|
|
|
-import org.junit.Before;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
+import java.net.URISyntaxException;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
-import java.util.concurrent.TimeUnit;
|
|
|
|
|
|
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
|
|
-import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
|
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
|
|
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
|
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
|
@@ -56,64 +42,15 @@ import static org.elasticsearch.xpack.inference.services.mistral.completion.Mist
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
import static org.hamcrest.Matchers.hasSize;
|
|
|
import static org.hamcrest.Matchers.is;
|
|
|
-import static org.mockito.ArgumentMatchers.any;
|
|
|
-import static org.mockito.Mockito.doAnswer;
|
|
|
-import static org.mockito.Mockito.doThrow;
|
|
|
-import static org.mockito.Mockito.mock;
|
|
|
-
|
|
|
-public class MistralChatCompletionActionTests extends ESTestCase {
|
|
|
- private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
|
|
- private final MockWebServer webServer = new MockWebServer();
|
|
|
- private ThreadPool threadPool;
|
|
|
- private HttpClientManager clientManager;
|
|
|
-
|
|
|
- @Before
|
|
|
- public void init() throws Exception {
|
|
|
- webServer.start();
|
|
|
- threadPool = createThreadPool(inferenceUtilityPool());
|
|
|
- clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
|
|
- }
|
|
|
-
|
|
|
- @After
|
|
|
- public void shutdown() throws IOException {
|
|
|
- clientManager.close();
|
|
|
- terminate(threadPool);
|
|
|
- webServer.close();
|
|
|
- }
|
|
|
|
|
|
- public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
|
|
+public class MistralChatCompletionActionTests extends ChatCompletionActionTests {
|
|
|
+ public void testExecute_ReturnsSuccessfulResponse() throws IOException, URISyntaxException {
|
|
|
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
|
|
|
|
|
try (var sender = createSender(senderFactory)) {
|
|
|
sender.start();
|
|
|
|
|
|
- String responseJson = """
|
|
|
- {
|
|
|
- "id": "9d80f26810ac4e9582f927fcf0512ec7",
|
|
|
- "object": "chat.completion",
|
|
|
- "created": 1748596419,
|
|
|
- "model": "mistral-small-latest",
|
|
|
- "choices": [
|
|
|
- {
|
|
|
- "index": 0,
|
|
|
- "message": {
|
|
|
- "role": "assistant",
|
|
|
- "tool_calls": null,
|
|
|
- "content": "result content"
|
|
|
- },
|
|
|
- "finish_reason": "length",
|
|
|
- "logprobs": null
|
|
|
- }
|
|
|
- ],
|
|
|
- "usage": {
|
|
|
- "prompt_tokens": 10,
|
|
|
- "total_tokens": 11,
|
|
|
- "completion_tokens": 1
|
|
|
- }
|
|
|
- }
|
|
|
- """;
|
|
|
-
|
|
|
- webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson()));
|
|
|
|
|
|
var action = createAction(getUrl(webServer), sender);
|
|
|
|
|
@@ -140,87 +77,7 @@ public class MistralChatCompletionActionTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testExecute_ThrowsElasticsearchException() {
|
|
|
- var sender = mock(Sender.class);
|
|
|
- doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
|
|
|
-
|
|
|
- var action = createAction(getUrl(webServer), sender);
|
|
|
-
|
|
|
- PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
- action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
|
|
-
|
|
|
- var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
|
|
-
|
|
|
- assertThat(thrownException.getMessage(), is("failed"));
|
|
|
- }
|
|
|
-
|
|
|
- public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
|
|
|
- var sender = mock(Sender.class);
|
|
|
-
|
|
|
- doAnswer(invocation -> {
|
|
|
- ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
|
|
|
- listener.onFailure(new IllegalStateException("failed"));
|
|
|
-
|
|
|
- return Void.TYPE;
|
|
|
- }).when(sender).send(any(), any(), any(), any());
|
|
|
-
|
|
|
- var action = createAction(getUrl(webServer), sender);
|
|
|
-
|
|
|
- PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
- action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
|
|
-
|
|
|
- var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
|
|
-
|
|
|
- assertThat(thrownException.getMessage(), is("Failed to send mistral chat completions request. Cause: failed"));
|
|
|
- }
|
|
|
-
|
|
|
- public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
|
|
- var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
-
|
|
|
- try (var sender = createSender(senderFactory)) {
|
|
|
- sender.start();
|
|
|
-
|
|
|
- String responseJson = """
|
|
|
- {
|
|
|
- "id": "9d80f26810ac4e9582f927fcf0512ec7",
|
|
|
- "object": "chat.completion",
|
|
|
- "created": 1748596419,
|
|
|
- "model": "mistral-small-latest",
|
|
|
- "choices": [
|
|
|
- {
|
|
|
- "index": 0,
|
|
|
- "message": {
|
|
|
- "role": "assistant",
|
|
|
- "tool_calls": null,
|
|
|
- "content": "result content"
|
|
|
- },
|
|
|
- "finish_reason": "length",
|
|
|
- "logprobs": null
|
|
|
- }
|
|
|
- ],
|
|
|
- "usage": {
|
|
|
- "prompt_tokens": 10,
|
|
|
- "total_tokens": 11,
|
|
|
- "completion_tokens": 1
|
|
|
- }
|
|
|
- }
|
|
|
- """;
|
|
|
-
|
|
|
- webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
-
|
|
|
- var action = createAction(getUrl(webServer), sender);
|
|
|
-
|
|
|
- PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
- action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
|
|
-
|
|
|
- var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
|
|
-
|
|
|
- assertThat(thrownException.getMessage(), is("mistral chat completions only accepts 1 input"));
|
|
|
- assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- private ExecutableAction createAction(String url, Sender sender) {
|
|
|
+ protected ExecutableAction createAction(String url, Sender sender) {
|
|
|
var model = createCompletionModel("secret", "model");
|
|
|
model.setURI(url);
|
|
|
var manager = new GenericRequestManager<>(
|
|
@@ -233,4 +90,12 @@ public class MistralChatCompletionActionTests extends ESTestCase {
|
|
|
var errorMessage = constructFailedToSendRequestMessage("mistral chat completions");
|
|
|
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "mistral chat completions");
|
|
|
}
|
|
|
+
|
|
|
+ protected String getFailedToSendError() {
|
|
|
+ return "Failed to send mistral chat completions request. Cause: failed";
|
|
|
+ }
|
|
|
+
|
|
|
+ protected String getOneInputError() {
|
|
|
+ return "mistral chat completions only accepts 1 input";
|
|
|
+ }
|
|
|
}
|