Browse Source

Add Ibm Granite Completion and Chat Completion support (#129146)

* Add Ibm Granite Completion and Chat Completion support

* Apply suggestions

* remove ibm watsonx transport version constant

* update transport version
Evgenii-Kazannik 3 months ago
parent
commit
5d0c5e02bd
36 changed files with 1395 additions and 189 deletions
  1. 5 0
      docs/changelog/129146.yaml
  2. 1 1
      server/src/main/java/org/elasticsearch/TransportVersions.java
  3. 4 2
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java
  4. 8 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
  5. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java
  6. 51 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java
  7. 25 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxCompletionResponseHandler.java
  8. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java
  9. 13 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxModel.java
  10. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxRerankRequestManager.java
  11. 49 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java
  12. 49 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionCreator.java
  13. 29 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionVisitor.java
  14. 143 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionModel.java
  15. 193 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionServiceSettings.java
  16. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsServiceSettings.java
  17. 79 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequest.java
  18. 47 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestEntity.java
  19. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxUtils.java
  20. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankModel.java
  21. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankServiceSettings.java
  22. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java
  23. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxRankedResponseEntity.java
  24. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModel.java
  25. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java
  26. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java
  27. 154 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ChatCompletionActionTests.java
  28. 7 7
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
  29. 50 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxChatCompletionActionTests.java
  30. 3 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java
  31. 107 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionModelTests.java
  32. 173 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionServiceSettingsTests.java
  33. 66 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestEntityTests.java
  34. 106 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestTests.java
  35. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java
  36. 14 149
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java

+ 5 - 0
docs/changelog/129146.yaml

@@ -0,0 +1,5 @@
+pr: 129146
+summary: "[ML] Add IBM watsonx Completion and Chat Completion support to the Inference Plugin"
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 1
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -328,7 +328,7 @@ public class TransportVersions {
     public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00);
     public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00);
     public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00);
-
+    public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00);
     /*
      * STOP! READ THIS FIRST! No, really,
      *        ____ _____ ___  ____  _        ____  _____    _    ____    _____ _   _ ___ ____    _____ ___ ____  ____ _____ _

+ 4 - 2
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

@@ -151,7 +151,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
                     "completion_test_service",
                     "hugging_face",
                     "amazon_sagemaker",
-                    "mistral"
+                    "mistral",
+                    "watsonxai"
                 ).toArray()
             )
         );
@@ -169,7 +170,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
                     "hugging_face",
                     "amazon_sagemaker",
                     "googlevertexai",
-                    "mistral"
+                    "mistral",
+                    "watsonxai"
                 ).toArray()
             )
         );

+ 8 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

@@ -95,6 +95,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.Hugging
 import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
 import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
 import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionServiceSettings;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
@@ -469,6 +470,13 @@ public class InferenceNamedWriteablesProvider {
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
         );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                ServiceSettings.class,
+                IbmWatsonxChatCompletionServiceSettings.NAME,
+                IbmWatsonxChatCompletionServiceSettings::new
+            )
+        );
     }
 
     private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java

@@ -80,8 +80,8 @@ public class CohereRerankModel extends CohereModel {
 
     /**
      * Accepts a visitor to create an executable action. The returned action will not return documents in the response.
-     * @param visitor _
-     * @param taskSettings _
+     * @param visitor          Interface for creating {@link ExecutableAction} instances for Cohere models.
+     * @param taskSettings     Settings in the request to override the model's defaults
      * @return the rerank action
      */
     @Override

+ 51 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonUnifiedChatCompletionResponseHandler.java

@@ -0,0 +1,51 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx;
+
+import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
+import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
+
+import java.util.Locale;
+
+/**
+ * Handles streaming chat completion responses and error parsing for Watsonx inference endpoints.
+ * Adapts the OpenAI handler to support Watsonx's error schema.
+ */
+public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
+
+    private static final String WATSONX_ERROR = "watsonx_error";
+
+    public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
+        super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
+    }
+
+    @Override
+    protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
+        assert request.isStreaming() : "Only streaming requests support this format";
+        var responseStatusCode = result.response().getStatusLine().getStatusCode();
+        if (request.isStreaming()) {
+            var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
+            var restStatus = toRestStatus(responseStatusCode);
+            return errorResponse instanceof IbmWatsonxErrorResponseEntity
+                ? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT))
+                : new UnifiedChatCompletionException(
+                    restStatus,
+                    errorMessage,
+                    createErrorType(errorResponse),
+                    restStatus.name().toLowerCase(Locale.ROOT)
+                );
+        } else {
+            return super.buildError(message, request, result, errorResponse);
+        }
+    }
+}

+ 25 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxCompletionResponseHandler.java

@@ -0,0 +1,25 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx;
+
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
+import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
+
+public class IbmWatsonxCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
+
+    /**
+     * Constructs a IbmWatsonxCompletionResponseHandler with the specified request type and response parser.
+     *
+     * @param requestType The type of request being handled (e.g., "IBM watsonx completions").
+     * @param parseFunction The function to parse the response.
+     */
+    public IbmWatsonxCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
+        super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
+    }
+}

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java

