Ver código fonte

Fix AlibabaCloudSearchCompletionAction not accepting ChatCompletionInputs (#125023)

* Fix AlibabaCloudSearchCompletionAction not accepting ChatCompletionInputs

* Update docs/changelog/125023.yaml

* Fix unit tests
Dan Rubinstein 6 meses atrás
pai
commit
52bc96240c

+ 5 - 0
docs/changelog/125023.yaml

@@ -0,0 +1,5 @@
+pr: 125023
+summary: Fix `AlibabaCloudSearchCompletionAction` not accepting `ChatCompletionInputs`
+area: Machine Learning
+type: bug
+issues: []

+ 3 - 15
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionAction.java

@@ -14,12 +14,11 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceServiceResults;
-import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
 import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
-import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
 import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
@@ -27,7 +26,6 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.
 
 import java.util.Objects;
 
-import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
 import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
 import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
@@ -51,18 +49,8 @@ public class AlibabaCloudSearchCompletionAction implements ExecutableAction {
 
     @Override
     public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
-        if (inferenceInputs instanceof EmbeddingsInput == false) {
-            listener.onFailure(
-                new ElasticsearchStatusException(
-                    format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
-                    RestStatus.INTERNAL_SERVER_ERROR
-                )
-            );
-            return;
-        }
-
-        var docsOnlyInput = (EmbeddingsInput) inferenceInputs;
-        if (docsOnlyInput.getInputs().size() % 2 == 0) {
+        var completionInput = inferenceInputs.castTo(ChatCompletionInput.class);
+        if (completionInput.getInputs().size() % 2 == 0) {
             listener.onFailure(
                 new ElasticsearchStatusException(
                     "Alibaba Completion's inputs must be an odd number. The last input is the current query, "

+ 165 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionActionTests.java

@@ -0,0 +1,165 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import static org.apache.lucene.tests.util.LuceneTestCase.expectThrows;
+import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
+import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+
+public class AlibabaCloudSearchCompletionActionTests extends ESTestCase {
+
+    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+    private final MockWebServer webServer = new MockWebServer();
+    private ThreadPool threadPool;
+    private HttpClientManager clientManager;
+
+    @Before
+    public void init() throws IOException {
+        webServer.start();
+        threadPool = createThreadPool(inferenceUtilityPool());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        clientManager.close();
+        terminate(threadPool);
+        webServer.close();
+    }
+
+    public void testExecute_Success() {
+        var sender = mock(Sender.class);
+
+        var resultString = randomAlphaOfLength(100);
+        doAnswer(invocation -> {
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
+            listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result(resultString))));
+
+            return Void.TYPE;
+        }).when(sender).send(any(), any(), any(), any());
+        var action = createAction(threadPool, sender);
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var result = listener.actionGet(TIMEOUT);
+        assertThat(result.asMap(), is(buildExpectationCompletion(List.of(resultString))));
+    }
+
+    public void testExecute_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
+        var sender = mock(Sender.class);
+        doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
+        var action = createAction(threadPool, sender);
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(thrownException.getMessage(), is("error"));
+    }
+
+    public void testExecute_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
+        var sender = mock(Sender.class);
+        doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
+        var action = createAction(threadPool, sender);
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(thrownException.getMessage(), is(constructFailedToSendRequestMessage("AlibabaCloud Search completion")));
+    }
+
+    public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatCompletionInput() {
+        var action = createAction(threadPool, mock(Sender.class));
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        assertThrows(IllegalArgumentException.class, () -> {
+            action.execute(
+                new EmbeddingsInput(List.of(randomAlphaOfLength(10)), InputType.INGEST),
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+        });
+    }
+
+    public void testExecute_ListenerThrowsElasticsearchStatusException_WhenInputSizeIsEven() {
+        var action = createAction(threadPool, mock(Sender.class));
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new ChatCompletionInput(List.of(randomAlphaOfLength(10), randomAlphaOfLength(10))),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+
+        var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(
+            thrownException.getMessage(),
+            is(
+                "Alibaba Completion's inputs must be an odd number. The last input is the current query, "
+                    + "all preceding inputs are the completion history as pairs of user input and the assistant's response."
+            )
+        );
+        assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
+    }
+
+    private ExecutableAction createAction(ThreadPool threadPool, Sender sender) {
+        var model = AlibabaCloudSearchCompletionModelTests.createModel(
+            "completion_test",
+            TaskType.COMPLETION,
+            AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"),
+            AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null),
+            getSecretSettingsMap("secret")
+        );
+
+        var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
+        return new AlibabaCloudSearchCompletionAction(sender, model, serviceComponents);
+    }
+}