瀏覽代碼

[ML] Migrate Alibaba Senders to SenderExecutableAction (#134515)

- Remove AlibabaCloudSearch*Action classes in favour of using
  SenderExecutableAction in AlibabaCloudSearchActionCreator
- Consolidate redundant GoogleDiscoveryEngineRateLimitServiceSettings
  and GoogleVertexAiEmbeddingsRateLimitServiceSettings interfaces with
  GoogleVertexAiRateLimitServiceSettings
- Rename fields on CustomRequestManager.RateLimitGrouping and
  AlibabaCloudSearchRequestManager.RateLimitGrouping to better reflect
  what they actually are
- Add tests for AlibabaCloudSearchActionCreator
- Migrate some existing tests to
  AlibabaCloudSearchCompletionRequestManagerTests
Donal Evans 1 月之前
父節點
當前提交
abd8b324ee
共有 22 個文件被更改,包括 522 次插入449 次删除
  1. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SenderExecutableAction.java
  2. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java
  3. 12 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManager.java
  4. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchRequestManager.java
  5. 24 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreator.java
  6. 0 72
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionAction.java
  7. 0 53
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchEmbeddingsAction.java
  8. 0 57
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchRerankAction.java
  9. 0 57
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchSparseAction.java
  10. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java
  11. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiRateLimitServiceSettings.java
  12. 3 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java
  13. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java
  14. 3 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java
  15. 0 15
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsRateLimitServiceSettings.java
  16. 2 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java
  17. 0 14
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleDiscoveryEngineRateLimitServiceSettings.java
  18. 3 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankModel.java
  19. 2 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankServiceSettings.java
  20. 105 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManagerTests.java
  21. 361 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java
  22. 0 162
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SenderExecutableAction.java

@@ -18,7 +18,7 @@ import java.util.Objects;
 
 import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
 
-public class SenderExecutableAction implements ExecutableAction {
+public sealed class SenderExecutableAction implements ExecutableAction permits SingleInputSenderExecutableAction {
 
     private final Sender sender;
     private final RequestManager requestManager;

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java

@@ -18,7 +18,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 
 import java.util.Objects;
 
-public class SingleInputSenderExecutableAction extends SenderExecutableAction {
+public final class SingleInputSenderExecutableAction extends SenderExecutableAction {
     private final String requestTypeForInputValidationError;
 
     public SingleInputSenderExecutableAction(

+ 12 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManager.java

@@ -9,8 +9,10 @@ package org.elasticsearch.xpack.inference.services.alibabacloudsearch;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
 import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
@@ -71,6 +73,16 @@ public class AlibabaCloudSearchCompletionRequestManager extends AlibabaCloudSear
         ActionListener<InferenceServiceResults> listener
     ) {
         List<String> input = inferenceInputs.castTo(ChatCompletionInput.class).getInputs();
+        if (input.size() % 2 == 0) {
+            listener.onFailure(
+                new ElasticsearchStatusException(
+                    "Alibaba Completion's inputs must be an odd number. The last input is the current query, "
+                        + "all preceding inputs are the completion history as pairs of user input and the assistant's response.",
+                    RestStatus.BAD_REQUEST
+                )
+            );
+            return;
+        }
         AlibabaCloudSearchCompletionRequest request = new AlibabaCloudSearchCompletionRequest(account, input, model);
         execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
     }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchRequestManager.java

@@ -18,7 +18,7 @@ abstract class AlibabaCloudSearchRequestManager extends BaseRequestManager {
         super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
     }
 
-    record RateLimitGrouping(int apiKeyHash) {
+    record RateLimitGrouping(int serviceSettingsHash) {
         public static RateLimitGrouping of(AlibabaCloudSearchModel model) {
             Objects.requireNonNull(model);
 

+ 24 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreator.java

@@ -8,8 +8,14 @@
 package org.elasticsearch.xpack.inference.services.alibabacloudsearch.action;
 
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchAccount;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchCompletionRequestManager;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchEmbeddingsRequestManager;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchRerankRequestManager;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchSparseRequestManager;
 import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
 import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
@@ -18,6 +24,8 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.Alib
 import java.util.Map;
 import java.util.Objects;
 
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
+
 /**
  * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the alibaba cloud search model type.
  */
@@ -33,28 +41,40 @@ public class AlibabaCloudSearchActionCreator implements AlibabaCloudSearchAction
     @Override
     public ExecutableAction create(AlibabaCloudSearchEmbeddingsModel model, Map<String, Object> taskSettings) {
         var overriddenModel = AlibabaCloudSearchEmbeddingsModel.of(model, taskSettings);
+        var account = new AlibabaCloudSearchAccount(overriddenModel.getSecretSettings().apiKey());
+        var requestManager = AlibabaCloudSearchEmbeddingsRequestManager.of(account, overriddenModel, serviceComponents.threadPool());
+        var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search text embeddings");
 
-        return new AlibabaCloudSearchEmbeddingsAction(sender, overriddenModel, serviceComponents);
+        return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
     }
 
     @Override
     public ExecutableAction create(AlibabaCloudSearchSparseModel model, Map<String, Object> taskSettings) {
         var overriddenModel = AlibabaCloudSearchSparseModel.of(model, taskSettings);
+        var account = new AlibabaCloudSearchAccount(overriddenModel.getSecretSettings().apiKey());
+        var requestManager = AlibabaCloudSearchSparseRequestManager.of(account, overriddenModel, serviceComponents.threadPool());
+        var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search sparse embeddings");
 
-        return new AlibabaCloudSearchSparseAction(sender, overriddenModel, serviceComponents);
+        return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
     }
 
     @Override
     public ExecutableAction create(AlibabaCloudSearchRerankModel model, Map<String, Object> taskSettings) {
         var overriddenModel = AlibabaCloudSearchRerankModel.of(model, taskSettings);
+        var account = new AlibabaCloudSearchAccount(overriddenModel.getSecretSettings().apiKey());
+        var requestManager = AlibabaCloudSearchRerankRequestManager.of(account, overriddenModel, serviceComponents.threadPool());
+        var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search rerank");
 
-        return new AlibabaCloudSearchRerankAction(sender, overriddenModel, serviceComponents);
+        return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
     }
 
     @Override
     public ExecutableAction create(AlibabaCloudSearchCompletionModel model, Map<String, Object> taskSettings) {
         var overriddenModel = AlibabaCloudSearchCompletionModel.of(model, taskSettings);
+        var account = new AlibabaCloudSearchAccount(overriddenModel.getSecretSettings().apiKey());
+        var requestManager = AlibabaCloudSearchCompletionRequestManager.of(account, overriddenModel, serviceComponents.threadPool());
+        var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search completion");
 
-        return new AlibabaCloudSearchCompletionAction(sender, overriddenModel, serviceComponents);
+        return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
     }
 }

+ 0 - 72
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionAction.java

@@ -1,72 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.alibabacloudsearch.action;
-
-import org.apache.logging.log4j.LogManager;
-import org.apache.logging.log4j.Logger;
-import org.elasticsearch.ElasticsearchStatusException;
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.InferenceServiceResults;
-import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
-import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
-import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
-import org.elasticsearch.xpack.inference.external.http.sender.Sender;
-import org.elasticsearch.xpack.inference.services.ServiceComponents;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchAccount;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchCompletionRequestManager;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
-
-import java.util.Objects;
-
-import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
-import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
-
-public class AlibabaCloudSearchCompletionAction implements ExecutableAction {
-    private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchCompletionAction.class);
-
-    private final AlibabaCloudSearchAccount account;
-    private final AlibabaCloudSearchCompletionModel model;
-    private final String failedToSendRequestErrorMessage;
-    private final Sender sender;
-    private final AlibabaCloudSearchCompletionRequestManager requestCreator;
-
-    public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompletionModel model, ServiceComponents serviceComponents) {
-        this.model = Objects.requireNonNull(model);
-        this.sender = Objects.requireNonNull(sender);
-        this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
-        this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search completion");
-        this.requestCreator = AlibabaCloudSearchCompletionRequestManager.of(account, model, serviceComponents.threadPool());
-    }
-
-    @Override
-    public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
-        var completionInput = inferenceInputs.castTo(ChatCompletionInput.class);
-        if (completionInput.getInputs().size() % 2 == 0) {
-            listener.onFailure(
-                new ElasticsearchStatusException(
-                    "Alibaba Completion's inputs must be an odd number. The last input is the current query, "
-                        + "all preceding inputs are the completion history as pairs of user input and the assistant's response.",
-                    RestStatus.BAD_REQUEST
-                )
-            );
-            return;
-        }
-
-        ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
-            failedToSendRequestErrorMessage,
-            listener
-        );
-        try {
-            sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
-        } catch (Exception e) {
-            wrappedListener.onFailure(e);
-        }
-    }
-}

+ 0 - 53
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchEmbeddingsAction.java

@@ -1,53 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.alibabacloudsearch.action;
-
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.InferenceServiceResults;
-import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
-import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
-import org.elasticsearch.xpack.inference.external.http.sender.Sender;
-import org.elasticsearch.xpack.inference.services.ServiceComponents;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchAccount;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchEmbeddingsRequestManager;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
-
-import java.util.Objects;
-
-import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
-import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
-
-public class AlibabaCloudSearchEmbeddingsAction implements ExecutableAction {
-    private final AlibabaCloudSearchAccount account;
-    private final AlibabaCloudSearchEmbeddingsModel model;
-    private final String failedToSendRequestErrorMessage;
-    private final Sender sender;
-    private final AlibabaCloudSearchEmbeddingsRequestManager requestCreator;
-
-    public AlibabaCloudSearchEmbeddingsAction(Sender sender, AlibabaCloudSearchEmbeddingsModel model, ServiceComponents serviceComponents) {
-        this.model = Objects.requireNonNull(model);
-        this.sender = Objects.requireNonNull(sender);
-        this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
-        this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search text embeddings");
-        this.requestCreator = AlibabaCloudSearchEmbeddingsRequestManager.of(account, model, serviceComponents.threadPool());
-    }
-
-    @Override
-    public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
-        ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
-            failedToSendRequestErrorMessage,
-            listener
-        );
-        try {
-            sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
-        } catch (Exception e) {
-            wrappedListener.onFailure(e);
-        }
-    }
-}

+ 0 - 57
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchRerankAction.java

@@ -1,57 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.alibabacloudsearch.action;
-
-import org.apache.logging.log4j.LogManager;
-import org.apache.logging.log4j.Logger;
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.InferenceServiceResults;
-import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
-import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
-import org.elasticsearch.xpack.inference.external.http.sender.Sender;
-import org.elasticsearch.xpack.inference.services.ServiceComponents;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchAccount;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchRerankRequestManager;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
-
-import java.util.Objects;
-
-import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
-import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
-
-public class AlibabaCloudSearchRerankAction implements ExecutableAction {
-    private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchRerankAction.class);
-
-    private final AlibabaCloudSearchAccount account;
-    private final AlibabaCloudSearchRerankModel model;
-    private final String failedToSendRequestErrorMessage;
-    private final Sender sender;
-    private final AlibabaCloudSearchRerankRequestManager requestCreator;
-
-    public AlibabaCloudSearchRerankAction(Sender sender, AlibabaCloudSearchRerankModel model, ServiceComponents serviceComponents) {
-        this.model = Objects.requireNonNull(model);
-        this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
-        this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search rerank");
-        this.sender = Objects.requireNonNull(sender);
-        this.requestCreator = AlibabaCloudSearchRerankRequestManager.of(account, model, serviceComponents.threadPool());
-    }
-
-    @Override
-    public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
-        ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
-            failedToSendRequestErrorMessage,
-            listener
-        );
-        try {
-            sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
-        } catch (Exception e) {
-            wrappedListener.onFailure(e);
-        }
-    }
-}

