Browse Source

[Inference API] Add Azure OpenAI completion support (#108352)

Tim Grein 1 year ago
parent
commit
616e71963e
37 changed files with 2283 additions and 99 deletions
  1. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  2. 6 0
      test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java
  3. 7 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java
  4. 3 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java
  5. 67 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionAction.java
  6. 58 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java
  7. 70 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequest.java
  8. 64 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequestEntity.java
  9. 1 26
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java
  10. 35 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java
  11. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java
  12. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/XContentUtils.java
  13. 114 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntity.java
  14. 41 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java
  15. 29 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java
  16. 17 12
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
  17. 121 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java
  18. 38 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettings.java
  19. 183 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettings.java
  20. 105 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java
  21. 20 17
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java
  22. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java
  23. 208 15
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java
  24. 200 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java
  25. 62 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequestTests.java
  26. 45 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestEntityTests.java
  27. 100 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestTests.java
  28. 2 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestEntityTests.java
  29. 33 20
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTests.java
  30. 18 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/XContentUtilsTests.java
  31. 220 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntityTests.java
  32. 3 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiChatCompletionResponseEntityTests.java
  33. 142 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModelTests.java
  34. 45 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettingsTests.java
  35. 92 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java
  36. 99 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java
  37. 30 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -195,6 +195,7 @@ public class TransportVersions {
     public static final TransportVersion INDEXING_PRESSURE_REQUEST_REJECTIONS_COUNT = def(8_652_00_0);
     public static final TransportVersion ROLLUP_USAGE = def(8_653_00_0);
     public static final TransportVersion SECURITY_ROLE_DESCRIPTION = def(8_654_00_0);
+    public static final TransportVersion ML_INFERENCE_AZURE_OPENAI_COMPLETIONS = def(8_655_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 6 - 0
test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java

@@ -64,6 +64,7 @@ import org.elasticsearch.common.logging.HeaderWarningAppender;
 import org.elasticsearch.common.logging.LogConfigurator;
 import org.elasticsearch.common.logging.Loggers;
 import org.elasticsearch.common.lucene.Lucene;
+import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.time.DateUtils;
@@ -1058,6 +1059,11 @@ public abstract class ESTestCase extends LuceneTestCase {
         return RandomizedTest.randomAsciiOfLength(codeUnits);
     }
 
+    public static SecureString randomSecureStringOfLength(int codeUnits) {
+        var randomAlpha = randomAlphaOfLength(codeUnits);
+        return new SecureString(randomAlpha.toCharArray());
+    }
+
     public static String randomNullOrAlphaOfLength(int codeUnits) {
         return randomBoolean() ? null : randomAlphaOfLength(codeUnits);
     }

+ 7 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.external.action.azureopenai;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
 import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
 
 import java.util.Map;
@@ -32,4 +33,10 @@ public class AzureOpenAiActionCreator implements AzureOpenAiActionVisitor {
         var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, taskSettings);
         return new AzureOpenAiEmbeddingsAction(sender, overriddenModel, serviceComponents);
     }
+
+    @Override
+    public ExecutableAction create(AzureOpenAiCompletionModel model, Map<String, Object> taskSettings) {
+        var overriddenModel = AzureOpenAiCompletionModel.of(model, taskSettings);
+        return new AzureOpenAiCompletionAction(sender, overriddenModel, serviceComponents);
+    }
 }

+ 3 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java

@@ -8,10 +8,13 @@
 package org.elasticsearch.xpack.inference.external.action.azureopenai;
 
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
 import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
 
 import java.util.Map;
 
 public interface AzureOpenAiActionVisitor {
     ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map<String, Object> taskSettings);
+
+    ExecutableAction create(AzureOpenAiCompletionModel model, Map<String, Object> taskSettings);
 }

+ 67 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionAction.java

@@ -0,0 +1,67 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.azureopenai;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiCompletionRequestManager;
+import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
+import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
+
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
+
+public class AzureOpenAiCompletionAction implements ExecutableAction {
+
+    private final String errorMessage;
+    private final AzureOpenAiCompletionRequestManager requestCreator;
+    private final Sender sender;
+
+    public AzureOpenAiCompletionAction(Sender sender, AzureOpenAiCompletionModel model, ServiceComponents serviceComponents) {
+        Objects.requireNonNull(serviceComponents);
+        Objects.requireNonNull(model);
+        this.sender = Objects.requireNonNull(sender);
+        this.requestCreator = new AzureOpenAiCompletionRequestManager(model, serviceComponents.threadPool());
+        this.errorMessage = constructFailedToSendRequestMessage(model.getUri(), "Azure OpenAI completion");
+    }
+
+    @Override
+    public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
+        if (inferenceInputs instanceof DocumentsOnlyInput == false) {
+            listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR));
+            return;
+        }
+
+        var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
+        if (docsOnlyInput.getInputs().size() > 1) {
+            listener.onFailure(new ElasticsearchStatusException("Azure OpenAI completion only accepts 1 input", RestStatus.BAD_REQUEST));
+            return;
+        }
+
+        try {
+            ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);
+
+            sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
+        } catch (ElasticsearchException e) {
+            listener.onFailure(e);
+        } catch (Exception e) {
+            listener.onFailure(createInternalServerError(e, errorMessage));
+        }
+    }
+}

+ 58 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java

@@ -0,0 +1,58 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.http.sender;
+
+import org.apache.http.client.protocol.HttpClientContext;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiCompletionRequest;
+import org.elasticsearch.xpack.inference.external.response.azureopenai.AzureOpenAiCompletionResponseEntity;
+import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.function.Supplier;
+
+public class AzureOpenAiCompletionRequestManager extends AzureOpenAiRequestManager {
+
+    private static final Logger logger = LogManager.getLogger(AzureOpenAiCompletionRequestManager.class);
+
+    private static final ResponseHandler HANDLER = createCompletionHandler();
+
+    private final AzureOpenAiCompletionModel model;
+
+    private static ResponseHandler createCompletionHandler() {
+        return new AzureOpenAiResponseHandler("azure openai completion", AzureOpenAiCompletionResponseEntity::fromResponse);
+    }
+
+    public AzureOpenAiCompletionRequestManager(AzureOpenAiCompletionModel model, ThreadPool threadPool) {
+        super(threadPool, model);
+        this.model = Objects.requireNonNull(model);
+    }
+
+    @Override
+    public Runnable create(
+        @Nullable String query,
+        List<String> input,
+        RequestSender requestSender,
+        Supplier<Boolean> hasRequestCompletedFunction,
+        HttpClientContext context,
+        ActionListener<InferenceServiceResults> listener
+    ) {
+        AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model);
+        return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
+    }
+
+}

+ 70 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequest.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai;
+
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.ByteArrayEntity;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
+
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+import java.util.Objects;
+
+public class AzureOpenAiCompletionRequest implements AzureOpenAiRequest {
+
+    private final List<String> input;
+
+    private final URI uri;
+
+    private final AzureOpenAiCompletionModel model;
+
+    public AzureOpenAiCompletionRequest(List<String> input, AzureOpenAiCompletionModel model) {
+        this.input = input;
+        this.model = Objects.requireNonNull(model);
+        this.uri = model.getUri();
+    }
+
+    @Override
+    public HttpRequest createHttpRequest() {
+        var httpPost = new HttpPost(uri);
+        var requestEntity = Strings.toString(new AzureOpenAiCompletionRequestEntity(input, model.getTaskSettings().user()));
+
+        ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
+        httpPost.setEntity(byteEntity);
+
+        AzureOpenAiRequest.decorateWithAuthHeader(httpPost, model.getSecretSettings());
+
+        return new HttpRequest(httpPost, getInferenceEntityId());
+    }
+
+    @Override
+    public URI getURI() {
+        return this.uri;
+    }
+
+    @Override
+    public String getInferenceEntityId() {
+        return model.getInferenceEntityId();
+    }
+
+    @Override
+    public Request truncate() {
+        // No truncation for Azure OpenAI completion
+        return this;
+    }
+
+    @Override
+    public boolean[] getTruncationInfo() {
+        // No truncation for Azure OpenAI completion
+        return null;
+    }
+}

+ 64 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiCompletionRequestEntity.java

@@ -0,0 +1,64 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+public record AzureOpenAiCompletionRequestEntity(List<String> messages, @Nullable String user) implements ToXContentObject {
+
+    private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n";
+
+    private static final String MESSAGES_FIELD = "messages";
+
+    private static final String ROLE_FIELD = "role";
+
+    private static final String CONTENT_FIELD = "content";
+
+    private static final String USER_FIELD = "user";
+
+    public AzureOpenAiCompletionRequestEntity {
+        Objects.requireNonNull(messages);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.startArray(MESSAGES_FIELD);
+
+        {
+            for (String message : messages) {
+                builder.startObject();
+
+                {
+                    builder.field(ROLE_FIELD, USER_FIELD);
+                    builder.field(CONTENT_FIELD, message);
+                }
+
+                builder.endObject();
+            }
+        }
+
+        builder.endArray();
+
+        builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
+
+        if (Strings.isNullOrEmpty(user) == false) {
+            builder.field(USER_FIELD, user);
+        }
+
+        builder.endObject();
+        return builder;
+    }
+}

+ 1 - 26
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java

@@ -7,13 +7,9 @@
 
 package org.elasticsearch.xpack.inference.external.request.azureopenai;
 