@@ -35,7 +35,7 @@ public class IbmWatsonxEmbeddingsRequestManager extends IbmWatsonxRequestManager
     private static final ResponseHandler HANDLER = createEmbeddingsHandler();
 
     private static ResponseHandler createEmbeddingsHandler() {
-        return new IbmWatsonxResponseHandler("ibm watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
+        return new IbmWatsonxResponseHandler("IBM watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
     }
 
     private final IbmWatsonxEmbeddingsModel model;

+ 13 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxModel.java

@@ -7,18 +7,19 @@
 
 package org.elasticsearch.xpack.inference.services.ibmwatsonx;
 
-import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.ServiceSettings;
 import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
 
 import java.util.Map;
 import java.util.Objects;
 
-public abstract class IbmWatsonxModel extends Model {
+public abstract class IbmWatsonxModel extends RateLimitGroupingModel {
 
     private final IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings;
 
@@ -49,4 +50,14 @@ public abstract class IbmWatsonxModel extends Model {
     public IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings() {
         return rateLimitServiceSettings;
     }
+
+    @Override
+    public int rateLimitGroupingHash() {
+        return Objects.hash(this.rateLimitServiceSettings);
+    }
+
+    @Override
+    public RateLimitSettings rateLimitSettings() {
+        return this.rateLimitServiceSettings().rateLimitSettings();
+    }
 }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxRerankRequestManager.java

@@ -31,7 +31,7 @@ public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager {
 
     private static ResponseHandler createIbmWatsonxResponseHandler() {
         return new IbmWatsonxResponseHandler(
-            "ibm watsonx rerank",
+            "IBM watsonx rerank",
             (request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response)
         );
     }

+ 49 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java

@@ -30,7 +30,10 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
 import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
+import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
 import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
@@ -40,14 +43,18 @@ import org.elasticsearch.xpack.inference.services.SenderService;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
+import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
 
 import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
 import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
@@ -56,7 +63,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersi
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
 import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL;
 import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION;
 import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE;
@@ -66,8 +72,16 @@ public class IbmWatsonxService extends SenderService {
 
     public static final String NAME = "watsonxai";
 
-    private static final String SERVICE_NAME = "IBM Watsonx";
-    private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING);
+    private static final String SERVICE_NAME = "IBM watsonx";
+    private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
+        TaskType.TEXT_EMBEDDING,
+        TaskType.COMPLETION,
+        TaskType.CHAT_COMPLETION
+    );
+    private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new IbmWatsonUnifiedChatCompletionResponseHandler(
+        "IBM watsonx chat completions",
+        OpenAiChatCompletionResponseEntity::fromResponse
+    );
 
     public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
         super(factory, serviceComponents);
@@ -148,6 +162,14 @@ public class IbmWatsonxService extends SenderService {
                 secretSettings,
                 context
             );
+            case CHAT_COMPLETION, COMPLETION -> new IbmWatsonxChatCompletionModel(
+                inferenceEntityId,
+                taskType,
+                NAME,
+                serviceSettings,
+                secretSettings,
+                context
+            );
             default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
         };
     }
@@ -236,6 +258,11 @@ public class IbmWatsonxService extends SenderService {
         return TransportVersions.V_8_16_0;
     }
 
+    @Override
+    public Set<TaskType> supportedStreamingTasks() {
+        return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
+    }
+
     @Override
     public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
         if (model instanceof IbmWatsonxEmbeddingsModel embeddingsModel) {
@@ -291,7 +318,24 @@ public class IbmWatsonxService extends SenderService {
         TimeValue timeout,
         ActionListener<InferenceServiceResults> listener
     ) {
-        throwUnsupportedUnifiedCompletionOperation(NAME);
+        if (model instanceof IbmWatsonxChatCompletionModel == false) {
+            listener.onFailure(createInvalidModelException(model));
+            return;
+        }
+
+        IbmWatsonxChatCompletionModel ibmWatsonxChatCompletionModel = (IbmWatsonxChatCompletionModel) model;
+        var overriddenModel = IbmWatsonxChatCompletionModel.of(ibmWatsonxChatCompletionModel, inputs.getRequest());
+        var manager = new GenericRequestManager<>(
+            getServiceComponents().threadPool(),
+            overriddenModel,
+            UNIFIED_CHAT_COMPLETION_HANDLER,
+            unifiedChatInput -> new IbmWatsonxChatCompletionRequest(unifiedChatInput, overriddenModel),
+            UnifiedChatInput.class
+        );
+        var errorMessage = IbmWatsonxActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
+        var action = new SenderExecutableAction(getSender(), manager, errorMessage);
+
+        action.execute(inputs, timeout, listener);
     }
 
     @Override
@@ -331,7 +375,7 @@ public class IbmWatsonxService extends SenderService {
 
                 configurationMap.put(
                     API_VERSION,
-                    new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM Watsonx API version ID to use.")
+                    new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM watsonx API version ID to use.")
                         .setLabel("API Version")
                         .setRequired(true)
                         .setSensitive(false)

+ 49 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionCreator.java

@@ -7,26 +7,48 @@
 
 package org.elasticsearch.xpack.inference.services.ibmwatsonx.action;
 
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.inference.common.Truncator;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxCompletionResponseHandler;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxEmbeddingsRequestManager;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxRerankRequestManager;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
+import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
 
 import java.util.Map;
 import java.util.Objects;
 
+import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
 
+/**
+ * IbmWatsonxActionCreator is responsible for creating executable actions for various models.
+ * It implements the IbmWatsonxActionVisitor interface to provide specific implementations.
+ */
 public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
     private final Sender sender;
     private final ServiceComponents serviceComponents;
 
+    static final String COMPLETION_REQUEST_TYPE = "IBM watsonx completions";
+    static final String USER_ROLE = "user";
+    static final ResponseHandler COMPLETION_HANDLER = new IbmWatsonxCompletionResponseHandler(
+        COMPLETION_REQUEST_TYPE,
+        OpenAiChatCompletionResponseEntity::fromResponse
+    );
+
     public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponents) {
         this.sender = Objects.requireNonNull(sender);
         this.serviceComponents = Objects.requireNonNull(serviceComponents);
@@ -34,7 +56,7 @@ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
 
     @Override
     public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings) {
-        var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM WatsonX embeddings");
+        var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM watsonx embeddings");
         return new SenderExecutableAction(
             sender,
             getEmbeddingsRequestManager(model, serviceComponents.truncator(), serviceComponents.threadPool()),
@@ -46,10 +68,24 @@ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
     public ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings) {
         var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings);
         var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
-        var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Ibm Watsonx rerank");
+        var failedToSendRequestErrorMessage = buildErrorMessage(TaskType.RERANK, overriddenModel.getInferenceEntityId());
         return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
     }
 
+    @Override
+    public ExecutableAction create(IbmWatsonxChatCompletionModel chatCompletionModel) {
+        var manager = new GenericRequestManager<>(
+            serviceComponents.threadPool(),
+            chatCompletionModel,
+            COMPLETION_HANDLER,
+            inputs -> new IbmWatsonxChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), chatCompletionModel),
+            ChatCompletionInput.class
+        );
+
+        var failedToSendRequestErrorMessage = buildErrorMessage(TaskType.COMPLETION, chatCompletionModel.getInferenceEntityId());
+        return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_REQUEST_TYPE);
+    }
+
     protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager(
         IbmWatsonxEmbeddingsModel model,
         Truncator truncator,
@@ -57,4 +93,15 @@ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
     ) {
         return new IbmWatsonxEmbeddingsRequestManager(model, truncator, threadPool);
     }
+
+    /**
+     * Builds an error message for IBM watsonx actions.
+     *
+     * @param requestType The type of request (e.g. COMPLETION, EMBEDDING, RERANK).
+     * @param inferenceId The ID of the inference entity.
+     * @return A formatted error message.
+     */
+    public static String buildErrorMessage(TaskType requestType, String inferenceId) {
+        return format("Failed to send IBM watsonx %s request from inference entity id [%s]", requestType.toString(), inferenceId);
+    }
 }

+ 29 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxActionVisitor.java