+ 0 - 57
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchSparseAction.java

@@ -1,57 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.alibabacloudsearch.action;
-
-import org.apache.logging.log4j.LogManager;
-import org.apache.logging.log4j.Logger;
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.InferenceServiceResults;
-import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
-import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
-import org.elasticsearch.xpack.inference.external.http.sender.Sender;
-import org.elasticsearch.xpack.inference.services.ServiceComponents;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchAccount;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchSparseRequestManager;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
-
-import java.util.Objects;
-
-import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
-import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
-
-public class AlibabaCloudSearchSparseAction implements ExecutableAction {
-    private static final Logger logger = LogManager.getLogger(AlibabaCloudSearchSparseAction.class);
-
-    private final AlibabaCloudSearchAccount account;
-    private final AlibabaCloudSearchSparseModel model;
-    private final String failedToSendRequestErrorMessage;
-    private final Sender sender;
-    private final AlibabaCloudSearchSparseRequestManager requestCreator;
-
-    public AlibabaCloudSearchSparseAction(Sender sender, AlibabaCloudSearchSparseModel model, ServiceComponents serviceComponents) {
-        this.model = Objects.requireNonNull(model);
-        this.account = new AlibabaCloudSearchAccount(this.model.getSecretSettings().apiKey());
-        this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("AlibabaCloud Search sparse embeddings");
-        this.sender = Objects.requireNonNull(sender);
-        requestCreator = AlibabaCloudSearchSparseRequestManager.of(account, model, serviceComponents.threadPool());
-    }
-
-    @Override
-    public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
-        ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
-            failedToSendRequestErrorMessage,
-            listener
-        );
-        try {
-            sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
-        } catch (Exception e) {
-            wrappedListener.onFailure(e);
-        }
-    }
-}

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java