-import org.apache.http.HttpHeaders;
 import org.apache.http.client.methods.HttpPost;
 import org.apache.http.entity.ByteArrayEntity;
-import org.apache.http.message.BasicHeader;
 import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.ValidationException;
-import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.common.Truncator;
 import org.elasticsearch.xpack.inference.external.request.HttpRequest;
 import org.elasticsearch.xpack.inference.external.request.Request;
@@ -23,14 +19,7 @@ import java.net.URI;
 import java.nio.charset.StandardCharsets;
 import java.util.Objects;
 
-import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
-import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
-import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.API_KEY;
-import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.ENTRA_ID;
-
 public class AzureOpenAiEmbeddingsRequest implements AzureOpenAiRequest {
-    private static final String MISSING_AUTHENTICATION_ERROR_MESSAGE =
-        "The request does not have any authentication methods set. One of [%s] or [%s] is required.";
 
     private final Truncator truncator;
     private final Truncator.TruncationResult truncationResult;
@@ -59,21 +48,7 @@ public class AzureOpenAiEmbeddingsRequest implements AzureOpenAiRequest {
         ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
         httpPost.setEntity(byteEntity);
 
-        httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
-
-        var entraId = model.getSecretSettings().entraId();
-        var apiKey = model.getSecretSettings().apiKey();
-
-        if (entraId != null && entraId.isEmpty() == false) {
-            httpPost.setHeader(createAuthBearerHeader(entraId));
-        } else if (apiKey != null && apiKey.isEmpty() == false) {
-            httpPost.setHeader(new BasicHeader(API_KEY_HEADER, apiKey.toString()));
-        } else {
-            // should never happen due to the checks on the secret settings, but just in case
-            ValidationException validationException = new ValidationException();
-            validationException.addValidationError(Strings.format(MISSING_AUTHENTICATION_ERROR_MESSAGE, API_KEY, ENTRA_ID));
-            throw validationException;
-        }
+        AzureOpenAiRequest.decorateWithAuthHeader(httpPost, model.getSecretSettings());
 
         return new HttpRequest(httpPost, getInferenceEntityId());
     }

+ 35 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java

@@ -7,6 +7,40 @@
 
 package org.elasticsearch.xpack.inference.external.request.azureopenai;
 
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.message.BasicHeader;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
 
-public interface AzureOpenAiRequest extends Request {}
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
+import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.API_KEY;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.ENTRA_ID;
+
+public interface AzureOpenAiRequest extends Request {
+
+    String MISSING_AUTHENTICATION_ERROR_MESSAGE =
+        "The request does not have any authentication methods set. One of [%s] or [%s] is required.";
+
+    static void decorateWithAuthHeader(HttpPost httpPost, AzureOpenAiSecretSettings secretSettings) {
+        httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
+
+        var entraId = secretSettings.entraId();
+        var apiKey = secretSettings.apiKey();
+
+        if (entraId != null && entraId.isEmpty() == false) {
+            httpPost.setHeader(createAuthBearerHeader(entraId));
+        } else if (apiKey != null && apiKey.isEmpty() == false) {
+            httpPost.setHeader(new BasicHeader(API_KEY_HEADER, apiKey.toString()));
+        } else {
+            // should never happen due to the checks on the secret settings, but just in case
+            ValidationException validationException = new ValidationException();
+            validationException.addValidationError(Strings.format(MISSING_AUTHENTICATION_ERROR_MESSAGE, API_KEY, ENTRA_ID));
+            throw validationException;
+        }
+    }
+}

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java

@@ -13,6 +13,8 @@ public class AzureOpenAiUtils {
     public static final String OPENAI_PATH = "openai";
     public static final String DEPLOYMENTS_PATH = "deployments";
     public static final String EMBEDDINGS_PATH = "embeddings";
+    public static final String CHAT_PATH = "chat";
+    public static final String COMPLETIONS_PATH = "completions";
     public static final String API_VERSION_PARAMETER = "api-version";
     public static final String API_KEY_HEADER = "api-key";
 

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/XContentUtils.java

@@ -39,7 +39,7 @@ public class XContentUtils {
     public static void positionParserAtTokenAfterField(XContentParser parser, String field, String errorMsgTemplate) throws IOException {
         XContentParser.Token token = parser.nextToken();
 
-        while (token != null && token != XContentParser.Token.END_OBJECT) {
+        while (token != null) {
             if (token == XContentParser.Token.FIELD_NAME && parser.currentName().equals(field)) {
                 parser.nextToken();
                 return;

+ 114 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntity.java

@@ -0,0 +1,114 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.azureopenai;
+
+import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
+
+public class AzureOpenAiCompletionResponseEntity {
+
+    private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Azure OpenAI completions response";
+
+    /**
+     * Parses the Azure OpenAI completion response.
+     * For a request like:
+     *
+     * <pre>
+     *     <code>
+     *         {
+     *             "inputs": "Please summarize this text: some text"
+     *         }
+     *     </code>
+     * </pre>
+     *
+     * The response would look like:
+     *
+     * <pre>
+     *     <code>
+     *         {
+     *     "choices": [
+     *         {
+     *             "content_filter_results": {
+     *                 "hate": { ... },
+     *                 "self_harm": { ... },
+     *                 "sexual": { ... },
+     *                 "violence": { ... }
+     *             },
+     *             "finish_reason": "stop",
+     *             "index": 0,
+     *             "logprobs": null,
+     *             "message": {
+     *                 "content": "response",
+     *                 "role": "assistant"
+     *             }
+     *         }
+     *     ],
+     *     "created": 1714982782,
+     *     "id": "...",
+     *     "model": "gpt-4",
+     *     "object": "chat.completion",
+     *     "prompt_filter_results": [
+     *         {
+     *             "prompt_index": 0,
+     *             "content_filter_results": {
+     *                 "hate": { ... },
+     *                 "self_harm": { ... },
+     *                 "sexual": { ... },
+     *                 "violence": { ... }
+     *             }
+     *         }
+     *     ],
+     *     "system_fingerprint": null,
+     *     "usage": { ... }
+     * }
+     *     </code>
+     * </pre>
+     */
+    public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException {
+        var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
+        try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
+            moveToFirstToken(jsonParser);
+
+            XContentParser.Token token = jsonParser.currentToken();
+            ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
+
+            positionParserAtTokenAfterField(jsonParser, "choices", FAILED_TO_FIND_FIELD_TEMPLATE);
+
+            jsonParser.nextToken();
+            ensureExpectedToken(XContentParser.Token.START_OBJECT, jsonParser.currentToken(), jsonParser);
+
+            positionParserAtTokenAfterField(jsonParser, "message", FAILED_TO_FIND_FIELD_TEMPLATE);
+
+            token = jsonParser.currentToken();
+
+            ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser);
+
+            positionParserAtTokenAfterField(jsonParser, "content", FAILED_TO_FIND_FIELD_TEMPLATE);
+
+            XContentParser.Token contentToken = jsonParser.currentToken();
+            ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser);
+            String content = jsonParser.text();
+
+            return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content)));
+        }
+    }
+
+}

+ 41 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.services.azureopenai;
 
+import org.apache.http.client.utils.URIBuilder;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
@@ -14,11 +15,18 @@ import org.elasticsearch.inference.ServiceSettings;
 import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
 
 import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
+import static org.elasticsearch.core.Strings.format;
+
 public abstract class AzureOpenAiModel extends Model {
 
     protected URI uri;
@@ -50,6 +58,30 @@ public abstract class AzureOpenAiModel extends Model {
 
     public abstract ExecutableAction accept(AzureOpenAiActionVisitor creator, Map<String, Object> taskSettings);
 
+    public final URI buildUriString() throws URISyntaxException {
+        return AzureOpenAiModel.buildUri(resourceName(), deploymentId(), apiVersion(), operationPathSegments());
+    }
+
+    // use only for testing directly
+    public static URI buildUri(String resourceName, String deploymentId, String apiVersion, String... pathSegments)
+        throws URISyntaxException {
+        String hostname = format("%s.%s", resourceName, AzureOpenAiUtils.HOST_SUFFIX);
+
+        return new URIBuilder().setScheme("https")
+            .setHost(hostname)
+            .setPathSegments(createPathSegmentsList(deploymentId, pathSegments))
+            .addParameter(AzureOpenAiUtils.API_VERSION_PARAMETER, apiVersion)
+            .build();
+    }
+
+    private static List<String> createPathSegmentsList(String deploymentId, String[] pathSegments) {
+        List<String> pathSegmentsList = new ArrayList<>(
+            List.of(AzureOpenAiUtils.OPENAI_PATH, AzureOpenAiUtils.DEPLOYMENTS_PATH, deploymentId)
+        );
+        pathSegmentsList.addAll(Arrays.asList(pathSegments));
+        return pathSegmentsList;
+    }
+
     public URI getUri() {
         return uri;
     }
@@ -62,4 +94,13 @@ public abstract class AzureOpenAiModel extends Model {
     public AzureOpenAiRateLimitServiceSettings rateLimitServiceSettings() {
         return rateLimitServiceSettings;
     }
+
+    // TODO: can be inferred directly from modelConfigurations.getServiceSettings(); will be addressed with separate refactoring
+    public abstract String resourceName();
+
+    public abstract String deploymentId();
+
+    public abstract String apiVersion();
+
+    public abstract String[] operationPathSegments();
 }

+ 29 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java

@@ -25,12 +25,16 @@ import java.util.Objects;
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalSecureString;
 
-public record AzureOpenAiSecretSettings(@Nullable SecureString apiKey, @Nullable SecureString entraId) implements SecretSettings {
+public class AzureOpenAiSecretSettings implements SecretSettings {
 
     public static final String NAME = "azure_openai_secret_settings";
     public static final String API_KEY = "api_key";
     public static final String ENTRA_ID = "entra_id";
 
+    private final SecureString entraId;
+
+    private final SecureString apiKey;
+
     public static AzureOpenAiSecretSettings fromMap(@Nullable Map<String, Object> map) {
         if (map == null) {
             return null;
@@ -59,14 +63,24 @@ public record AzureOpenAiSecretSettings(@Nullable SecureString apiKey, @Nullable
         return new AzureOpenAiSecretSettings(secureApiToken, secureEntraId);
     }
 
-    public AzureOpenAiSecretSettings {
+    public AzureOpenAiSecretSettings(@Nullable SecureString apiKey, @Nullable SecureString entraId) {
         Objects.requireNonNullElse(apiKey, entraId);
+        this.apiKey = apiKey;
+        this.entraId = entraId;
     }
 
     public AzureOpenAiSecretSettings(StreamInput in) throws IOException {
         this(in.readOptionalSecureString(), in.readOptionalSecureString());
     }
 
+    public SecureString apiKey() {
+        return apiKey;
+    }
+
+    public SecureString entraId() {
+        return entraId;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
@@ -98,4 +112,17 @@ public record AzureOpenAiSecretSettings(@Nullable SecureString apiKey, @Nullable
         out.writeOptionalSecureString(apiKey);
         out.writeOptionalSecureString(entraId);
     }
+
+    @Override
+    public boolean equals(Object object) {
+        if (this == object) return true;
+        if (object == null || getClass() != object.getClass()) return false;
+        AzureOpenAiSecretSettings that = (AzureOpenAiSecretSettings) object;
+        return Objects.equals(entraId, that.entraId) && Objects.equals(apiKey, that.apiKey);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(entraId, apiKey);
+    }
 }

+ 17 - 12
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

@@ -35,6 +35,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.SenderService;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel;
 import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;
 
@@ -121,19 +122,23 @@ public class AzureOpenAiService extends SenderService {
         String failureMessage,
         ConfigurationParseContext context
     ) {
-        if (taskType == TaskType.TEXT_EMBEDDING) {
-            return new AzureOpenAiEmbeddingsModel(
-                inferenceEntityId,
-                taskType,
-                NAME,
-                serviceSettings,
-                taskSettings,
-                secretSettings,
-                context
-            );
+        switch (taskType) {
+            case TEXT_EMBEDDING -> {
+                return new AzureOpenAiEmbeddingsModel(
+                    inferenceEntityId,
+                    taskType,
+                    NAME,
+                    serviceSettings,
+                    taskSettings,
+                    secretSettings,
+                    context
+                );
+            }
+            case COMPLETION -> {
+                return new AzureOpenAiCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
+            }
+            default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
         }
-
-        throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
     }
 
     @Override

+ 121 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java

@@ -0,0 +1,121 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.completion;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
+
+import java.net.URISyntaxException;
+import java.util.Map;
+
+public class AzureOpenAiCompletionModel extends AzureOpenAiModel {
+
+    public static AzureOpenAiCompletionModel of(AzureOpenAiCompletionModel model, Map<String, Object> taskSettings) {
+        if (taskSettings == null || taskSettings.isEmpty()) {
+            return model;
+        }
+
+        var requestTaskSettings = AzureOpenAiCompletionRequestTaskSettings.fromMap(taskSettings);
+        return new AzureOpenAiCompletionModel(model, AzureOpenAiCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
+    }
+
+    public AzureOpenAiCompletionModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        Map<String, Object> serviceSettings,
+        Map<String, Object> taskSettings,
+        @Nullable Map<String, Object> secrets
+    ) {
+        this(
+            inferenceEntityId,
+            taskType,
+            service,
+            AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings),
+            AzureOpenAiCompletionTaskSettings.fromMap(taskSettings),
+            AzureOpenAiSecretSettings.fromMap(secrets)
+        );
+    }
+
+    // Should only be used directly for testing
+    AzureOpenAiCompletionModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        AzureOpenAiCompletionServiceSettings serviceSettings,
+        AzureOpenAiCompletionTaskSettings taskSettings,
+        @Nullable AzureOpenAiSecretSettings secrets
+    ) {
+        super(
+            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
+            new ModelSecrets(secrets),
+            serviceSettings
+        );
+        try {
+            this.uri = buildUriString();
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    public AzureOpenAiCompletionModel(AzureOpenAiCompletionModel originalModel, AzureOpenAiCompletionServiceSettings serviceSettings) {
+        super(originalModel, serviceSettings);
+    }
+
+    private AzureOpenAiCompletionModel(AzureOpenAiCompletionModel originalModel, AzureOpenAiCompletionTaskSettings taskSettings) {
+        super(originalModel, taskSettings);
+    }
+
+    @Override
+    public AzureOpenAiCompletionServiceSettings getServiceSettings() {
+        return (AzureOpenAiCompletionServiceSettings) super.getServiceSettings();
+    }
+
+    @Override
+    public AzureOpenAiCompletionTaskSettings getTaskSettings() {
+        return (AzureOpenAiCompletionTaskSettings) super.getTaskSettings();
+    }
+
+    @Override
+    public AzureOpenAiSecretSettings getSecretSettings() {
+        return (AzureOpenAiSecretSettings) super.getSecretSettings();
+    }
+
+    @Override
+    public ExecutableAction accept(AzureOpenAiActionVisitor creator, Map<String, Object> taskSettings) {
+        return creator.create(this, taskSettings);
+    }
+
+    @Override
+    public String resourceName() {
+        return getServiceSettings().resourceName();
+    }
+
+    @Override
+    public String deploymentId() {
+        return getServiceSettings().deploymentId();
+    }
+
+    @Override
+    public String apiVersion() {
+        return getServiceSettings().apiVersion();
+    }
+
+    @Override
+    public String[] operationPathSegments() {
+        return new String[] { AzureOpenAiUtils.CHAT_PATH, AzureOpenAiUtils.COMPLETIONS_PATH };
+    }
+
+}

+ 38 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettings.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.completion;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER;
+
+public record AzureOpenAiCompletionRequestTaskSettings(@Nullable String user) {
+
+    public static final AzureOpenAiCompletionRequestTaskSettings EMPTY_SETTINGS = new AzureOpenAiCompletionRequestTaskSettings(null);
+
+    public static AzureOpenAiCompletionRequestTaskSettings fromMap(Map<String, Object> map) {
+        if (map.isEmpty()) {
+            return AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS;
+        }
+
+        ValidationException validationException = new ValidationException();
+
+        String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new AzureOpenAiCompletionRequestTaskSettings(user);
+    }
+}

+ 183 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettings.java

@@ -0,0 +1,183 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.completion;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiRateLimitServiceSettings;
+import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME;
+
+public class AzureOpenAiCompletionServiceSettings extends FilteredXContentObject
+    implements
+        ServiceSettings,
+        AzureOpenAiRateLimitServiceSettings {
+
+    public static final String NAME = "azure_openai_completions_service_settings";
+
+    /**
+     * Rate limit documentation can be found here:
+     *
+     * Limits per region per model id
+     * https://learn.microsoft.com/en-us/azure/ai-services/openai/quotas-limits
+     *
+     * How to change the limits
+     * https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/quota?tabs=rest
+     *
+     * Blog giving some examples
+     * https://techcommunity.microsoft.com/t5/fasttrack-for-azure/optimizing-azure-openai-a-guide-to-limits-quotas-and-best/ba-p/4076268
+     *
+     * According to the docs 1000 tokens per minute (TPM) = 6 requests per minute (RPM). The limits change depending on the region
+     * and model. The lowest chat completions limit is 20k TPM, so we'll default to that.
+     * Calculation: 20K TPM = 20 * 6 = 120 requests per minute (used `francecentral` and `gpt-4` as basis for the calculation).
+     */
+    private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
+
+    public static AzureOpenAiCompletionServiceSettings fromMap(Map<String, Object> map) {
+        ValidationException validationException = new ValidationException();
+
+        var settings = fromMap(map, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new AzureOpenAiCompletionServiceSettings(settings);
+    }
+
+    private static AzureOpenAiCompletionServiceSettings.CommonFields fromMap(
+        Map<String, Object> map,
+        ValidationException validationException
+    ) {
+        String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException);
+
+        return new AzureOpenAiCompletionServiceSettings.CommonFields(resourceName, deploymentId, apiVersion, rateLimitSettings);
+    }
+
+    private record CommonFields(String resourceName, String deploymentId, String apiVersion, RateLimitSettings rateLimitSettings) {}
+
+    private final String resourceName;
+    private final String deploymentId;
+    private final String apiVersion;
+
+    private final RateLimitSettings rateLimitSettings;
+
+    public AzureOpenAiCompletionServiceSettings(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        @Nullable RateLimitSettings rateLimitSettings
+    ) {
+        this.resourceName = resourceName;
+        this.deploymentId = deploymentId;
+        this.apiVersion = apiVersion;
+        this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
+    }
+
+    public AzureOpenAiCompletionServiceSettings(StreamInput in) throws IOException {
+        resourceName = in.readString();
+        deploymentId = in.readString();
+        apiVersion = in.readString();
+        rateLimitSettings = new RateLimitSettings(in);
+    }
+
+    private AzureOpenAiCompletionServiceSettings(AzureOpenAiCompletionServiceSettings.CommonFields fields) {
+        this(fields.resourceName, fields.deploymentId, fields.apiVersion, fields.rateLimitSettings);
+    }
+
+    public String resourceName() {
+        return resourceName;
+    }
+
+    public String deploymentId() {
+        return deploymentId;
+    }
+
+    @Override
+    public RateLimitSettings rateLimitSettings() {
+        return DEFAULT_RATE_LIMIT_SETTINGS;
+    }
+
+    public String apiVersion() {
+        return apiVersion;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.startObject();
+
+        toXContentFragmentOfExposedFields(builder, params);
+        rateLimitSettings.toXContent(builder, params);
+
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        builder.field(RESOURCE_NAME, resourceName);
+        builder.field(DEPLOYMENT_ID, deploymentId);
+        builder.field(API_VERSION, apiVersion);
+
+        return builder;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_AZURE_OPENAI_COMPLETIONS;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(resourceName);
+        out.writeString(deploymentId);
+        out.writeString(apiVersion);
+        rateLimitSettings.writeTo(out);
+    }
+
+    @Override
+    public boolean equals(Object object) {
+        if (this == object) return true;
+        if (object == null || getClass() != object.getClass()) return false;
+        AzureOpenAiCompletionServiceSettings that = (AzureOpenAiCompletionServiceSettings) object;
+        return Objects.equals(resourceName, that.resourceName)
+            && Objects.equals(deploymentId, that.deploymentId)
+            && Objects.equals(apiVersion, that.apiVersion)
+            && Objects.equals(rateLimitSettings, that.rateLimitSettings);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(resourceName, deploymentId, apiVersion, rateLimitSettings);
+    }
+}

+ 105 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java

@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.completion;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+
+public class AzureOpenAiCompletionTaskSettings implements TaskSettings {
+
+    public static final String NAME = "azure_openai_completion_task_settings";
+
+    public static final String USER = "user";
+
+    public static AzureOpenAiCompletionTaskSettings fromMap(Map<String, Object> map) {
+        ValidationException validationException = new ValidationException();
+
+        String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new AzureOpenAiCompletionTaskSettings(user);
+    }
+
+    private final String user;
+
+    public static AzureOpenAiCompletionTaskSettings of(
+        AzureOpenAiCompletionTaskSettings originalSettings,
+        AzureOpenAiCompletionRequestTaskSettings requestSettings
+    ) {
+        var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user();
+        return new AzureOpenAiCompletionTaskSettings(userToUse);
+    }
+
+    public AzureOpenAiCompletionTaskSettings(@Nullable String user) {
+        this.user = user;
+    }
+
+    public AzureOpenAiCompletionTaskSettings(StreamInput in) throws IOException {
+        this.user = in.readOptionalString();
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        {
+            if (user != null) {
+                builder.field(USER, user);
+            }
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    public String user() {
+        return user;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_AZURE_OPENAI_COMPLETIONS;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalString(user);
+    }
+
+    @Override
+    public boolean equals(Object object) {
+        if (this == object) return true;
+        if (object == null || getClass() != object.getClass()) return false;
+        AzureOpenAiCompletionTaskSettings that = (AzureOpenAiCompletionTaskSettings) object;
+        return Objects.equals(user, that.user);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(user);
+    }
+}

+ 20 - 17
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java

@@ -7,7 +7,6 @@
 
 package org.elasticsearch.xpack.inference.services.azureopenai.embeddings;
 
-import org.apache.http.client.utils.URIBuilder;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
@@ -19,12 +18,9 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel;
 import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
 
-import java.net.URI;
 import java.net.URISyntaxException;
 import java.util.Map;
 
-import static org.elasticsearch.core.Strings.format;
-
 public class AzureOpenAiEmbeddingsModel extends AzureOpenAiModel {
 
     public static AzureOpenAiEmbeddingsModel of(AzureOpenAiEmbeddingsModel model, Map<String, Object> taskSettings) {
@@ -70,7 +66,7 @@ public class AzureOpenAiEmbeddingsModel extends AzureOpenAiModel {
             serviceSettings
         );
         try {
-            this.uri = getEmbeddingsUri(serviceSettings.resourceName(), serviceSettings.deploymentId(), serviceSettings.apiVersion());
+            this.uri = buildUriString();
         } catch (URISyntaxException e) {
             throw new RuntimeException(e);
         }
@@ -104,17 +100,24 @@ public class AzureOpenAiEmbeddingsModel extends AzureOpenAiModel {
         return creator.create(this, taskSettings);
     }
 
-    public static URI getEmbeddingsUri(String resourceName, String deploymentId, String apiVersion) throws URISyntaxException {
-        String hostname = format("%s.%s", resourceName, AzureOpenAiUtils.HOST_SUFFIX);
-        return new URIBuilder().setScheme("https")
-            .setHost(hostname)
-            .setPathSegments(
-                AzureOpenAiUtils.OPENAI_PATH,
-                AzureOpenAiUtils.DEPLOYMENTS_PATH,
-                deploymentId,
-                AzureOpenAiUtils.EMBEDDINGS_PATH
-            )
-            .addParameter(AzureOpenAiUtils.API_VERSION_PARAMETER, apiVersion)
-            .build();
+    @Override
+    public String resourceName() {
+        return getServiceSettings().resourceName();
+    }
+
+    @Override
+    public String deploymentId() {
+        return getServiceSettings().deploymentId();
+    }
+
+    @Override
+    public String apiVersion() {
+        return getServiceSettings().apiVersion();
+    }
+
+    @Override
+    public String[] operationPathSegments() {
+        return new String[] { AzureOpenAiUtils.EMBEDDINGS_PATH };
     }
+
 }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java

@@ -63,7 +63,7 @@ public class AzureOpenAiEmbeddingsServiceSettings extends FilteredXContentObject
      *
      * According to the docs 1000 tokens per minute (TPM) = 6 requests per minute (RPM). The limits change depending on the region
      * and model. The lowest text embedding limit is 240K TPM, so we'll default to that.
-     * Calculation: 240K TPM = 240 * 6 = 1440 requests per minute
+     * Calculation: 240K TPM = 240 * 6 = 1440 requests per minute (used `eastus` and `Text-Embedding-Ada-002` as basis for the calculation).
      */
     private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1_440);
 

+ 208 - 15
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.test.http.MockWebServer;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -45,6 +46,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
 import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
 import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
 import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
 import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
 import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsRequestTaskSettingsTests.createRequestTaskSettingsMap;
 import static org.hamcrest.Matchers.equalTo;
@@ -54,6 +56,11 @@ import static org.mockito.Mockito.mock;
 
 public class AzureOpenAiActionCreatorTests extends ESTestCase {
     private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+    private static final Settings ZERO_TIMEOUT_SETTINGS = buildSettingsWithRetryFields(
+        TimeValue.timeValueMillis(1),
+        TimeValue.timeValueMinutes(1),
+        TimeValue.timeValueSeconds(0)
+    );
     private final MockWebServer webServer = new MockWebServer();
     private ThreadPool threadPool;
     private HttpClientManager clientManager;
@@ -116,7 +123,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
             validateRequestWithApiKey(webServer.requests().get(0), "apikey");
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            validateRequestMapWithUser(requestMap, List.of("abc"), "overridden_user");
+            validateEmbeddingsRequestMapWithUser(requestMap, List.of("abc"), "overridden_user");
         } catch (URISyntaxException e) {
             throw new RuntimeException(e);
         }
@@ -166,7 +173,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
             validateRequestWithApiKey(webServer.requests().get(0), "apikey");
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            validateRequestMapWithUser(requestMap, List.of("abc"), null);
+            validateEmbeddingsRequestMapWithUser(requestMap, List.of("abc"), null);
         } catch (URISyntaxException e) {
             throw new RuntimeException(e);
         }
@@ -174,12 +181,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
 
     public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException {
         // timeout as zero for no retries
-        var settings = buildSettingsWithRetryFields(
-            TimeValue.timeValueMillis(1),
-            TimeValue.timeValueMinutes(1),
-            TimeValue.timeValueSeconds(0)
-        );
-        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
 
         try (var sender = senderFactory.createSender("test_service")) {
             sender.start();
@@ -226,7 +228,7 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
             validateRequestWithApiKey(webServer.requests().get(0), "apikey");
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            validateRequestMapWithUser(requestMap, List.of("abc"), "overridden_user");
+            validateEmbeddingsRequestMapWithUser(requestMap, List.of("abc"), "overridden_user");
         } catch (URISyntaxException e) {
             throw new RuntimeException(e);
         }
@@ -295,13 +297,13 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
                 validateRequestWithApiKey(webServer.requests().get(0), "apikey");
 
                 var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-                validateRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user");
+                validateEmbeddingsRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user");
             }
             {
                 validateRequestWithApiKey(webServer.requests().get(1), "apikey");
 
                 var requestMap = entityAsMap(webServer.requests().get(1).getBody());
-                validateRequestMapWithUser(requestMap, List.of("ab"), "overridden_user");
+                validateEmbeddingsRequestMapWithUser(requestMap, List.of("ab"), "overridden_user");
             }
         } catch (URISyntaxException e) {
             throw new RuntimeException(e);
@@ -371,13 +373,13 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
                 validateRequestWithApiKey(webServer.requests().get(0), "apikey");
 
                 var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-                validateRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user");
+                validateEmbeddingsRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user");
             }
             {
                 validateRequestWithApiKey(webServer.requests().get(1), "apikey");
 
                 var requestMap = entityAsMap(webServer.requests().get(1).getBody());
-                validateRequestMapWithUser(requestMap, List.of("ab"), "overridden_user");
+                validateEmbeddingsRequestMapWithUser(requestMap, List.of("ab"), "overridden_user");
             }
         } catch (URISyntaxException e) {
             throw new RuntimeException(e);
@@ -429,13 +431,186 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
             validateRequestWithApiKey(webServer.requests().get(0), "apikey");
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            validateRequestMapWithUser(requestMap, List.of("sup"), "overridden_user");
+            validateEmbeddingsRequestMapWithUser(requestMap, List.of("sup"), "overridden_user");
         } catch (URISyntaxException e) {
             throw new RuntimeException(e);
         }
     }
 
-    private void validateRequestMapWithUser(Map<String, Object> requestMap, List<String> input, @Nullable String user) {
+    public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            String responseJson = """
+                {
+                    "choices": [
+                                {
+                                    "finish_reason": "stop",
+                                    "index": 0,
+                                    "logprobs": null,
+                                    "message": {
+                                        "content": "response",
+                                        "role": "assistant"
+                                        }
+                                    }
+                                ],
+                                "model": "gpt-4",
+                                "object": "chat.completion"
+                }""";
+
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var originalUser = "original_user";
+            var overriddenUser = "overridden_user";
+            var apiKey = "api_key";
+            var completionInput = "some input";
+
+            var model = createCompletionModel("resource", "deployment", "apiversion", originalUser, apiKey, null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
+            var taskSettingsWithUserOverride = createRequestTaskSettingsMap(overriddenUser);
+            var action = (AzureOpenAiCompletionAction) actionCreator.create(model, taskSettingsWithUserOverride);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(webServer.requests(), hasSize(1));
+
+            var request = webServer.requests().get(0);
+            var requestMap = entityAsMap(request.getBody());
+
+            assertThat(
+                result.asMap(),
+                is(Map.of(ChatCompletionResults.COMPLETION, List.of(Map.of(ChatCompletionResults.Result.RESULT, "response"))))
+            );
+            validateRequestWithApiKey(request, apiKey);
+            validateCompletionRequestMapWithUser(requestMap, List.of(completionInput), overriddenUser);
+
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            String responseJson = """
+                {
+                    "choices": [
+                                {
+                                    "finish_reason": "stop",
+                                    "index": 0,
+                                    "logprobs": null,
+                                    "message": {
+                                        "content": "response",
+                                        "role": "assistant"
+                                        }
+                                    }
+                                ],
+                                "model": "gpt-4",
+                                "object": "chat.completion"
+                }""";
+
+            var completionInput = "some input";
+            var apiKey = "api key";
+
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = createCompletionModel("resource", "deployment", "apiversion", null, apiKey, null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
+            var requestTaskSettingsWithoutUser = createRequestTaskSettingsMap(null);
+            var action = (AzureOpenAiCompletionAction) actionCreator.create(model, requestTaskSettingsWithoutUser);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(webServer.requests(), hasSize(1));
+
+            var request = webServer.requests().get(0);
+            var requestMap = entityAsMap(request.getBody());
+
+            assertThat(
+                result.asMap(),
+                is(Map.of(ChatCompletionResults.COMPLETION, List.of(Map.of(ChatCompletionResults.Result.RESULT, "response"))))
+            );
+            validateRequestWithApiKey(request, apiKey);
+            validateCompletionRequestMapWithUser(requestMap, List.of(completionInput), null);
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    public void testInfer_AzureOpenAiCompletionModel_FailsFromInvalidResponseFormat() throws IOException {
+        // timeout as zero for no retries
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, ZERO_TIMEOUT_SETTINGS);
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            // "choices" missing
+            String responseJson = """
+                {
+                    "not_choices": [
+                                   {
+                                    "finish_reason": "stop",
+                                    "index": 0,
+                                    "logprobs": null,
+                                    "message": {
+                                        "content": "response",
+                                        "role": "assistant"
+                                        }
+                                    }
+                                ],
+                                "model": "gpt-4",
+                                "object": "chat.completion"
+                }""";
+
+            var completionInput = "some input";
+            var apiKey = "api key";
+            var userOverride = "overridden_user";
+
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = createCompletionModel("resource", "deployment", "apiversion", null, apiKey, null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
+            var requestTaskSettingsWithoutUser = createRequestTaskSettingsMap(userOverride);
+            var action = (AzureOpenAiCompletionAction) actionCreator.create(model, requestTaskSettingsWithoutUser);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+            assertThat(
+                thrownException.getMessage(),
+                is(format("Failed to send Azure OpenAI completion request to [%s]", getUrl(webServer)))
+            );
+            assertThat(
+                thrownException.getCause().getMessage(),
+                is("Failed to find required field [choices] in Azure OpenAI completions response")
+            );
+
+            assertThat(webServer.requests(), hasSize(1));
+            validateRequestWithApiKey(webServer.requests().get(0), apiKey);
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            validateCompletionRequestMapWithUser(requestMap, List.of(completionInput), userOverride);
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    private void validateEmbeddingsRequestMapWithUser(Map<String, Object> requestMap, List<String> input, @Nullable String user) {
         var expectedSize = user == null ? 1 : 2;
 
         assertThat(requestMap.size(), is(expectedSize));
@@ -446,6 +621,24 @@ public class AzureOpenAiActionCreatorTests extends ESTestCase {
         }
     }
 
+    private void validateCompletionRequestMapWithUser(Map<String, Object> requestMap, List<String> input, @Nullable String user) {
+        assertThat("input for completions can only be of size 1", input.size(), equalTo(1));
+
+        var expectedSize = user == null ? 2 : 3;
+
+        assertThat(requestMap.size(), is(expectedSize));
+        assertThat(getContentOfMessageInRequestMap(requestMap), is(input.get(0)));
+
+        if (user != null) {
+            assertThat(requestMap.get("user"), is(user));
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    public static String getContentOfMessageInRequestMap(Map<String, Object> requestMap) {
+        return ((Map<String, Object>) ((List<Object>) requestMap.get("messages")).get(0)).get("content").toString();
+    }
+
     private void validateRequestWithApiKey(MockRequest request, String apiKey) {
         assertNull(request.getUri().getQuery());
         assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));

+ 200 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java

@@ -0,0 +1,200 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.azureopenai;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreatorTests.getContentOfMessageInRequestMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+
+public class AzureOpenAiCompletionActionTests extends ESTestCase {
+
+    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+    private final MockWebServer webServer = new MockWebServer();
+    private ThreadPool threadPool;
+    private HttpClientManager clientManager;
+
+    @Before
+    public void init() throws Exception {
+        webServer.start();
+        threadPool = createThreadPool(inferenceUtilityPool());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        clientManager.close();
+        terminate(threadPool);
+        webServer.close();
+    }
+
+    public void testExecute_ReturnsSuccessfulResponse() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            String responseJson = """
+                {
+                    "choices": [
+                                {
+                                    "finish_reason": "stop",
+                                    "index": 0,
+                                    "logprobs": null,
+                                    "message": {
+                                        "content": "response",
+                                        "role": "assistant"
+                                        }
+                                    }
+                                ],
+                                "model": "gpt-4",
+                                "object": "chat.completion"
+                                ]
+                }""";
+
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var user = "user";
+            var apiKey = "api_key";
+            var completionInput = "some input";
+
+            var action = createAction("resource", "deployment", "apiversion", user, apiKey, sender, "id");
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(webServer.requests(), hasSize(1));
+
+            var request = webServer.requests().get(0);
+            assertNull(request.getUri().getQuery());
+            assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType()));
+            assertThat(request.getHeader(AzureOpenAiUtils.API_KEY_HEADER), is(apiKey));
+
+            assertThat(
+                result.asMap(),
+                is(Map.of(ChatCompletionResults.COMPLETION, List.of(Map.of(ChatCompletionResults.Result.RESULT, "response"))))
+            );
+
+            var requestMap = entityAsMap(request.getBody());
+            assertThat(requestMap.size(), is(3));
+            assertThat(getContentOfMessageInRequestMap(requestMap), is(completionInput));
+            assertThat(requestMap.get("user"), is(user));
+            assertThat(requestMap.get("n"), is(1));
+        }
+    }
+
+    public void testExecute_ThrowsElasticsearchException() {
+        var sender = mock(Sender.class);
+        doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
+
+        var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+        assertThat(thrownException.getMessage(), is("failed"));
+    }
+
+    public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
+        var sender = mock(Sender.class);
+
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[1];
+            listener.onFailure(new IllegalStateException("failed"));
+
+            return Void.TYPE;
+        }).when(sender).send(any(), any(), any(), any());
+
+        var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+        assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI completion request to [%s]", getUrl(webServer))));
+    }
+
+    public void testExecute_ThrowsException() {
+        var sender = mock(Sender.class);
+        doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any());
+
+        var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+        assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI completion request to [%s]", getUrl(webServer))));
+    }
+
+    private AzureOpenAiCompletionAction createAction(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        @Nullable String user,
+        String apiKey,
+        Sender sender,
+        String inferenceEntityId
+    ) {
+        try {
+            var model = createCompletionModel(resourceName, deploymentId, apiVersion, user, apiKey, null, inferenceEntityId);
+            model.setUri(new URI(getUrl(webServer)));
+            return new AzureOpenAiCompletionAction(sender, model, createWithEmptySettings(threadPool));
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+}

+ 62 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequestTests.java

@@ -0,0 +1,62 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
+
+import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiRequest.MISSING_AUTHENTICATION_ERROR_MESSAGE;
+import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.API_KEY;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.ENTRA_ID;
+import static org.hamcrest.Matchers.equalTo;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class AzureOpenAiRequestTests extends ESTestCase {
+
+    public void testDecorateWithAuthHeader_apiKeyPresent() {
+        var apiKey = randomSecureStringOfLength(10);
+        var httpPost = new HttpPost();
+        var secretSettings = new AzureOpenAiSecretSettings(apiKey, null);
+
+        AzureOpenAiRequest.decorateWithAuthHeader(httpPost, secretSettings);
+        var apiKeyHeader = httpPost.getFirstHeader(API_KEY_HEADER);
+
+        assertThat(apiKeyHeader.getValue(), equalTo(apiKey.toString()));
+    }
+
+    public void testDecorateWithAuthHeader_entraIdPresent() {
+        var entraId = randomSecureStringOfLength(10);
+        var httpPost = new HttpPost();
+        var secretSettings = new AzureOpenAiSecretSettings(null, entraId);
+
+        AzureOpenAiRequest.decorateWithAuthHeader(httpPost, secretSettings);
+        var authHeader = httpPost.getFirstHeader(HttpHeaders.AUTHORIZATION);
+
+        assertThat(authHeader.getValue(), equalTo("Bearer " + entraId));
+    }
+
+    public void testDecorateWithAuthHeader_entraIdAndApiKeyMissing_throwMissingAuthValidationException() {
+        var httpPost = new HttpPost();
+        var secretSettingsMock = mock(AzureOpenAiSecretSettings.class);
+
+        when(secretSettingsMock.entraId()).thenReturn(null);
+        when(secretSettingsMock.apiKey()).thenReturn(null);
+
+        ValidationException exception = expectThrows(
+            ValidationException.class,
+            () -> AzureOpenAiRequest.decorateWithAuthHeader(httpPost, secretSettingsMock)
+        );
+        assertTrue(exception.getMessage().contains(Strings.format(MISSING_AUTHENTICATION_ERROR_MESSAGE, API_KEY, ENTRA_ID)));
+    }
+}

+ 45 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestEntityTests.java

@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai.completion;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiCompletionRequestEntity;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class AzureOpenAiCompletionRequestEntityTests extends ESTestCase {
+
+    public void testXContent_WritesSingleMessage_DoesNotWriteUserWhenItIsNull() throws IOException {
+        var entity = new AzureOpenAiCompletionRequestEntity(List.of("input"), null);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is("""
+            {"messages":[{"role":"user","content":"input"}],"n":1}"""));
+    }
+
+    public void testXContent_WritesSingleMessage_WriteUserWhenItIsNull() throws IOException {
+        var entity = new AzureOpenAiCompletionRequestEntity(List.of("input"), "user");
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is("""
+            {"messages":[{"role":"user","content":"input"}],"n":1,"user":"user"}"""));
+    }
+}

+ 100 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/completion/AzureOpenAiCompletionRequestTests.java

@@ -0,0 +1,100 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai.completion;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiCompletionRequest;
+import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreatorTests.getContentOfMessageInRequestMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class AzureOpenAiCompletionRequestTests extends ESTestCase {
+
+    public void testCreateRequest_WithApiKeyDefined() throws IOException {
+        var input = "input";
+        var user = "user";
+        var apiKey = randomAlphaOfLength(10);
+
+        var request = createRequest("resource", "deployment", "2024", apiKey, null, input, user);
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        assertThat(
+            httpPost.getURI().toString(),
+            is("https://resource.openai.azure.com/openai/deployments/deployment/chat/completions?api-version=2024")
+        );
+
+        assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is(apiKey));
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(getContentOfMessageInRequestMap(requestMap), is(input));
+        assertThat(requestMap.get("user"), is(user));
+        assertThat(requestMap.get("n"), is(1));
+    }
+
+    public void testCreateRequest_WithEntraIdDefined() throws IOException {
+        var input = "input";
+        var user = "user";
+        var entraId = randomAlphaOfLength(10);
+
+        var request = createRequest("resource", "deployment", "2024", null, entraId, input, user);
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        assertThat(
+            httpPost.getURI().toString(),
+            is("https://resource.openai.azure.com/openai/deployments/deployment/chat/completions?api-version=2024")
+        );
+
+        assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + entraId));
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(getContentOfMessageInRequestMap(requestMap), is(input));
+        assertThat(requestMap.get("user"), is(user));
+        assertThat(requestMap.get("n"), is(1));
+    }
+
+    protected AzureOpenAiCompletionRequest createRequest(
+        String resource,
+        String deployment,
+        String apiVersion,
+        String apiKey,
+        String entraId,
+        String input,
+        String user
+    ) {
+        var completionModel = AzureOpenAiCompletionModelTests.createCompletionModel(
+            resource,
+            deployment,
+            apiVersion,
+            user,
+            apiKey,
+            entraId,
+            "id"
+        );
+
+        return new AzureOpenAiCompletionRequest(List.of(input), completionModel);
+    }
+
+}

+ 2 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntityTests.java → x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestEntityTests.java

@@ -5,13 +5,14 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.inference.external.request.azureopenai;
+package org.elasticsearch.xpack.inference.external.request.azureopenai.embeddings;
 
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiEmbeddingsRequestEntity;
 
 import java.io.IOException;
 import java.util.List;

+ 33 - 20
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestTests.java → x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTests.java

@@ -5,7 +5,7 @@
  * 2.0.
  */
 
-package org.elasticsearch.xpack.inference.external.request.azureopenai;
+package org.elasticsearch.xpack.inference.external.request.azureopenai.embeddings;
 
 import org.apache.http.HttpHeaders;
 import org.apache.http.client.methods.HttpPost;
@@ -14,56 +14,69 @@ import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.common.Truncator;
 import org.elasticsearch.xpack.inference.common.TruncatorTests;
-import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiEmbeddingsRequest;
 import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests;
 
 import java.io.IOException;
-import java.net.URISyntaxException;
 import java.util.List;
 
 import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
 import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
 import static org.hamcrest.Matchers.aMapWithSize;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 
 public class AzureOpenAiEmbeddingsRequestTests extends ESTestCase {
-    public void testCreateRequest_WithApiKeyDefined() throws IOException, URISyntaxException {
-        var request = createRequest("resource", "deployment", "apiVersion", "apikey", null, "abc", "user");
+
+    public void testCreateRequest_WithApiKeyDefined() throws IOException {
+        var input = "input";
+        var user = "user";
+        var apiKey = randomAlphaOfLength(10);
+
+        var request = createRequest("resource", "deployment", "2024", apiKey, null, input, user);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
         var httpPost = (HttpPost) httpRequest.httpRequestBase();
 
-        var expectedUri = AzureOpenAiEmbeddingsModel.getEmbeddingsUri("resource", "deployment", "apiVersion").toString();
-        assertThat(httpPost.getURI().toString(), is(expectedUri));
+        assertThat(
+            httpPost.getURI().toString(),
+            is("https://resource.openai.azure.com/openai/deployments/deployment/embeddings?api-version=2024")
+        );
 
         assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
-        assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is("apikey"));
+        assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is(apiKey));
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
-        assertThat(requestMap, aMapWithSize(2));
-        assertThat(requestMap.get("input"), is(List.of("abc")));
-        assertThat(requestMap.get("user"), is("user"));
+        assertThat(requestMap.size(), equalTo(2));
+        assertThat(requestMap.get("input"), is(List.of(input)));
+        assertThat(requestMap.get("user"), is(user));
     }
 
-    public void testCreateRequest_WithEntraIdDefined() throws IOException, URISyntaxException {
-        var request = createRequest("resource", "deployment", "apiVersion", null, "entraId", "abc", "user");
+    public void testCreateRequest_WithEntraIdDefined() throws IOException {
+        var input = "input";
+        var user = "user";
+        var entraId = randomAlphaOfLength(10);
+
+        var request = createRequest("resource", "deployment", "2024", null, entraId, input, user);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
         var httpPost = (HttpPost) httpRequest.httpRequestBase();
 
-        var expectedUri = AzureOpenAiEmbeddingsModel.getEmbeddingsUri("resource", "deployment", "apiVersion").toString();
-        assertThat(httpPost.getURI().toString(), is(expectedUri));
+        assertThat(
+            httpPost.getURI().toString(),
+            is("https://resource.openai.azure.com/openai/deployments/deployment/embeddings?api-version=2024")
+        );
 
         assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
-        assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer entraId"));
+        assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + entraId));
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
-        assertThat(requestMap, aMapWithSize(2));
-        assertThat(requestMap.get("input"), is(List.of("abc")));
-        assertThat(requestMap.get("user"), is("user"));
+        assertThat(requestMap.size(), equalTo(2));
+        assertThat(requestMap.get("input"), is(List.of(input)));
+        assertThat(requestMap.get("user"), is(user));
     }
 
     public void testTruncate_ReducesInputTextSizeByHalf() throws IOException {
@@ -87,7 +100,7 @@ public class AzureOpenAiEmbeddingsRequestTests extends ESTestCase {
         assertTrue(truncatedRequest.getTruncationInfo()[0]);
     }
 
-    public static AzureOpenAiEmbeddingsRequest createRequest(
+    public AzureOpenAiEmbeddingsRequest createRequest(
         String resourceName,
         String deploymentId,
         String apiVersion,

+ 18 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/XContentUtilsTests.java

@@ -106,6 +106,24 @@ public class XContentUtilsTests extends ESTestCase {
         }
     }
 
+    public void testPositionParserAtTokenAfterField_ConsumesUntilEnd() throws IOException {
+        var json = """
+            {
+              "key": {
+                "foo": "bar"
+              },
+              "target": "value"
+            }
+            """;
+
+        var errorFormat = "Error: %s";
+
+        try (XContentParser parser = createParser(XContentType.JSON.xContent(), json)) {
+            XContentUtils.positionParserAtTokenAfterField(parser, "target", errorFormat);
+            assertEquals("value", parser.text());
+        }
+    }
+
     public void testConsumeUntilObjectEnd() throws IOException {
         var json = """
             {

+ 220 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/azureopenai/AzureOpenAiCompletionResponseEntityTests.java

@@ -0,0 +1,220 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.response.azureopenai;
+
+import org.apache.http.HttpResponse;
+import org.elasticsearch.common.ParsingException;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class AzureOpenAiCompletionResponseEntityTests extends ESTestCase {
+
+    public void testFromResponse_CreatesResultsForASingleItem() throws IOException {
+        String responseJson = """
+            {
+                 "choices": [
+                     {
+                         "content_filter_results": {
+                             "hate": {
+                                 "filtered": false,
+                                 "severity": "safe"
+                             },
+                             "self_harm": {
+                                 "filtered": false,
+                                 "severity": "safe"
+                             },
+                             "sexual": {
+                                 "filtered": false,
+                                 "severity": "safe"
+                             },
+                             "violence": {
+                                 "filtered": false,
+                                 "severity": "safe"
+                             }
+                         },
+                         "finish_reason": "stop",
+                         "index": 0,
+                         "logprobs": null,
+                         "message": {
+                             "content": "response",
+                             "role": "assistant"
+                         }
+                     }
+                 ],
+                 "model": "gpt-4",
+                 "object": "chat.completion",
+                 "prompt_filter_results": [
+                     {
+                         "prompt_index": 0,
+                         "content_filter_results": {
+                             "hate": {
+                                 "filtered": false,
+                                 "severity": "safe"
+                             },
+                             "self_harm": {
+                                 "filtered": false,
+                                 "severity": "safe"
+                             },
+                             "sexual": {
+                                 "filtered": false,
+                                 "severity": "safe"
+                             },
+                             "violence": {
+                                 "filtered": false,
+                                 "severity": "safe"
+                             }
+                         }
+                     }
+                 ],
+                 "usage": {
+                     "completion_tokens": 138,
+                     "prompt_tokens": 11,
+                     "total_tokens": 149
+                 }
+             }""";
+
+        ChatCompletionResults chatCompletionResults = AzureOpenAiCompletionResponseEntity.fromResponse(
+            mock(Request.class),
+            new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+        );
+
+        assertThat(chatCompletionResults.getResults().size(), equalTo(1));
+
+        ChatCompletionResults.Result result = chatCompletionResults.getResults().get(0);
+        assertThat(result.asMap().get(result.getResultsField()), is("response"));
+    }
+
+    public void testFromResponse_FailsWhenChoicesFieldIsNotPresent() {
+        String responseJson = """
+            {
+                "not_choices": [
+                            {
+                                "finish_reason": "stop",
+                                "index": 0,
+                                "logprobs": null,
+                                "message": {
+                                    "content": "response",
+                                    "role": "assistant"
+                                    }
+                                }
+                            ],
+                            "model": "gpt-4",
+                            "object": "chat.completion"
+            }""";
+
+        var thrownException = expectThrows(
+            IllegalStateException.class,
+            () -> AzureOpenAiCompletionResponseEntity.fromResponse(
+                mock(Request.class),
+                new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+            )
+        );
+
+        assertThat(thrownException.getMessage(), is("Failed to find required field [choices] in Azure OpenAI completions response"));
+    }
+
+    public void testFromResponse_FailsWhenChoicesFieldIsNotAnArray() {
+        String responseJson = """
+            {
+                "choices": {
+                                "finish_reason": "stop",
+                                "index": 0,
+                                "logprobs": null,
+                                "message": {
+                                    "content": "response",
+                                    "role": "assistant"
+                                    }
+                            },
+                            "model": "gpt-4",
+                            "object": "chat.completion"
+                            ]
+            }""";
+
+        var thrownException = expectThrows(
+            ParsingException.class,
+            () -> AzureOpenAiCompletionResponseEntity.fromResponse(
+                mock(Request.class),
+                new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+            )
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            is("Failed to parse object: expecting token of type [START_OBJECT] but found [FIELD_NAME]")
+        );
+    }
+
+    public void testFromResponse_FailsWhenMessageDoesNotExist() {
+        String responseJson = """
+            {
+                "choices": [
+                            {
+                                "finish_reason": "stop",
+                                "index": 0,
+                                "logprobs": null,
+                                "not_message": {
+                                    "content": "response",
+                                    "role": "assistant"
+                                    }
+                                }
+                            ],
+                            "model": "gpt-4",
+                            "object": "chat.completion"
+            }""";
+
+        var thrownException = expectThrows(
+            IllegalStateException.class,
+            () -> AzureOpenAiCompletionResponseEntity.fromResponse(
+                mock(Request.class),
+                new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+            )
+        );
+
+        assertThat(thrownException.getMessage(), is("Failed to find required field [message] in Azure OpenAI completions response"));
+    }
+
+    public void testFromResponse_FailsWhenMessageValueIsAString() {
+        String responseJson = """
+            {
+                "choices": [
+                            {
+                                "finish_reason": "stop",
+                                "index": 0,
+                                "logprobs": null,
+                                "message": "string"
+                                }
+                            ],
+                            "model": "gpt-4",
+                            "object": "chat.completion"
+                            ]
+            }""";
+
+        var thrownException = expectThrows(
+            ParsingException.class,
+            () -> AzureOpenAiCompletionResponseEntity.fromResponse(
+                mock(Request.class),
+                new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
+            )
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            is("Failed to parse object: expecting token of type [START_OBJECT] but found [VALUE_STRING]")
+        );
+    }
+
+}

+ 3 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/openai/OpenAiChatCompletionResponseEntityTests.java

@@ -74,7 +74,7 @@ public class OpenAiChatCompletionResponseEntityTests extends ESTestCase {
                   },
                   "logprobs": null,
                   "finish_reason": "stop"
-                },
+                }
               ],
               "usage": {
                 "prompt_tokens": 46,
@@ -112,7 +112,7 @@ public class OpenAiChatCompletionResponseEntityTests extends ESTestCase {
                   },
                   "logprobs": null,
                   "finish_reason": "stop"
-                },
+                }
               },
               "usage": {
                 "prompt_tokens": 46,
@@ -153,7 +153,7 @@ public class OpenAiChatCompletionResponseEntityTests extends ESTestCase {
                   },
                   "logprobs": null,
                   "finish_reason": "stop"
-                },
+                }
               ],
               "usage": {
                 "prompt_tokens": 46,

+ 142 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModelTests.java

@@ -0,0 +1,142 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.completion;
+
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
+
+import java.net.URISyntaxException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.sameInstance;
+
+public class AzureOpenAiCompletionModelTests extends ESTestCase {
+
+    public void testOverrideWith_UpdatedTaskSettings_OverridesUser() {
+        var resource = "resource";
+        var deploymentId = "deployment";
+        var apiVersion = "api version";
+        var apiKey = "api key";
+        var entraId = "entra id";
+        var inferenceEntityId = "inference entity id";
+
+        var user = "user";
+        var userOverride = "user override";
+
+        var model = createCompletionModel(resource, deploymentId, apiVersion, user, apiKey, entraId, inferenceEntityId);
+        var requestTaskSettingsMap = taskSettingsMap(userOverride);
+        var overriddenModel = AzureOpenAiCompletionModel.of(model, requestTaskSettingsMap);
+
+        assertThat(
+            overriddenModel,
+            equalTo(createCompletionModel(resource, deploymentId, apiVersion, userOverride, apiKey, entraId, inferenceEntityId))
+        );
+    }
+
+    public void testOverrideWith_EmptyMap_OverridesNothing() {
+        var model = createCompletionModel("resource", "deployment", "api version", "user", "api key", "entra id", "inference entity id");
+        var requestTaskSettingsMap = Map.<String, Object>of();
+        var overriddenModel = AzureOpenAiCompletionModel.of(model, requestTaskSettingsMap);
+
+        assertThat(overriddenModel, sameInstance(model));
+    }
+
+    public void testOverrideWith_NullMap_OverridesNothing() {
+        var model = createCompletionModel("resource", "deployment", "api version", "user", "api key", "entra id", "inference entity id");
+        var overriddenModel = AzureOpenAiCompletionModel.of(model, null);
+
+        assertThat(overriddenModel, sameInstance(model));
+    }
+
+    public void testOverrideWith_UpdatedServiceSettings_OverridesApiVersion() {
+        var resource = "resource";
+        var deploymentId = "deployment";
+        var apiKey = "api key";
+        var user = "user";
+        var entraId = "entra id";
+        var inferenceEntityId = "inference entity id";
+
+        var apiVersion = "api version";
+        var updatedApiVersion = "updated api version";
+
+        var updatedServiceSettings = new AzureOpenAiCompletionServiceSettings(resource, deploymentId, updatedApiVersion, null);
+
+        var model = createCompletionModel(resource, deploymentId, apiVersion, user, apiKey, entraId, inferenceEntityId);
+        var overriddenModel = new AzureOpenAiCompletionModel(model, updatedServiceSettings);
+
+        assertThat(
+            overriddenModel,
+            is(createCompletionModel(resource, deploymentId, updatedApiVersion, user, apiKey, entraId, inferenceEntityId))
+        );
+    }
+
+    public void testBuildUriString() throws URISyntaxException {
+        var resource = "resource";
+        var deploymentId = "deployment";
+        var apiKey = "api key";
+        var user = "user";
+        var entraId = "entra id";
+        var inferenceEntityId = "inference entity id";
+        var apiVersion = "2024";
+
+        var model = createCompletionModel(resource, deploymentId, apiVersion, user, apiKey, entraId, inferenceEntityId);
+
+        assertThat(
+            model.buildUriString().toString(),
+            is("https://resource.openai.azure.com/openai/deployments/deployment/chat/completions?api-version=2024")
+        );
+    }
+
+    public static AzureOpenAiCompletionModel createModelWithRandomValues() {
+        return createCompletionModel(
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10)
+        );
+    }
+
+    public static AzureOpenAiCompletionModel createCompletionModel(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        String user,
+        @Nullable String apiKey,
+        @Nullable String entraId,
+        String inferenceEntityId
+    ) {
+        var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null;
+        var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null;
+
+        return new AzureOpenAiCompletionModel(
+            inferenceEntityId,
+            TaskType.COMPLETION,
+            "service",
+            new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null),
+            new AzureOpenAiCompletionTaskSettings(user),
+            new AzureOpenAiSecretSettings(secureApiKey, secureEntraId)
+        );
+    }
+
+    private Map<String, Object> taskSettingsMap(String user) {
+        Map<String, Object> taskSettingsMap = new HashMap<>();
+        taskSettingsMap.put(AzureOpenAiServiceFields.USER, user);
+        return taskSettingsMap;
+    }
+
+}

+ 45 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettingsTests.java

@@ -0,0 +1,45 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.completion;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+public class AzureOpenAiCompletionRequestTaskSettingsTests extends ESTestCase {
+
+    public void testFromMap_ReturnsEmptySettings_WhenMapIsEmpty() {
+        var settings = AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of()));
+        assertThat(settings, is(AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS));
+    }
+
+    public void testFromMap_ReturnsEmptySettings_WhenMapDoesNotContainKnownFields() {
+        var settings = AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model")));
+        assertThat(settings, is(AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS));
+    }
+
+    public void testFromMap_ReturnsUser() {
+        var settings = AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user")));
+        assertThat(settings.user(), is("user"));
+    }
+
+    public void testFromMap_WhenUserIsEmpty_ThrowsValidationException() {
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "")))
+        );
+
+        assertThat(exception.getMessage(), containsString("[user] must be a non-empty string"));
+    }
+}

+ 92 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionServiceSettingsTests.java

@@ -0,0 +1,92 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.completion;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class AzureOpenAiCompletionServiceSettingsTests extends AbstractWireSerializingTestCase<AzureOpenAiCompletionServiceSettings> {
+
+    private static AzureOpenAiCompletionServiceSettings createRandom() {
+        var resourceName = randomAlphaOfLength(8);
+        var deploymentId = randomAlphaOfLength(8);
+        var apiVersion = randomAlphaOfLength(8);
+
+        return new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null);
+    }
+
+    public void testFromMap_Request_CreatesSettingsCorrectly() {
+        var resourceName = "this-resource";
+        var deploymentId = "this-deployment";
+        var apiVersion = "2024-01-01";
+
+        var serviceSettings = AzureOpenAiCompletionServiceSettings.fromMap(
+            new HashMap<>(
+                Map.of(
+                    AzureOpenAiServiceFields.RESOURCE_NAME,
+                    resourceName,
+                    AzureOpenAiServiceFields.DEPLOYMENT_ID,
+                    deploymentId,
+                    AzureOpenAiServiceFields.API_VERSION,
+                    apiVersion
+                )
+            )
+        );
+
+        assertThat(serviceSettings, is(new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null)));
+    }
+
+    public void testToXContent_WritesAllValues() throws IOException {
+        var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is("""
+            {"resource_name":"resource","deployment_id":"deployment","api_version":"2024","rate_limit":{"requests_per_minute":120}}"""));
+    }
+
+    public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException {
+        var entity = new AzureOpenAiCompletionServiceSettings("resource", "deployment", "2024", null);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        var filteredXContent = entity.getFilteredXContentObject();
+        filteredXContent.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is("""
+            {"resource_name":"resource","deployment_id":"deployment","api_version":"2024"}"""));
+    }
+
+    @Override
+    protected Writeable.Reader<AzureOpenAiCompletionServiceSettings> instanceReader() {
+        return AzureOpenAiCompletionServiceSettings::new;
+    }
+
+    @Override
+    protected AzureOpenAiCompletionServiceSettings createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected AzureOpenAiCompletionServiceSettings mutateInstance(AzureOpenAiCompletionServiceSettings instance) throws IOException {
+        return createRandom();
+    }
+}

+ 99 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java

@@ -0,0 +1,99 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.completion;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class AzureOpenAiCompletionTaskSettingsTests extends AbstractWireSerializingTestCase<AzureOpenAiCompletionTaskSettings> {
+
+    public static AzureOpenAiCompletionTaskSettings createRandomWithUser() {
+        return new AzureOpenAiCompletionTaskSettings(randomAlphaOfLength(15));
+    }
+
+    public static AzureOpenAiCompletionTaskSettings createRandom() {
+        var user = randomBoolean() ? randomAlphaOfLength(15) : null;
+        return new AzureOpenAiCompletionTaskSettings(user);
+    }
+
+    public void testFromMap_WithUser() {
+        var user = "user";
+
+        assertThat(
+            new AzureOpenAiCompletionTaskSettings(user),
+            is(AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, user))))
+        );
+    }
+
+    public void testFromMap_UserIsEmptyString() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "")))
+        );
+
+        MatcherAssert.assertThat(
+            thrownException.getMessage(),
+            is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;"))
+        );
+    }
+
+    public void testFromMap_MissingUser_DoesNotThrowException() {
+        var taskSettings = AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of()));
+        assertNull(taskSettings.user());
+    }
+
+    public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() {
+        var taskSettings = AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user")));
+
+        var overriddenTaskSettings = AzureOpenAiCompletionTaskSettings.of(
+            taskSettings,
+            AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS
+        );
+        assertThat(overriddenTaskSettings, is(taskSettings));
+    }
+
+    public void testOverrideWith_UsesOverriddenSettings() {
+        var user = "user";
+        var userOverride = "user override";
+
+        var taskSettings = AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, user)));
+
+        var requestTaskSettings = AzureOpenAiCompletionRequestTaskSettings.fromMap(
+            new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, userOverride))
+        );
+
+        var overriddenTaskSettings = AzureOpenAiCompletionTaskSettings.of(taskSettings, requestTaskSettings);
+        assertThat(overriddenTaskSettings, is(new AzureOpenAiCompletionTaskSettings(userOverride)));
+    }
+
+    @Override
+    protected Writeable.Reader<AzureOpenAiCompletionTaskSettings> instanceReader() {
+        return AzureOpenAiCompletionTaskSettings::new;
+    }
+
+    @Override
+    protected AzureOpenAiCompletionTaskSettings createTestInstance() {
+        return createRandomWithUser();
+    }
+
+    @Override
+    protected AzureOpenAiCompletionTaskSettings mutateInstance(AzureOpenAiCompletionTaskSettings instance) throws IOException {
+        return createRandomWithUser();
+    }
+}

+ 30 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
 
+import java.net.URISyntaxException;
 import java.util.Map;
 
 import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettingsTests.getAzureOpenAiRequestTaskSettingsMap;
@@ -65,6 +66,35 @@ public class AzureOpenAiEmbeddingsModelTests extends ESTestCase {
         assertThat(overridenModel, is(createModel("resource", "deployment", "override_apiversion", "user", "api_key", null, "id")));
     }
 
+    public void testBuildUriString() throws URISyntaxException {
+        var resource = "resource";
+        var deploymentId = "deployment";
+        var apiKey = "api key";
+        var user = "user";
+        var entraId = "entra id";
+        var inferenceEntityId = "inference entity id";
+        var apiVersion = "2024";
+
+        var model = createModel(resource, deploymentId, apiVersion, user, apiKey, entraId, inferenceEntityId);
+
+        assertThat(
+            model.buildUriString().toString(),
+            is("https://resource.openai.azure.com/openai/deployments/deployment/embeddings?api-version=2024")
+        );
+    }
+
+    public static AzureOpenAiEmbeddingsModel createModelWithRandomValues() {
+        return createModel(
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10),
+            randomAlphaOfLength(10)
+        );
+    }
+
     public static AzureOpenAiEmbeddingsModel createModel(
         String resourceName,
         String deploymentId,