@@ -8,13 +8,42 @@
 package org.elasticsearch.xpack.inference.services.ibmwatsonx.action;
 
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
 
 import java.util.Map;
 
+/**
+ * Interface for creating {@link ExecutableAction} instances for IBM watsonx models.
+ * <p>
+ * This interface is used to create {@link ExecutableAction} instances for different types of IBM watsonx models, such as
+ * {@link IbmWatsonxEmbeddingsModel} and {@link IbmWatsonxRerankModel} and {@link IbmWatsonxChatCompletionModel}.
+ */
 public interface IbmWatsonxActionVisitor {
+
+    /**
+     * Creates an {@link ExecutableAction} for the given {@link IbmWatsonxEmbeddingsModel}.
+     *
+     * @param model The model to create the action for.
+     * @param taskSettings    The task settings to use.
+     * @return An {@link ExecutableAction} for the given model.
+     */
     ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);
 
+    /**
+     * Creates an {@link ExecutableAction} for the given {@link IbmWatsonxRerankModel}.
+     *
+     * @param model The model to create the action for.
+     * @return An {@link ExecutableAction} for the given model.
+     */
     ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings);
+
+    /**
+     * Creates an {@link ExecutableAction} for the given {@link IbmWatsonxChatCompletionModel}.
+     *
+     * @param model The model to create the action for.
+     * @return An {@link ExecutableAction} for the given model.
+     */
+    ExecutableAction create(IbmWatsonxChatCompletionModel model);
 }

+ 143 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionModel.java

@@ -0,0 +1,143 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
+
+import org.apache.http.client.utils.URIBuilder;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxModel;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.COMPLETIONS;
+import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.ML;
+import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.TEXT;
+import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.V1;
+
+public class IbmWatsonxChatCompletionModel extends IbmWatsonxModel {
+
+    /**
+     * Constructor for IbmWatsonxChatCompletionModel.
+     *
+     * @param inferenceEntityId The unique identifier for the inference entity.
+     * @param taskType The type of task this model is designed for.
+     * @param service The name of the service this model belongs to.
+     * @param serviceSettings The settings specific to the Ibm Granite chat completion service.
+     * @param secrets The secrets required for accessing the service.
+     * @param context The context for parsing configuration settings.
+     */
+    public IbmWatsonxChatCompletionModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        Map<String, Object> serviceSettings,
+        @Nullable Map<String, Object> secrets,
+        ConfigurationParseContext context
+    ) {
+        this(
+            inferenceEntityId,
+            taskType,
+            service,
+            IbmWatsonxChatCompletionServiceSettings.fromMap(serviceSettings, context),
+            DefaultSecretSettings.fromMap(secrets)
+        );
+    }
+
+    /**
+     * Creates a new IbmWatsonxChatCompletionModel with overridden service settings.
+     *
+     * @param model The original IbmWatsonxChatCompletionModel.
+     * @param request The UnifiedCompletionRequest containing the model override.
+     * @return A new IbmWatsonxChatCompletionModel with the overridden model ID.
+     */
+    public static IbmWatsonxChatCompletionModel of(IbmWatsonxChatCompletionModel model, UnifiedCompletionRequest request) {
+        if (request.model() == null) {
+            // If no model is specified in the request, return the original model
+            return model;
+        }
+
+        var originalModelServiceSettings = model.getServiceSettings();
+        var overriddenServiceSettings = new IbmWatsonxChatCompletionServiceSettings(
+            originalModelServiceSettings.uri(),
+            originalModelServiceSettings.apiVersion(),
+            request.model(),
+            originalModelServiceSettings.projectId(),
+            originalModelServiceSettings.rateLimitSettings()
+        );
+
+        return new IbmWatsonxChatCompletionModel(
+            model.getInferenceEntityId(),
+            model.getTaskType(),
+            model.getConfigurations().getService(),
+            overriddenServiceSettings,
+            model.getSecretSettings()
+        );
+    }
+
+    // should only be used for testing
+    IbmWatsonxChatCompletionModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        IbmWatsonxChatCompletionServiceSettings serviceSettings,
+        @Nullable DefaultSecretSettings secretSettings
+    ) {
+        super(
+            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings),
+            new ModelSecrets(secretSettings),
+            serviceSettings
+        );
+    }
+
+    @Override
+    public IbmWatsonxChatCompletionServiceSettings getServiceSettings() {
+        return (IbmWatsonxChatCompletionServiceSettings) super.getServiceSettings();
+    }
+
+    @Override
+    public DefaultSecretSettings getSecretSettings() {
+        return (DefaultSecretSettings) super.getSecretSettings();
+    }
+
+    public URI uri() {
+        URI uri;
+        try {
+            uri = buildUri(this.getServiceSettings().uri().toString(), this.getServiceSettings().apiVersion());
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+
+        return uri;
+    }
+
+    /**
+     * Accepts a visitor to create an executable action. The returned action will not return documents in the response.
+     * @param visitor          Interface for creating {@link ExecutableAction} instances for IBM watsonx models.
+     * @return the completion action
+     */
+    public ExecutableAction accept(IbmWatsonxActionVisitor visitor, Map<String, Object> taskSettings) {
+        return visitor.create(this);
+    }
+
+    public static URI buildUri(String uri, String apiVersion) throws URISyntaxException {
+        return new URIBuilder().setScheme("https")
+            .setHost(uri)
+            .setPathSegments(ML, V1, TEXT, COMPLETIONS)
+            .setParameter("version", apiVersion)
+            .build();
+    }
+}

+ 193 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionServiceSettings.java