@@ -36,7 +36,7 @@ import java.util.function.Supplier;
 public class CustomRequestManager extends BaseRequestManager {
     private static final Logger logger = LogManager.getLogger(CustomRequestManager.class);
 
-    record RateLimitGrouping(int apiKeyHash) {
+    record RateLimitGrouping(int serviceSettingsHash) {
         public static RateLimitGrouping of(CustomModel model) {
             Objects.requireNonNull(model);
 

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiRateLimitServiceSettings.java

@@ -13,4 +13,5 @@ public interface GoogleVertexAiRateLimitServiceSettings {
 
     RateLimitSettings rateLimitSettings();
 
+    String projectId();
 }

+ 3 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java

@@ -16,10 +16,10 @@ import org.elasticsearch.inference.UnifiedCompletionRequest;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel;
+import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
 import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils;
-import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleDiscoveryEngineRateLimitServiceSettings;
 
 import java.net.URI;
 import java.net.URISyntaxException;
@@ -126,8 +126,8 @@ public class GoogleVertexAiChatCompletionModel extends GoogleVertexAiModel {
     }
 
     @Override
-    public GoogleDiscoveryEngineRateLimitServiceSettings rateLimitServiceSettings() {
-        return (GoogleDiscoveryEngineRateLimitServiceSettings) super.rateLimitServiceSettings();
+    public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
+        return super.rateLimitServiceSettings();
     }
 
     @Override

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionServiceSettings.java

@@ -20,8 +20,8 @@ import org.elasticsearch.xcontent.ToXContent;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService;
-import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleDiscoveryEngineRateLimitServiceSettings;
 import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
 
@@ -36,7 +36,7 @@ import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVe
 public class GoogleVertexAiChatCompletionServiceSettings extends FilteredXContentObject
     implements
         ServiceSettings,
-        GoogleDiscoveryEngineRateLimitServiceSettings {
+        GoogleVertexAiRateLimitServiceSettings {
 
     public static final String NAME = "google_vertex_ai_chatcompletion_service_settings";
 

+ 3 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java

@@ -16,6 +16,7 @@ import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel;
+import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
 import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils;
@@ -126,8 +127,8 @@ public class GoogleVertexAiEmbeddingsModel extends GoogleVertexAiModel {
     }
 
     @Override
-    public GoogleVertexAiEmbeddingsRateLimitServiceSettings rateLimitServiceSettings() {
-        return (GoogleVertexAiEmbeddingsRateLimitServiceSettings) super.rateLimitServiceSettings();
+    public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
+        return super.rateLimitServiceSettings();
     }
 
     @Override

+ 0 - 15
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsRateLimitServiceSettings.java

@@ -1,15 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings;
-
-import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings;
-
-public interface GoogleVertexAiEmbeddingsRateLimitServiceSettings extends GoogleVertexAiRateLimitServiceSettings {
-
-    String projectId();
-}

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsServiceSettings.java

@@ -20,6 +20,7 @@ import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService;
 import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -42,7 +43,7 @@ import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVe
 public class GoogleVertexAiEmbeddingsServiceSettings extends FilteredXContentObject
     implements
         ServiceSettings,
-        GoogleVertexAiEmbeddingsRateLimitServiceSettings {
+        GoogleVertexAiRateLimitServiceSettings {
 
     public static final String NAME = "google_vertex_ai_embeddings_service_settings";
 

+ 0 - 14
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleDiscoveryEngineRateLimitServiceSettings.java

@@ -1,14 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.googlevertexai.rerank;
-
-import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings;
-
-public interface GoogleDiscoveryEngineRateLimitServiceSettings extends GoogleVertexAiRateLimitServiceSettings {
-    String projectId();
-}

+ 3 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankModel.java

@@ -15,6 +15,7 @@ import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel;
+import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.action.GoogleVertexAiActionVisitor;
 import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUtils;
@@ -111,8 +112,8 @@ public class GoogleVertexAiRerankModel extends GoogleVertexAiModel {
     }
 
     @Override
-    public GoogleDiscoveryEngineRateLimitServiceSettings rateLimitServiceSettings() {
-        return (GoogleDiscoveryEngineRateLimitServiceSettings) super.rateLimitServiceSettings();
+    public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
+        return super.rateLimitServiceSettings();
     }
 
     @Override

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankServiceSettings.java

@@ -17,6 +17,7 @@ import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ServiceSettings;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRateLimitServiceSettings;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService;
 import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -33,7 +34,7 @@ import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVe
 public class GoogleVertexAiRerankServiceSettings extends FilteredXContentObject
     implements
         ServiceSettings,
-        GoogleDiscoveryEngineRateLimitServiceSettings {
+        GoogleVertexAiRateLimitServiceSettings {
 
     public static final String NAME = "google_vertex_ai_rerank_service_settings";
 

+ 105 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManagerTests.java

@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkInferenceInput;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.request.completion.AlibabaCloudSearchCompletionRequest;
+import org.mockito.ArgumentCaptor;
+
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.rest.RestStatus.BAD_REQUEST;
+import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class AlibabaCloudSearchCompletionRequestManagerTests extends ESTestCase {
+    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+
+    public void testExecute_executesRequest() {
+        var inputs = new ChatCompletionInput(List.of("input1", "input2", "input3"));
+        RequestSender mockSender = mock(RequestSender.class);
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+
+        ExecutorService mockExecutorService = mock(ExecutorService.class);
+        var requestManager = createRequestManagerWithMockExecutor(mockExecutorService);
+        requestManager.execute(inputs, mockSender, () -> false, listener);
+
+        ArgumentCaptor<ExecutableInferenceRequest> captor = ArgumentCaptor.forClass(ExecutableInferenceRequest.class);
+        verify(mockExecutorService).execute(captor.capture());
+
+        ExecutableInferenceRequest executableRequest = captor.getValue();
+        assertThat(executableRequest.request(), is(instanceOf(AlibabaCloudSearchCompletionRequest.class)));
+        assertThat(executableRequest.responseHandler().getRequestType(), is("alibaba cloud search completion"));
+    }
+
+    public void testExecute_throwsElasticsearchStatusException_whenNumberOfInputsIsEven() {
+        var inputs = new ChatCompletionInput(List.of("input1", "input2"));
+        RequestSender mockSender = mock(RequestSender.class);
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+
+        var requestManager = createRequestManagerWithMockExecutor(mock(ExecutorService.class));
+        requestManager.execute(inputs, mockSender, () -> false, listener);
+
+        var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+
+        assertThat(thrownException.status(), is(BAD_REQUEST));
+        assertThat(thrownException.getMessage(), containsString("Alibaba Completion's inputs must be an odd number"));
+    }
+
+    public void testExecute_throwsIllegalArgumentException_whenInputIsNotChatCompletion() {
+        var inputs = new EmbeddingsInput(List.of(new ChunkInferenceInput("input1")), InputType.SEARCH);
+        RequestSender mockSender = mock(RequestSender.class);
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+
+        var requestManager = createRequestManagerWithMockExecutor(mock(ExecutorService.class));
+        var thrownException = expectThrows(
+            IllegalArgumentException.class,
+            () -> requestManager.execute(inputs, mockSender, () -> false, listener)
+        );
+
+        assertThat(thrownException.getMessage(), containsString("Unable to convert inference inputs type"));
+    }
+
+    private AlibabaCloudSearchCompletionRequestManager createRequestManagerWithMockExecutor(ExecutorService mockExecutorService) {
+        ThreadPool mockThreadPool = mock(ThreadPool.class);
+        when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService);
+
+        var model = AlibabaCloudSearchCompletionModelTests.createModel(
+            "completion_test",
+            TaskType.COMPLETION,
+            AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"),
+            AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null),
+            getSecretSettingsMap("secret")
+        );
+        var account = new AlibabaCloudSearchAccount(model.getSecretSettings().apiKey());
+        return AlibabaCloudSearchCompletionRequestManager.of(account, model, mockThreadPool);
+    }
+}

+ 361 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java

@@ -0,0 +1,361 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.alibabacloudsearch.action;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.WeightedToken;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests;
+import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.request.AlibabaCloudSearchUtils;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel;
+import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
+import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank;
+import static org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings;
+import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat;
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+
+public class AlibabaCloudSearchActionCreatorTests extends ESTestCase {
+
+    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+    private final MockWebServer webServer = new MockWebServer();
+    private Sender sender;
+    private ThreadPool threadPool;
+    private HttpClientManager clientManager;
+
+    @Before
+    public void init() throws IOException {
+        sender = mock(Sender.class);
+        webServer.start();
+        threadPool = createThreadPool(inferenceUtilityExecutors());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        clientManager.close();
+        terminate(threadPool);
+        webServer.close();
+    }
+
+    public void testExecute_withTextEmbeddingsAction_Success() {
+        float[] values = { 0.1111111f, 0.2222222f, 0.3333333f };
+        doAnswer(invocation -> {
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
+            listener.onResponse(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(values))));
+
+            return Void.TYPE;
+        }).when(sender).send(any(), any(), any(), any());
+        var action = createTextEmbeddingsAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+
+        var result = listener.actionGet(TIMEOUT);
+        assertThat(result.asMap(), is(buildExpectationFloat(List.of(values))));
+    }
+
+    public void testExecute_withTextEmbeddingsAction_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
+        doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
+        var action = createTextEmbeddingsAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(thrownException.getMessage(), is("error"));
+    }
+
+    public void testExecute_withTextEmbeddingsAction_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
+        doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
+        var action = createTextEmbeddingsAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+
+        var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search text embeddings request. Cause: error"));
+    }
+
+    public void testExecute_withSparseEmbeddingsAction_Success() {
+        String token = "token";
+        float weight = 0.1111111f;
+        boolean isTruncated = false;
+        doAnswer(invocation -> {
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
+            listener.onResponse(
+                new SparseEmbeddingResults(
+                    List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken(token, weight)), isTruncated))
+                )
+            );
+
+            return Void.TYPE;
+        }).when(sender).send(any(), any(), any(), any());
+        var action = createSparseEmbeddingsAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+
+        var result = listener.actionGet(TIMEOUT);
+        assertThat(
+            result.asMap(),
+            is(
+                buildExpectationSparseEmbeddings(
+                    List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(token, weight), isTruncated))
+                )
+            )
+        );
+    }
+
+    public void testExecute_withSparseEmbeddingsAction_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
+        doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
+        var action = createSparseEmbeddingsAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(thrownException.getMessage(), is("error"));
+    }
+
+    public void testExecute_withSparseEmbeddingsAction_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
+        doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
+        var action = createSparseEmbeddingsAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+
+        var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search sparse embeddings request. Cause: error"));
+    }
+
+    public void testExecute_withRerankAction_Success() {
+        int index = 0;
+        float relevanceScore = 0.1111111f;
+        doAnswer(invocation -> {
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
+            listener.onResponse(new RankedDocsResults(List.of(new RankedDocsResults.RankedDoc(index, relevanceScore, null))));
+
+            return Void.TYPE;
+        }).when(sender).send(any(), any(), any(), any());
+        var action = createRerankAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new QueryAndDocsInputs("query", List.of(randomAlphaOfLength(10))),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+
+        var result = listener.actionGet(TIMEOUT);
+        assertThat(
+            result.asMap(),
+            is(
+                buildExpectationRerank(
+                    List.of(new RankedDocsResultsTests.RerankExpectation(Map.of("index", index, "relevance_score", relevanceScore)))
+                )
+            )
+        );
+    }
+
+    public void testExecute_withRerankAction_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
+        doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
+        var action = createRerankAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new QueryAndDocsInputs("query", List.of(randomAlphaOfLength(10))),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(thrownException.getMessage(), is("error"));
+    }
+
+    public void testExecute_withRerankAction_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
+        doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
+        var action = createRerankAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(
+            new QueryAndDocsInputs("query", List.of(randomAlphaOfLength(10))),
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            listener
+        );
+
+        var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search rerank request. Cause: error"));
+    }
+
+    public void testExecute_withCompletionAction_Success() {
+        var resultString = randomAlphaOfLength(100);
+        doAnswer(invocation -> {
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
+            listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result(resultString))));
+
+            return Void.TYPE;
+        }).when(sender).send(any(), any(), any(), any());
+        var action = createCompletionAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var result = listener.actionGet(TIMEOUT);
+        assertThat(result.asMap(), is(buildExpectationCompletion(List.of(resultString))));
+    }
+
+    public void testExecute_withCompletionAction_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
+        doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
+        var action = createCompletionAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(thrownException.getMessage(), is("error"));
+    }
+
+    public void testExecute_withCompletionAction_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
+        doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
+        var action = createCompletionAction();
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search completion request. Cause: error"));
+    }
+
+    private ExecutableAction createTextEmbeddingsAction() {
+        var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
+        AlibabaCloudSearchEmbeddingsModel embeddingsModel = new AlibabaCloudSearchEmbeddingsModel(
+            "text_embedding_test",
+            TaskType.TEXT_EMBEDDING,
+            AlibabaCloudSearchUtils.SERVICE_NAME,
+            AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("text_embedding_test", "host", "default"),
+            null,
+            null,
+            DefaultSecretSettingsTests.getSecretSettingsMap("secret"),
+            null
+        );
+        var actionCreator = new AlibabaCloudSearchActionCreator(sender, serviceComponents);
+        return actionCreator.create(embeddingsModel, Map.of());
+    }
+
+    private ExecutableAction createSparseEmbeddingsAction() {
+        var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
+        AlibabaCloudSearchSparseModel sparseModel = new AlibabaCloudSearchSparseModel(
+            "sparse_embedding_test",
+            TaskType.SPARSE_EMBEDDING,
+            AlibabaCloudSearchUtils.SERVICE_NAME,
+            AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("sparse_embedding_test", "host", "default"),
+            null,
+            null,
+            DefaultSecretSettingsTests.getSecretSettingsMap("secret"),
+            null
+        );
+        var actionCreator = new AlibabaCloudSearchActionCreator(sender, serviceComponents);
+        return actionCreator.create(sparseModel, Map.of());
+    }
+
+    private ExecutableAction createRerankAction() {
+        var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
+        AlibabaCloudSearchRerankModel rerankModel = new AlibabaCloudSearchRerankModel(
+            "rerank_test",
+            TaskType.RERANK,
+            AlibabaCloudSearchUtils.SERVICE_NAME,
+            AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("rerank_test", "host", "default"),
+            null,
+            DefaultSecretSettingsTests.getSecretSettingsMap("secret"),
+            null
+        );
+        var actionCreator = new AlibabaCloudSearchActionCreator(sender, serviceComponents);
+        return actionCreator.create(rerankModel, Map.of());
+    }
+
+    private ExecutableAction createCompletionAction() {
+        var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
+        AlibabaCloudSearchCompletionModel completionModel = new AlibabaCloudSearchCompletionModel(
+            "completion_test",
+            TaskType.COMPLETION,
+            AlibabaCloudSearchUtils.SERVICE_NAME,
+            AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"),
+            null,
+            DefaultSecretSettingsTests.getSecretSettingsMap("secret"),
+            null
+        );
+        var actionCreator = new AlibabaCloudSearchActionCreator(sender, serviceComponents);
+        return actionCreator.create(completionModel, Map.of());
+    }
+}