@@ -0,0 +1,193 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxRateLimitServiceSettings;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
+import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION;
+import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.PROJECT_ID;
+
+public class IbmWatsonxChatCompletionServiceSettings extends FilteredXContentObject
+    implements
+        ServiceSettings,
+        IbmWatsonxRateLimitServiceSettings {
+    public static final String NAME = "ibm_watsonx_completion_service_settings";
+
+    /**
+     * Rate limits are defined at
+     * <a href="https://www.ibm.com/docs/en/watsonx/saas?topic=learning-watson-machine-plans">Watson Machine Learning plans</a>.
+     * For the Lite plan, the limit is 120 requests per minute.
+     */
+    private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
+
+    public static IbmWatsonxChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
+        ValidationException validationException = new ValidationException();
+
+        String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
+
+        String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        String projectId = extractRequiredString(map, PROJECT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
+
+        RateLimitSettings rateLimitSettings = RateLimitSettings.of(
+            map,
+            DEFAULT_RATE_LIMIT_SETTINGS,
+            validationException,
+            IbmWatsonxService.NAME,
+            context
+        );
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new IbmWatsonxChatCompletionServiceSettings(uri, apiVersion, modelId, projectId, rateLimitSettings);
+    }
+
+    private final URI uri;
+
+    private final String apiVersion;
+
+    private final String modelId;
+
+    private final String projectId;
+
+    private final RateLimitSettings rateLimitSettings;
+
+    public IbmWatsonxChatCompletionServiceSettings(
+        URI uri,
+        String apiVersion,
+        String modelId,
+        String projectId,
+        @Nullable RateLimitSettings rateLimitSettings
+    ) {
+        this.uri = uri;
+        this.apiVersion = apiVersion;
+        this.projectId = projectId;
+        this.modelId = modelId;
+        this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
+    }
+
+    public IbmWatsonxChatCompletionServiceSettings(StreamInput in) throws IOException {
+        this.uri = createUri(in.readString());
+        this.apiVersion = in.readString();
+        this.modelId = in.readString();
+        this.projectId = in.readString();
+        this.rateLimitSettings = new RateLimitSettings(in);
+
+    }
+
+    public URI uri() {
+        return uri;
+    }
+
+    public String apiVersion() {
+        return apiVersion;
+    }
+
+    @Override
+    public String modelId() {
+        return modelId;
+    }
+
+    public String projectId() {
+        return projectId;
+    }
+
+    @Override
+    public RateLimitSettings rateLimitSettings() {
+        return rateLimitSettings;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        toXContentFragmentOfExposedFields(builder, params);
+
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
+        builder.field(URL, uri.toString());
+
+        builder.field(API_VERSION, apiVersion);
+
+        builder.field(MODEL_ID, modelId);
+
+        builder.field(PROJECT_ID, projectId);
+
+        rateLimitSettings.toXContent(builder, params);
+
+        return builder;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(uri.toString());
+        out.writeString(apiVersion);
+
+        out.writeString(modelId);
+        out.writeString(projectId);
+
+        rateLimitSettings.writeTo(out);
+    }
+
+    @Override
+    public boolean equals(Object object) {
+        if (this == object) return true;
+        if (object == null || getClass() != object.getClass()) return false;
+        IbmWatsonxChatCompletionServiceSettings that = (IbmWatsonxChatCompletionServiceSettings) object;
+        return Objects.equals(uri, that.uri)
+            && Objects.equals(apiVersion, that.apiVersion)
+            && Objects.equals(modelId, that.modelId)
+            && Objects.equals(projectId, that.projectId)
+            && Objects.equals(rateLimitSettings, that.rateLimitSettings);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(uri, apiVersion, modelId, projectId, rateLimitSettings);
+    }
+}

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsServiceSettings.java

@@ -52,7 +52,7 @@ public class IbmWatsonxEmbeddingsServiceSettings extends FilteredXContentObject
     /**
      * Rate limits are defined at
      * <a href="https://www.ibm.com/docs/en/watsonx/saas?topic=learning-watson-machine-plans">Watson Machine Learning plans</a>.
-     * For Lite plan, you've 120 requests per minute.
+     * For the Lite plan, the limit is 120 requests per minute.
      */
     private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
 

+ 79 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequest.java

@@ -0,0 +1,79 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.ByteArrayEntity;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
+
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+import java.util.Objects;
+
+public class IbmWatsonxChatCompletionRequest implements IbmWatsonxRequest {
+    private final IbmWatsonxChatCompletionModel model;
+    private final UnifiedChatInput chatInput;
+
+    public IbmWatsonxChatCompletionRequest(UnifiedChatInput chatInput, IbmWatsonxChatCompletionModel model) {
+        this.chatInput = Objects.requireNonNull(chatInput);
+        this.model = Objects.requireNonNull(model);
+    }
+
+    @Override
+    public HttpRequest createHttpRequest() {
+        HttpPost httpPost = new HttpPost(model.uri());
+
+        ByteArrayEntity byteEntity = new ByteArrayEntity(
+            Strings.toString(new IbmWatsonxChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8)
+        );
+        httpPost.setEntity(byteEntity);
+
+        httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
+
+        decorateWithAuth(httpPost);
+
+        return new HttpRequest(httpPost, getInferenceEntityId());
+    }
+
+    @Override
+    public URI getURI() {
+        return model.uri();
+    }
+
+    public void decorateWithAuth(HttpPost httpPost) {
+        IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());
+    }
+
+    @Override
+    public Request truncate() {
+        // No truncation for IBM watsonx chat completions
+        return this;
+    }
+
+    @Override
+    public boolean[] getTruncationInfo() {
+        // No truncation for IBM watsonx chat completions
+        return null;
+    }
+
+    @Override
+    public String getInferenceEntityId() {
+        return model.getInferenceEntityId();
+    }
+
+    @Override
+    public boolean isStreaming() {
+        return chatInput.stream();
+    }
+}

+ 47 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestEntity.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
+
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * IbmWatsonxChatCompletionRequestEntity is responsible for creating the request entity for Watsonx chat completion.
+ * It implements ToXContentObject to allow serialization to XContent format.
+ */
+public class IbmWatsonxChatCompletionRequestEntity implements ToXContentObject {
+
+    private final IbmWatsonxChatCompletionModel model;
+    private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
+
+    private static final String PROJECT_ID_FIELD = "project_id";
+
+    public IbmWatsonxChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, IbmWatsonxChatCompletionModel model) {
+        this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
+        this.model = Objects.requireNonNull(model);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(PROJECT_ID_FIELD, model.getServiceSettings().projectId());
+        unifiedRequestEntity.toXContent(
+            builder,
+            UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(model.getServiceSettings().modelId(), params)
+        );
+        builder.endObject();
+        return builder;
+    }
+}

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxUtils.java