+ 0 - 162
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java

@@ -1,162 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.alibabacloudsearch.action;
-
-import org.elasticsearch.ElasticsearchException;
-import org.elasticsearch.ElasticsearchStatusException;
-import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.InferenceServiceResults;
-import org.elasticsearch.inference.InputType;
-import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.test.http.MockWebServer;
-import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
-import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
-import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
-import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
-import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
-import org.elasticsearch.xpack.inference.external.http.sender.Sender;
-import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
-import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests;
-import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests;
-import org.junit.After;
-import org.junit.Before;
-
-import java.io.IOException;
-import java.util.List;
-import java.util.concurrent.TimeUnit;
-
-import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors;
-import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
-import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
-import static org.hamcrest.Matchers.is;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doThrow;
-import static org.mockito.Mockito.mock;
-
-public class AlibabaCloudSearchCompletionActionTests extends ESTestCase {
-
-    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
-    private final MockWebServer webServer = new MockWebServer();
-    private ThreadPool threadPool;
-    private HttpClientManager clientManager;
-
-    @Before
-    public void init() throws IOException {
-        webServer.start();
-        threadPool = createThreadPool(inferenceUtilityExecutors());
-        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
-    }
-
-    @After
-    public void shutdown() throws IOException {
-        clientManager.close();
-        terminate(threadPool);
-        webServer.close();
-    }
-
-    public void testExecute_Success() {
-        var sender = mock(Sender.class);
-
-        var resultString = randomAlphaOfLength(100);
-        doAnswer(invocation -> {
-            ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
-            listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result(resultString))));
-
-            return Void.TYPE;
-        }).when(sender).send(any(), any(), any(), any());
-        var action = createAction(threadPool, sender);
-
-        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
-
-        var result = listener.actionGet(TIMEOUT);
-        assertThat(result.asMap(), is(buildExpectationCompletion(List.of(resultString))));
-    }
-
-    public void testExecute_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
-        var sender = mock(Sender.class);
-        doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
-        var action = createAction(threadPool, sender);
-
-        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
-
-        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
-        assertThat(thrownException.getMessage(), is("error"));
-    }
-
-    public void testExecute_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
-        var sender = mock(Sender.class);
-        doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
-        var action = createAction(threadPool, sender);
-
-        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
-
-        var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
-        assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search completion request. Cause: error"));
-    }
-
-    public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatCompletionInput() {
-        var action = createAction(threadPool, mock(Sender.class));
-
-        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        assertThrows(IllegalArgumentException.class, () -> {
-            action.execute(
-                new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, InputType.INGEST),
-                InferenceAction.Request.DEFAULT_TIMEOUT,
-                listener
-            );
-        });
-    }
-
-    public void testExecute_ListenerThrowsElasticsearchStatusException_WhenInputSizeIsEven() {
-        var action = createAction(threadPool, mock(Sender.class));
-
-        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(
-            new ChatCompletionInput(List.of(randomAlphaOfLength(10), randomAlphaOfLength(10))),
-            InferenceAction.Request.DEFAULT_TIMEOUT,
-            listener
-        );
-
-        var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
-        assertThat(
-            thrownException.getMessage(),
-            is(
-                "Alibaba Completion's inputs must be an odd number. The last input is the current query, "
-                    + "all preceding inputs are the completion history as pairs of user input and the assistant's response."
-            )
-        );
-        assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
-    }
-
-    private ExecutableAction createAction(ThreadPool threadPool, Sender sender) {
-        var model = AlibabaCloudSearchCompletionModelTests.createModel(
-            "completion_test",
-            TaskType.COMPLETION,
-            AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"),
-            AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null),
-            getSecretSettingsMap("secret")
-        );
-
-        var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
-        return new AlibabaCloudSearchCompletionAction(sender, model, serviceComponents);
-    }
-}