@@ -14,6 +14,7 @@ public class IbmWatsonxUtils {
     public static final String TEXT = "text";
     public static final String EMBEDDINGS = "embeddings";
     public static final String RERANKS = "reranks";
+    public static final String COMPLETIONS = "chat";
 
     private IbmWatsonxUtils() {}
 

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankModel.java

@@ -100,8 +100,8 @@ public class IbmWatsonxRerankModel extends IbmWatsonxModel {
 
     /**
      * Accepts a visitor to create an executable action. The returned action will not return documents in the response.
-     * @param visitor _
-     * @param taskSettings _
+     * @param visitor          Interface for creating {@link ExecutableAction} instances for IBM watsonx models.
+     * @param taskSettings     Settings in the request to override the model's defaults
      * @return the rerank action
      */
     @Override

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankServiceSettings.java

@@ -41,7 +41,7 @@ public class IbmWatsonxRerankServiceSettings extends FilteredXContentObject impl
     /**
      * Rate limits are defined at
      * <a href="https://www.ibm.com/docs/en/watsonx/saas?topic=learning-watson-machine-plans">Watson Machine Learning plans</a>.
-     * For Lite plan, you've 120 requests per minute.
+     * For the Lite plan, the limit is 120 requests per minute.
      */
     private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
 

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntity.java

@@ -28,7 +28,7 @@ import static org.elasticsearch.xpack.inference.external.response.XContentUtils.
 
 public class IbmWatsonxEmbeddingsResponseEntity {
 
-    private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM Watsonx embeddings response";
+    private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM watsonx embeddings response";
 
     public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
         var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxRankedResponseEntity.java

@@ -32,7 +32,7 @@ public class IbmWatsonxRankedResponseEntity {
     private static final Logger logger = LogManager.getLogger(IbmWatsonxRankedResponseEntity.class);
 
     /**
-     * Parses the Ibm Watsonx ranked response.
+     * Parses the IBM watsonx ranked response.
      *
      * For a request like:
      *     "model": "rerank-english-v2.0",
@@ -71,7 +71,7 @@ public class IbmWatsonxRankedResponseEntity {
      *   ],
      *   }
      *
-     * @param response the http response from ibm watsonx
+     * @param response the http response from IBM watsonx
      * @return the parsed response
      * @throws IOException if there is an error parsing the response
      */

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/rerank/JinaAIRerankModel.java

@@ -84,8 +84,8 @@ public class JinaAIRerankModel extends JinaAIModel {
 
     /**
      * Accepts a visitor to create an executable action. The returned action will not return documents in the response.
-     * @param visitor _
-     * @param taskSettings _
+     * @param visitor          Interface for creating {@link ExecutableAction} instances for Jina AI models.
+     * @param taskSettings     Settings in the request to override the model's defaults
      * @return the rerank action
      */
     @Override

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java

@@ -37,8 +37,8 @@ import static org.elasticsearch.core.Strings.format;
 public class MistralActionCreator implements MistralActionVisitor {
 
     public static final String COMPLETION_ERROR_PREFIX = "Mistral completions";
-    static final String USER_ROLE = "user";
-    static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler(
+    public static final String USER_ROLE = "user";
+    public static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler(
         "mistral completions",
         OpenAiChatCompletionResponseEntity::fromResponse
     );

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/rerank/VoyageAIRerankModel.java

@@ -109,8 +109,8 @@ public class VoyageAIRerankModel extends VoyageAIModel {
 
     /**
      * Accepts a visitor to create an executable action. The returned action will not return documents in the response.
-     * @param visitor _
-     * @param taskSettings _
+     * @param visitor          Interface for creating {@link ExecutableAction} instances for Voyage AI models.
+     * @param taskSettings     Settings in the request to override the model's defaults
      * @return the rerank action
      */
     @Override

+ 154 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ChatCompletionActionTests.java

@@ -0,0 +1,154 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+
+public abstract class ChatCompletionActionTests extends ESTestCase {
+    protected static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+    protected final MockWebServer webServer = new MockWebServer();
+    protected HttpClientManager clientManager;
+    protected ThreadPool threadPool;
+
+    protected abstract ExecutableAction createAction(String url, Sender sender) throws URISyntaxException;
+
+    protected abstract String getOneInputError();
+
+    protected abstract String getFailedToSendError();
+
+    @Before
+    public void init() throws Exception {
+        webServer.start();
+        threadPool = createThreadPool(inferenceUtilityPool());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        clientManager.close();
+        terminate(threadPool);
+        webServer.close();
+    }
+
+    public void testExecute_ThrowsElasticsearchException() throws URISyntaxException {
+        var sender = mock(Sender.class);
+        doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
+
+        var action = createAction(getUrl(webServer), sender);
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+        assertThat(thrownException.getMessage(), is("failed"));
+    }
+
+    public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() throws URISyntaxException {
+        var sender = mock(Sender.class);
+
+        doAnswer(invocation -> {
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
+            listener.onFailure(new IllegalStateException("failed"));
+
+            return Void.TYPE;
+        }).when(sender).send(any(), any(), any(), any());
+
+        var action = createAction(getUrl(webServer), sender);
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+        assertThat(thrownException.getMessage(), is(getFailedToSendError()));
+    }
+
+    public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException, URISyntaxException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = createSender(senderFactory)) {
+            sender.start();
+
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson()));
+
+            var action = createAction(getUrl(webServer), sender);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+
+            assertThat(thrownException.getMessage(), is(getOneInputError()));
+            assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
+        }
+    }
+
+    protected String getResponseJson() {
+        return """
+            {
+                 "id": "9d80f26810ac4e9582f927fcf0512ec7",
+                 "object": "chat.completion",
+                 "created": 1748596419,
+                 "model": "modelId",
+                 "choices": [
+                     {
+                         "index": 0,
+                         "message": {
+                             "role": "assistant",
+                             "tool_calls": null,
+                             "content": "result content"
+                         },
+                         "finish_reason": "length",
+                         "logprobs": null
+                     }
+                 ],
+                 "usage": {
+                     "prompt_tokens": 10,
+                     "total_tokens": 11,
+                     "completion_tokens": 1
+                 }
+             }
+            """;
+    }
+}

+ 7 - 7
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java

@@ -918,8 +918,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
             String content = XContentHelper.stripWhitespace("""
                 {
                        "service": "watsonxai",
-                       "name": "IBM Watsonx",
-                       "task_types": ["text_embedding"],
+                       "name": "IBM watsonx",
+                       "task_types": ["text_embedding", "completion", "chat_completion"],
                        "configurations": {
                            "project_id": {
                                "description": "",
@@ -928,7 +928,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
                                "sensitive": false,
                                "updatable": false,
                                "type": "str",
-                               "supported_task_types": ["text_embedding"]
+                               "supported_task_types": ["text_embedding", "completion", "chat_completion"]
                            },
                            "model_id": {
                                "description": "The name of the model to use for the inference task.",
@@ -937,16 +937,16 @@ public class IbmWatsonxServiceTests extends ESTestCase {
                                "sensitive": false,
                                "updatable": false,
                                "type": "str",
-                               "supported_task_types": ["text_embedding"]
+                               "supported_task_types": ["text_embedding", "completion", "chat_completion"]
                            },
                            "api_version": {
-                               "description": "The IBM Watsonx API version ID to use.",
+                               "description": "The IBM watsonx API version ID to use.",
                                "label": "API Version",
                                "required": true,
                                "sensitive": false,
                                "updatable": false,
                                "type": "str",
-                               "supported_task_types": ["text_embedding"]
+                               "supported_task_types": ["text_embedding", "completion", "chat_completion"]
                            },
                            "max_input_tokens": {
                                "description": "Allows you to specify the maximum number of tokens per input.",
@@ -964,7 +964,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
                                "sensitive": false,
                                "updatable": false,
                                "type": "str",
-                               "supported_task_types": ["text_embedding"]
+                               "supported_task_types": ["text_embedding", "completion", "chat_completion"]
                            }
                        }
                    }

+ 50 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxChatCompletionActionTests.java

@@ -0,0 +1,50 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx.action;
+
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.services.ChatCompletionActionTests;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
+import static org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator.COMPLETION_HANDLER;
+import static org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator.USER_ROLE;
+import static org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests.createModel;
+
+public class IbmWatsonxChatCompletionActionTests extends ChatCompletionActionTests {
+    public static final URI TEST_URI = URI.create("abc.com");
+
+    protected ExecutableAction createAction(String url, Sender sender) throws URISyntaxException {
+        var model = createModel(TEST_URI, randomAlphaOfLength(8), randomAlphaOfLength(8), randomAlphaOfLength(8), randomAlphaOfLength(8));
+        var manager = new GenericRequestManager<>(
+            threadPool,
+            model,
+            COMPLETION_HANDLER,
+            inputs -> new IbmWatsonxChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
+            ChatCompletionInput.class
+        );
+        var errorMessage = constructFailedToSendRequestMessage("watsonx chat completions");
+        return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "watsonx chat completions");
+    }
+
+    protected String getFailedToSendError() {
+        return "Failed to send watsonx chat completions request. Cause: failed";
+    }
+
+    protected String getOneInputError() {
+        return "watsonx chat completions only accepts 1 input";
+    }
+}

+ 3 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java

@@ -180,7 +180,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
-        assertThat(thrownException.getMessage(), is("Failed to send IBM Watsonx embeddings request. Cause: failed"));
+        assertThat(thrownException.getMessage(), is("Failed to send IBM watsonx embeddings request. Cause: failed"));
     }
 
     public void testExecute_ThrowsException() {
@@ -204,7 +204,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
 
         var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
 
-        assertThat(thrownException.getMessage(), is("Failed to send IBM Watsonx embeddings request. Cause: failed"));
+        assertThat(thrownException.getMessage(), is("Failed to send IBM watsonx embeddings request. Cause: failed"));
     }
 
     private ExecutableAction createAction(
@@ -218,7 +218,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
     ) {
         var model = createModel(modelName, projectId, uri, apiVersion, apiKey, url);
         var requestManager = new IbmWatsonxEmbeddingsRequestManagerWithoutAuth(model, TruncatorTests.createTruncator(), threadPool);
-        var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM Watsonx embeddings");
+        var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM watsonx embeddings");
         return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
     }
 

+ 107 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionModelTests.java

@@ -0,0 +1,107 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
+
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.List;
+
+import static org.hamcrest.Matchers.is;
+
+public class IbmWatsonxChatCompletionModelTests extends ESTestCase {
+    private static final URI TEST_URI = URI.create("abc.com");
+
+    public static IbmWatsonxChatCompletionModel createModel(URI uri, String apiVersion, String modelId, String projectId, String apiKey)
+        throws URISyntaxException {
+        return new IbmWatsonxChatCompletionModel(
+            "id",
+            TaskType.COMPLETION,
+            "service",
+            new IbmWatsonxChatCompletionServiceSettings(uri, apiVersion, modelId, projectId, null),
+            new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
+        );
+    }
+
+    public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() throws URISyntaxException {
+        var model = createModel(TEST_URI, "apiVersion", "modelId", "projectId", "apiKey");
+        var request = new UnifiedCompletionRequest(
+            List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
+            "different_model",
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+
+        var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
+
+        assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
+    }
+
+    public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() throws URISyntaxException {
+        var model = createModel(TEST_URI, "apiVersion", null, "projectId", "apiKey");
+        var request = new UnifiedCompletionRequest(
+            List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
+            "different_model",
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+
+        var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
+
+        assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
+    }
+
+    public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() throws URISyntaxException {
+        var model = createModel(TEST_URI, "apiVersion", null, "projectId", "apiKey");
+        var request = new UnifiedCompletionRequest(
+            List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
+            null,
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+
+        var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
+
+        assertNull(overriddenModel.getServiceSettings().modelId());
+    }
+
+    public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() throws URISyntaxException {
+        var model = createModel(TEST_URI, "apiVersion", "modelId", "projectId", "apiKey");
+        var request = new UnifiedCompletionRequest(
+            List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
+            null, // not overriding model
+            null,
+            null,
+            null,
+            null,
+            null,
+            null
+        );
+
+        var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
+
+        assertThat(overriddenModel.getServiceSettings().modelId(), is("modelId"));
+    }
+}

+ 173 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/completion/IbmWatsonxChatCompletionServiceSettingsTests.java

@@ -0,0 +1,173 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+
+public class IbmWatsonxChatCompletionServiceSettingsTests extends AbstractWireSerializingTestCase<IbmWatsonxChatCompletionServiceSettings> {
+    private static final URI TEST_URI = URI.create("abc.com");
+
+    private static IbmWatsonxChatCompletionServiceSettings createRandom() {
+        return new IbmWatsonxChatCompletionServiceSettings(
+            TEST_URI,
+            randomAlphaOfLength(8),
+            randomAlphaOfLength(8),
+            randomAlphaOfLength(8),
+            randomFrom(RateLimitSettingsTests.createRandom(), null)
+        );
+    }
+
+    private IbmWatsonxChatCompletionServiceSettings getServiceSettings(Map<String, String> map) {
+        return IbmWatsonxChatCompletionServiceSettings.fromMap(new HashMap<>(map), ConfigurationParseContext.PERSISTENT);
+    }
+
+    public void testFromMap_WithAllParameters_CreatesSettingsCorrectly() {
+        var model = randomAlphaOfLength(8);
+        var projectId = randomAlphaOfLength(8);
+        var apiVersion = randomAlphaOfLength(8);
+
+        var serviceSettings = getServiceSettings(
+            Map.of(
+                ServiceFields.URL,
+                TEST_URI.toString(),
+                IbmWatsonxServiceFields.API_VERSION,
+                apiVersion,
+                ServiceFields.MODEL_ID,
+                model,
+                IbmWatsonxServiceFields.PROJECT_ID,
+                projectId
+            )
+        );
+        assertThat(serviceSettings, is(new IbmWatsonxChatCompletionServiceSettings(TEST_URI, apiVersion, model, projectId, null)));
+    }
+
+    public void testFromMap_Fails_WithoutRequiredParam_Url() {
+        var ex = expectThrows(
+            ValidationException.class,
+            () -> getServiceSettings(
+                Map.of(
+                    IbmWatsonxServiceFields.API_VERSION,
+                    randomAlphaOfLength(8),
+                    ServiceFields.MODEL_ID,
+                    randomAlphaOfLength(8),
+                    IbmWatsonxServiceFields.PROJECT_ID,
+                    randomAlphaOfLength(8)
+                )
+            )
+        );
+        assertThat(ex.getMessage(), equalTo(generateErrorMessage("url")));
+    }
+
+    public void testFromMap_Fails_WithoutRequiredParam_ApiVersion() {
+        var ex = expectThrows(
+            ValidationException.class,
+            () -> getServiceSettings(
+                Map.of(
+                    ServiceFields.URL,
+                    TEST_URI.toString(),
+                    ServiceFields.MODEL_ID,
+                    randomAlphaOfLength(8),
+                    IbmWatsonxServiceFields.PROJECT_ID,
+                    randomAlphaOfLength(8)
+                )
+            )
+        );
+        assertThat(ex.getMessage(), equalTo(generateErrorMessage("api_version")));
+    }
+
+    public void testFromMap_Fails_WithoutRequiredParam_ModelId() {
+        var ex = expectThrows(
+            ValidationException.class,
+            () -> getServiceSettings(
+                Map.of(
+                    ServiceFields.URL,
+                    TEST_URI.toString(),
+                    IbmWatsonxServiceFields.API_VERSION,
+                    randomAlphaOfLength(8),
+                    IbmWatsonxServiceFields.PROJECT_ID,
+                    randomAlphaOfLength(8)
+                )
+            )
+        );
+        assertThat(ex.getMessage(), equalTo(generateErrorMessage("model_id")));
+    }
+
+    public void testFromMap_Fails_WithoutRequiredParam_ProjectId() {
+        var ex = expectThrows(
+            ValidationException.class,
+            () -> getServiceSettings(
+                Map.of(
+                    ServiceFields.URL,
+                    TEST_URI.toString(),
+                    IbmWatsonxServiceFields.API_VERSION,
+                    randomAlphaOfLength(8),
+                    ServiceFields.MODEL_ID,
+                    randomAlphaOfLength(8)
+                )
+            )
+        );
+        assertThat(ex.getMessage(), equalTo(generateErrorMessage("project_id")));
+    }
+
+    private String generateErrorMessage(String field) {
+        return "Validation Failed: 1: [service_settings] does not contain the required setting [" + field + "];";
+    }
+
+    public void testToXContent_WritesAllValues() throws IOException {
+        var entity = new IbmWatsonxChatCompletionServiceSettings(TEST_URI, "2024-05-02", "model", "project_id", null);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
+            {
+                "url":"abc.com",
+                "api_version":"2024-05-02",
+                "model_id":"model",
+                "project_id":"project_id",
+                "rate_limit": {
+                    "requests_per_minute":120
+                }
+            }"""));
+    }
+
+    @Override
+    protected Writeable.Reader<IbmWatsonxChatCompletionServiceSettings> instanceReader() {
+        return IbmWatsonxChatCompletionServiceSettings::new;
+    }
+
+    @Override
+    protected IbmWatsonxChatCompletionServiceSettings createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected IbmWatsonxChatCompletionServiceSettings mutateInstance(IbmWatsonxChatCompletionServiceSettings instance) throws IOException {
+        return randomValueOtherThan(instance, IbmWatsonxChatCompletionServiceSettingsTests::createRandom);
+    }
+}

+ 66 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestEntityTests.java

@@ -0,0 +1,66 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.json.JsonXContent;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.ArrayList;
+
+import static org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests.createModel;
+
+public class IbmWatsonxChatCompletionRequestEntityTests extends ESTestCase {
+
+    private static final String ROLE = "user";
+
+    public void testModelUserFieldsSerialization() throws IOException, URISyntaxException {
+        UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
+            new UnifiedCompletionRequest.ContentString("test content"),
+            ROLE,
+            null,
+            null
+        );
+        var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
+        messageList.add(message);
+
+        var unifiedRequest = UnifiedCompletionRequest.of(messageList);
+
+        UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
+        IbmWatsonxChatCompletionModel model = createModel(new URI("abc.com"), "apiVersion", "modelId", "projectId", "apiKey");
+
+        IbmWatsonxChatCompletionRequestEntity entity = new IbmWatsonxChatCompletionRequestEntity(unifiedChatInput, model);
+
+        XContentBuilder builder = JsonXContent.contentBuilder();
+        entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
+        String expectedJson = """
+            {
+                "project_id": "projectId",
+                "messages": [
+                    {
+                        "content": "test content",
+                        "role": "user"
+                    }
+                ],
+                "model": "modelId",
+                "n": 1,
+                "stream": true
+            }
+            """;
+        assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder));
+    }
+}

+ 106 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/request/IbmWatsonxChatCompletionRequestTests.java

@@ -0,0 +1,106 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
+import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.hamcrest.Matchers.aMapWithSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class IbmWatsonxChatCompletionRequestTests extends ESTestCase {
+    private static final String AUTH_HEADER_VALUE = "foo";
+    private static final String API_COMPLETIONS_PATH = "https://abc.com/ml/v1/text/chat?version=apiVersion";
+
+    public void testCreateRequest_WithStreaming() throws IOException, URISyntaxException {
+        assertCreateRequestWithStreaming(true);
+    }
+
+    public void testCreateRequest_WithNoStreaming() throws IOException, URISyntaxException {
+        assertCreateRequestWithStreaming(false);
+    }
+
+    public void testTruncate_DoesNotReduceInputTextSize() throws IOException, URISyntaxException {
+        String input = randomAlphaOfLength(5);
+        String model = randomAlphaOfLength(5);
+
+        var request = createRequest(randomAlphaOfLength(5), input, model, true);
+        var truncatedRequest = request.truncate();
+        assertThat(request.getURI().toString(), is(API_COMPLETIONS_PATH));
+
+        var httpRequest = truncatedRequest.createHttpRequest();
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(5));
+
+        assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
+        assertThat(requestMap.get("model"), is(model));
+        assertThat(requestMap.get("n"), is(1));
+        assertTrue((Boolean) requestMap.get("stream"));
+        assertNull(requestMap.get("stream_options"));
+    }
+
+    public void testTruncationInfo_ReturnsNull() throws URISyntaxException {
+        var request = createRequest(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), true);
+        assertNull(request.getTruncationInfo());
+    }
+
+    public static IbmWatsonxChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model)
+        throws URISyntaxException {
+        return createRequest(apiKey, input, model, false);
+    }
+
+    public static IbmWatsonxChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model, boolean stream)
+        throws URISyntaxException {
+        var chatCompletionModel = IbmWatsonxChatCompletionModelTests.createModel(
+            new URI("abc.com"),
+            "apiVersion",
+            model,
+            randomAlphaOfLength(5),
+            apiKey
+        );
+        return new IbmWatsonxChatCompletionWithoutAuthRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
+    }
+
+    private static class IbmWatsonxChatCompletionWithoutAuthRequest extends IbmWatsonxChatCompletionRequest {
+        IbmWatsonxChatCompletionWithoutAuthRequest(UnifiedChatInput input, IbmWatsonxChatCompletionModel model) {
+            super(input, model);
+        }
+
+        @Override
+        public void decorateWithAuth(HttpPost httpPost) {
+            httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE);
+        }
+    }
+
+    private void assertCreateRequestWithStreaming(boolean isStreaming) throws URISyntaxException, IOException {
+        var request = createRequest(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), isStreaming);
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap.get("stream"), is(isStreaming));
+    }
+}

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/response/IbmWatsonxEmbeddingsResponseEntityTests.java

@@ -112,6 +112,6 @@ public class IbmWatsonxEmbeddingsResponseEntityTests extends ESTestCase {
             )
         );
 
-        assertThat(thrownException.getMessage(), is("Failed to find required field [results] in IBM Watsonx embeddings response"));
+        assertThat(thrownException.getMessage(), is("Failed to find required field [results] in IBM watsonx embeddings response"));
     }
 }

+ 14 - 149
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java

@@ -8,42 +8,28 @@
 package org.elasticsearch.xpack.inference.services.mistral.action;
 
 import org.apache.http.HttpHeaders;
-import org.elasticsearch.ElasticsearchException;
-import org.elasticsearch.ElasticsearchStatusException;
-import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.common.settings.Settings;
-import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceServiceResults;
-import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.http.MockRequest;
 import org.elasticsearch.test.http.MockResponse;
-import org.elasticsearch.test.http.MockWebServer;
-import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
-import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
 import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
-import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
-import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.ChatCompletionActionTests;
 import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
-import org.junit.After;
-import org.junit.Before;
 
 import java.io.IOException;
+import java.net.URISyntaxException;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.TimeUnit;
 
 import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
-import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
 import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
 import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
 import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
@@ -56,64 +42,15 @@ import static org.elasticsearch.xpack.inference.services.mistral.completion.Mist
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doThrow;
-import static org.mockito.Mockito.mock;
-
-public class MistralChatCompletionActionTests extends ESTestCase {
-    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
-    private final MockWebServer webServer = new MockWebServer();
-    private ThreadPool threadPool;
-    private HttpClientManager clientManager;
-
-    @Before
-    public void init() throws Exception {
-        webServer.start();
-        threadPool = createThreadPool(inferenceUtilityPool());
-        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
-    }
-
-    @After
-    public void shutdown() throws IOException {
-        clientManager.close();
-        terminate(threadPool);
-        webServer.close();
-    }
 
-    public void testExecute_ReturnsSuccessfulResponse() throws IOException {
+public class MistralChatCompletionActionTests extends ChatCompletionActionTests {
+    public void testExecute_ReturnsSuccessfulResponse() throws IOException, URISyntaxException {
         var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
 
         try (var sender = createSender(senderFactory)) {
             sender.start();
 
-            String responseJson = """
-                {
-                     "id": "9d80f26810ac4e9582f927fcf0512ec7",
-                     "object": "chat.completion",
-                     "created": 1748596419,
-                     "model": "mistral-small-latest",
-                     "choices": [
-                         {
-                             "index": 0,
-                             "message": {
-                                 "role": "assistant",
-                                 "tool_calls": null,
-                                 "content": "result content"
-                             },
-                             "finish_reason": "length",
-                             "logprobs": null
-                         }
-                     ],
-                     "usage": {
-                         "prompt_tokens": 10,
-                         "total_tokens": 11,
-                         "completion_tokens": 1
-                     }
-                 }
-                """;
-
-            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson()));
 
             var action = createAction(getUrl(webServer), sender);
 
@@ -140,87 +77,7 @@ public class MistralChatCompletionActionTests extends ESTestCase {
         }
     }
 
-    public void testExecute_ThrowsElasticsearchException() {
-        var sender = mock(Sender.class);
-        doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
-
-        var action = createAction(getUrl(webServer), sender);
-
-        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
-
-        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
-
-        assertThat(thrownException.getMessage(), is("failed"));
-    }
-
-    public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
-        var sender = mock(Sender.class);
-
-        doAnswer(invocation -> {
-            ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
-            listener.onFailure(new IllegalStateException("failed"));
-
-            return Void.TYPE;
-        }).when(sender).send(any(), any(), any(), any());
-
-        var action = createAction(getUrl(webServer), sender);
-
-        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-        action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
-
-        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
-
-        assertThat(thrownException.getMessage(), is("Failed to send mistral chat completions request. Cause: failed"));
-    }
-
-    public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
-        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
-
-        try (var sender = createSender(senderFactory)) {
-            sender.start();
-
-            String responseJson = """
-                {
-                     "id": "9d80f26810ac4e9582f927fcf0512ec7",
-                     "object": "chat.completion",
-                     "created": 1748596419,
-                     "model": "mistral-small-latest",
-                     "choices": [
-                         {
-                             "index": 0,
-                             "message": {
-                                 "role": "assistant",
-                                 "tool_calls": null,
-                                 "content": "result content"
-                             },
-                             "finish_reason": "length",
-                             "logprobs": null
-                         }
-                     ],
-                     "usage": {
-                         "prompt_tokens": 10,
-                         "total_tokens": 11,
-                         "completion_tokens": 1
-                     }
-                 }
-                """;
-
-            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
-
-            var action = createAction(getUrl(webServer), sender);
-
-            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
-            action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
-
-            var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
-
-            assertThat(thrownException.getMessage(), is("mistral chat completions only accepts 1 input"));
-            assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
-        }
-    }
-
-    private ExecutableAction createAction(String url, Sender sender) {
+    protected ExecutableAction createAction(String url, Sender sender) {
         var model = createCompletionModel("secret", "model");
         model.setURI(url);
         var manager = new GenericRequestManager<>(
@@ -233,4 +90,12 @@ public class MistralChatCompletionActionTests extends ESTestCase {
         var errorMessage = constructFailedToSendRequestMessage("mistral chat completions");
         return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "mistral chat completions");
     }
+
+    protected String getFailedToSendError() {
+        return "Failed to send mistral chat completions request. Cause: failed";
+    }
+
+    protected String getOneInputError() {
+        return "mistral chat completions only accepts 1 input";
+    }
 }