浏览代码

Add Azure OpenAI Embeddings Inference Service (#107178)

* initial start to Azure OpenAI Embeddings

* some cleanups; adding more tests; breaking

* cleanups; all test so far passing;

* cleanups; checkstyle; finish tests

* checkstyle cleanups; spotless apply

* remove String.format usage

* smoke tested and working; some cleanups

* cleanup unneeded comments

* cleanup wayward comment

* finalize tests; set model as URI holder

* fixups after rebase; notably add timeout param

* cleanups; remove AzureResponse in favour of OpenAI

* ensure dimensions_set_by_user cannot be in request

* move AzureOpenAiSecretSettings to azureopenai pkg

* fix lint

* add similarity for service settings; cleanups;

* allow request similarity;correct secret validation
Mark J. Hoy 1 年之前
父节点
当前提交
aedc07da4e
共有 35 个文件被更改,包括 4499 次插入2 次删除
  1. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  2. 27 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
  3. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
  4. 35 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java
  5. 17 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java
  6. 53 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsAction.java
  7. 40 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiAccount.java
  8. 52 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java
  9. 63 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsExecutableRequestCreator.java
  10. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java
  11. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java
  12. 110 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java
  13. 49 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntity.java
  14. 12 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java
  15. 20 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java
  16. 19 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
  17. 49 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java
  18. 101 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java
  19. 296 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
  20. 16 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceFields.java
  21. 116 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java
  22. 54 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java
  23. 282 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java
  24. 114 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java
  25. 454 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java
  26. 219 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java
  27. 88 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandlerTests.java
  28. 77 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntityTests.java
  29. 118 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestTests.java
  30. 160 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java
  31. 1180 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
  32. 121 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java
  33. 56 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java
  34. 389 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java
  35. 107 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -172,6 +172,7 @@ public class TransportVersions {
     public static final TransportVersion ML_INFERENCE_RERANK_NEW_RESPONSE_FORMAT = def(8_631_00_0);
     public static final TransportVersion HIGHLIGHTERS_TAGS_ON_FIELD_LEVEL = def(8_632_00_0);
     public static final TransportVersion TRACK_FLUSH_TIME_EXCLUDING_WAITING_ON_LOCKS = def(8_633_00_0);
+    public static final TransportVersion ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS = def(8_634_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 27 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

@@ -24,6 +24,9 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings;
 import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings;
 import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
@@ -202,6 +205,30 @@ public class InferenceNamedWriteablesProvider {
             new NamedWriteableRegistry.Entry(TaskSettings.class, CohereRerankTaskSettings.NAME, CohereRerankTaskSettings::new)
         );
 
+        // Azure OpenAI
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                AzureOpenAiSecretSettings.class,
+                AzureOpenAiSecretSettings.NAME,
+                AzureOpenAiSecretSettings::new
+            )
+        );
+
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                ServiceSettings.class,
+                AzureOpenAiEmbeddingsServiceSettings.NAME,
+                AzureOpenAiEmbeddingsServiceSettings::new
+            )
+        );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                TaskSettings.class,
+                AzureOpenAiEmbeddingsTaskSettings.NAME,
+                AzureOpenAiEmbeddingsTaskSettings::new
+            )
+        );
+
         return namedWriteables;
     }
 }

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -56,6 +56,7 @@ import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction;
 import org.elasticsearch.xpack.inference.rest.RestInferenceAction;
 import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
 import org.elasticsearch.xpack.inference.services.cohere.CohereService;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
 import org.elasticsearch.xpack.inference.services.elser.ElserInternalService;
@@ -176,6 +177,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
             context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
             context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
             context -> new CohereService(httpFactory.get(), serviceComponents.get()),
+            context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()),
             ElasticsearchInternalService::new
         );
     }

+ 35 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java

@@ -0,0 +1,35 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.azureopenai;
+
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type.
+ */
+public class AzureOpenAiActionCreator implements AzureOpenAiActionVisitor {
+    private final Sender sender;
+    private final ServiceComponents serviceComponents;
+
+    public AzureOpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
+        this.sender = Objects.requireNonNull(sender);
+        this.serviceComponents = Objects.requireNonNull(serviceComponents);
+    }
+
+    @Override
+    public ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map<String, Object> taskSettings) {
+        var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, taskSettings);
+        return new AzureOpenAiEmbeddingsAction(sender, overriddenModel, serviceComponents);
+    }
+}

+ 17 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java

@@ -0,0 +1,17 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.azureopenai;
+
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+
+import java.util.Map;
+
+public interface AzureOpenAiActionVisitor {
+    ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map<String, Object> taskSettings);
+}

+ 53 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsAction.java

@@ -0,0 +1,53 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.azureopenai;
+
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiEmbeddingsExecutableRequestCreator;
+import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
+import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
+
+public class AzureOpenAiEmbeddingsAction implements ExecutableAction {
+
+    private final String errorMessage;
+    private final AzureOpenAiEmbeddingsExecutableRequestCreator requestCreator;
+    private final Sender sender;
+
+    public AzureOpenAiEmbeddingsAction(Sender sender, AzureOpenAiEmbeddingsModel model, ServiceComponents serviceComponents) {
+        Objects.requireNonNull(serviceComponents);
+        Objects.requireNonNull(model);
+        this.sender = Objects.requireNonNull(sender);
+        requestCreator = new AzureOpenAiEmbeddingsExecutableRequestCreator(model, serviceComponents.truncator());
+        errorMessage = constructFailedToSendRequestMessage(model.getUri(), "Azure OpenAI embeddings");
+    }
+
+    @Override
+    public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
+        try {
+            ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);
+
+            sender.send(requestCreator, inferenceInputs, timeout, wrappedListener);
+        } catch (ElasticsearchException e) {
+            listener.onFailure(e);
+        } catch (Exception e) {
+            listener.onFailure(createInternalServerError(e, errorMessage));
+        }
+    }
+}

+ 40 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiAccount.java

@@ -0,0 +1,40 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.azureopenai;
+
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+
+import java.util.Objects;
+
+public record AzureOpenAiAccount(
+    String resourceName,
+    String deploymentId,
+    String apiVersion,
+    @Nullable SecureString apiKey,
+    @Nullable SecureString entraId
+) {
+
+    public AzureOpenAiAccount {
+        Objects.requireNonNull(resourceName);
+        Objects.requireNonNull(deploymentId);
+        Objects.requireNonNull(apiVersion);
+        Objects.requireNonNullElse(apiKey, entraId);
+    }
+
+    public static AzureOpenAiAccount fromModel(AzureOpenAiEmbeddingsModel model) {
+        return new AzureOpenAiAccount(
+            model.getServiceSettings().resourceName(),
+            model.getServiceSettings().deploymentId(),
+            model.getServiceSettings().apiVersion(),
+            model.getSecretSettings().apiKey(),
+            model.getSecretSettings().entraId()
+        );
+    }
+}

+ 52 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java

@@ -0,0 +1,52 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.azureopenai;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
+import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
+import org.elasticsearch.xpack.inference.external.openai.OpenAiResponseHandler;
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;
+
+public class AzureOpenAiResponseHandler extends OpenAiResponseHandler {
+
+    /**
+     * These headers for Azure OpenAi are mostly the same as the OpenAi ones with the major exception
+     * that there is no information returned about the request limit or the tokens limit
+     *
+     * Microsoft does not seem to have any published information in their docs about this, but more
+     * information can be found in the following Medium article and accompanying code:
+     *   - https://pablo-81685.medium.com/azure-openais-api-headers-unpacked-6dbe881e732a
+     *   - https://github.com/pablosalvador10/gbbai-azure-ai-aoai
+     */
+    static final String REMAINING_REQUESTS = "x-ratelimit-remaining-requests";
+    // The remaining number of tokens that are permitted before exhausting the rate limit.
+    static final String REMAINING_TOKENS = "x-ratelimit-remaining-tokens";
+
+    public AzureOpenAiResponseHandler(String requestType, ResponseParser parseFunction) {
+        super(requestType, parseFunction);
+    }
+
+    @Override
+    protected RetryException buildExceptionHandling429(Request request, HttpResult result) {
+        return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result));
+    }
+
+    static String buildRateLimitErrorMessage(HttpResult result) {
+        var response = result.response();
+        var remainingTokens = getFirstHeaderOrUnknown(response, REMAINING_TOKENS);
+        var remainingRequests = getFirstHeaderOrUnknown(response, REMAINING_REQUESTS);
+        var usageMessage = Strings.format("Remaining tokens [%s]. Remaining requests [%s].", remainingTokens, remainingRequests);
+
+        return RATE_LIMIT + ". " + usageMessage;
+    }
+
+}

+ 63 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsExecutableRequestCreator.java

@@ -0,0 +1,63 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.http.sender;
+
+import org.apache.http.client.protocol.HttpClientContext;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.inference.common.Truncator;
+import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiAccount;
+import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiResponseHandler;
+import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
+import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiEmbeddingsRequest;
+import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.function.Supplier;
+
+import static org.elasticsearch.xpack.inference.common.Truncator.truncate;
+
+public class AzureOpenAiEmbeddingsExecutableRequestCreator implements ExecutableRequestCreator {
+
+    private static final Logger logger = LogManager.getLogger(AzureOpenAiEmbeddingsExecutableRequestCreator.class);
+
+    private static final ResponseHandler HANDLER = createEmbeddingsHandler();
+
+    private static ResponseHandler createEmbeddingsHandler() {
+        return new AzureOpenAiResponseHandler("azure openai text embedding", OpenAiEmbeddingsResponseEntity::fromResponse);
+    }
+
+    private final Truncator truncator;
+    private final AzureOpenAiEmbeddingsModel model;
+    private final AzureOpenAiAccount account;
+
+    public AzureOpenAiEmbeddingsExecutableRequestCreator(AzureOpenAiEmbeddingsModel model, Truncator truncator) {
+        this.model = Objects.requireNonNull(model);
+        this.account = AzureOpenAiAccount.fromModel(model);
+        this.truncator = Objects.requireNonNull(truncator);
+    }
+
+    @Override
+    public Runnable create(
+        String query,
+        List<String> input,
+        RequestSender requestSender,
+        Supplier<Boolean> hasRequestCompletedFunction,
+        HttpClientContext context,
+        ActionListener<InferenceServiceResults> listener
+    ) {
+        var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens());
+        AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, account, truncatedInput, model);
+        return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener);
+    }
+}

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java

@@ -18,7 +18,7 @@ public class OpenAiChatCompletionResponseHandler extends OpenAiResponseHandler {
     }
 
     @Override
-    RetryException buildExceptionHandling429(Request request, HttpResult result) {
+    protected RetryException buildExceptionHandling429(Request request, HttpResult result) {
         // We don't retry, if the chat completion input is too large
         return new RetryException(false, buildError(RATE_LIMIT, request, result));
     }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java

@@ -83,7 +83,7 @@ public class OpenAiResponseHandler extends BaseResponseHandler {
         }
     }
 
-    RetryException buildExceptionHandling429(Request request, HttpResult result) {
+    protected RetryException buildExceptionHandling429(Request request, HttpResult result) {
         return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result));
     }
 

+ 110 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java

@@ -0,0 +1,110 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.apache.http.entity.ByteArrayEntity;
+import org.apache.http.message.BasicHeader;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.common.Truncator;
+import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiAccount;
+import org.elasticsearch.xpack.inference.external.request.HttpRequest;
+import org.elasticsearch.xpack.inference.external.request.Request;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
+import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.API_KEY;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.ENTRA_ID;
+
+public class AzureOpenAiEmbeddingsRequest implements AzureOpenAiRequest {
+    private static final String MISSING_AUTHENTICATION_ERROR_MESSAGE =
+        "The request does not have any authentication methods set. One of [%s] or [%s] is required.";
+
+    private final Truncator truncator;
+    private final AzureOpenAiAccount account;
+    private final Truncator.TruncationResult truncationResult;
+    private final URI uri;
+    private final AzureOpenAiEmbeddingsModel model;
+
+    public AzureOpenAiEmbeddingsRequest(
+        Truncator truncator,
+        AzureOpenAiAccount account,
+        Truncator.TruncationResult input,
+        AzureOpenAiEmbeddingsModel model
+    ) {
+        this.truncator = Objects.requireNonNull(truncator);
+        this.account = Objects.requireNonNull(account);
+        this.truncationResult = Objects.requireNonNull(input);
+        this.model = Objects.requireNonNull(model);
+        this.uri = model.getUri();
+    }
+
+    public HttpRequest createHttpRequest() {
+        HttpPost httpPost = new HttpPost(uri);
+
+        String requestEntity = Strings.toString(
+            new AzureOpenAiEmbeddingsRequestEntity(
+                truncationResult.input(),
+                model.getTaskSettings().user(),
+                model.getServiceSettings().dimensions(),
+                model.getServiceSettings().dimensionsSetByUser()
+            )
+        );
+
+        ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
+        httpPost.setEntity(byteEntity);
+
+        httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));
+
+        var entraId = model.getSecretSettings().entraId();
+        var apiKey = model.getSecretSettings().apiKey();
+
+        if (entraId != null && entraId.isEmpty() == false) {
+            httpPost.setHeader(createAuthBearerHeader(entraId));
+        } else if (apiKey != null && apiKey.isEmpty() == false) {
+            httpPost.setHeader(new BasicHeader(API_KEY_HEADER, apiKey.toString()));
+        } else {
+            // should never happen due to the checks on the secret settings, but just in case
+            ValidationException validationException = new ValidationException();
+            validationException.addValidationError(Strings.format(MISSING_AUTHENTICATION_ERROR_MESSAGE, API_KEY, ENTRA_ID));
+            throw validationException;
+        }
+
+        return new HttpRequest(httpPost, getInferenceEntityId());
+    }
+
+    @Override
+    public URI getURI() {
+        return this.uri;
+    }
+
+    @Override
+    public String getInferenceEntityId() {
+        return model.getInferenceEntityId();
+    }
+
+    @Override
+    public Request truncate() {
+        var truncatedInput = truncator.truncate(truncationResult.input());
+
+        return new AzureOpenAiEmbeddingsRequest(truncator, account, truncatedInput, model);
+    }
+
+    @Override
+    public boolean[] getTruncationInfo() {
+        return truncationResult.truncated().clone();
+    }
+}

+ 49 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntity.java

@@ -0,0 +1,49 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Objects;
+
+public record AzureOpenAiEmbeddingsRequestEntity(
+    List<String> input,
+    @Nullable String user,
+    @Nullable Integer dimensions,
+    boolean dimensionsSetByUser
+) implements ToXContentObject {
+
+    private static final String INPUT_FIELD = "input";
+    private static final String USER_FIELD = "user";
+    private static final String DIMENSIONS_FIELD = "dimensions";
+
+    public AzureOpenAiEmbeddingsRequestEntity {
+        Objects.requireNonNull(input);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(INPUT_FIELD, input);
+
+        if (user != null) {
+            builder.field(USER_FIELD, user);
+        }
+
+        if (dimensionsSetByUser && dimensions != null) {
+            builder.field(DIMENSIONS_FIELD, dimensions);
+        }
+
+        builder.endObject();
+        return builder;
+    }
+}

+ 12 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java

@@ -0,0 +1,12 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai;
+
+import org.elasticsearch.xpack.inference.external.request.Request;
+
+public interface AzureOpenAiRequest extends Request {}

+ 20 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java

@@ -0,0 +1,20 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai;
+
+public class AzureOpenAiUtils {
+
+    public static final String HOST_SUFFIX = "openai.azure.com";
+    public static final String OPENAI_PATH = "openai";
+    public static final String DEPLOYMENTS_PATH = "deployments";
+    public static final String EMBEDDINGS_PATH = "embeddings";
+    public static final String API_VERSION_PARAMETER = "api-version";
+    public static final String API_KEY_HEADER = "api-key";
+
+    private AzureOpenAiUtils() {}
+}

+ 19 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

@@ -139,6 +139,10 @@ public class ServiceUtils {
         );
     }
 
+    public static String invalidSettingError(String settingName, String scope) {
+        return Strings.format("[%s] does not allow the setting [%s]", scope, settingName);
+    }
+
     // TODO improve URI validation logic
     public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) {
         try {
@@ -186,6 +190,21 @@ public class ServiceUtils {
         return new SecureString(Objects.requireNonNull(requiredField).toCharArray());
     }
 
+    public static SecureString extractOptionalSecureString(
+        Map<String, Object> map,
+        String settingName,
+        String scope,
+        ValidationException validationException
+    ) {
+        String optionalField = extractOptionalString(map, settingName, scope, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false || optionalField == null) {
+            return null;
+        }
+
+        return new SecureString(optionalField.toCharArray());
+    }
+
     public static SimilarityMeasure extractSimilarity(Map<String, Object> map, String scope, ValidationException validationException) {
         return extractOptionalEnum(
             map,

+ 49 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java

@@ -0,0 +1,49 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai;
+
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor;
+
+import java.net.URI;
+import java.util.Map;
+
+public abstract class AzureOpenAiModel extends Model {
+
+    protected URI uri;
+
+    public AzureOpenAiModel(ModelConfigurations configurations, ModelSecrets secrets) {
+        super(configurations, secrets);
+    }
+
+    protected AzureOpenAiModel(AzureOpenAiModel model, TaskSettings taskSettings) {
+        super(model, taskSettings);
+        this.uri = model.getUri();
+    }
+
+    protected AzureOpenAiModel(AzureOpenAiModel model, ServiceSettings serviceSettings) {
+        super(model, serviceSettings);
+        this.uri = model.getUri();
+    }
+
+    public abstract ExecutableAction accept(AzureOpenAiActionVisitor creator, Map<String, Object> taskSettings);
+
+    public URI getUri() {
+        return uri;
+    }
+
+    // Needed for testing
+    public void setUri(URI newUri) {
+        this.uri = newUri;
+    }
+}

+ 101 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java

@@ -0,0 +1,101 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SecretSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalSecureString;
+
+public record AzureOpenAiSecretSettings(@Nullable SecureString apiKey, @Nullable SecureString entraId) implements SecretSettings {
+
+    public static final String NAME = "azure_openai_secret_settings";
+    public static final String API_KEY = "api_key";
+    public static final String ENTRA_ID = "entra_id";
+
+    public static AzureOpenAiSecretSettings fromMap(@Nullable Map<String, Object> map) {
+        if (map == null) {
+            return null;
+        }
+
+        ValidationException validationException = new ValidationException();
+        SecureString secureApiToken = extractOptionalSecureString(map, API_KEY, ModelSecrets.SECRET_SETTINGS, validationException);
+        SecureString secureEntraId = extractOptionalSecureString(map, ENTRA_ID, ModelSecrets.SECRET_SETTINGS, validationException);
+
+        if (secureApiToken == null && secureEntraId == null) {
+            validationException.addValidationError(
+                format("[secret_settings] must have either the [%s] or the [%s] key set", API_KEY, ENTRA_ID)
+            );
+        }
+
+        if (secureApiToken != null && secureEntraId != null) {
+            validationException.addValidationError(
+                format("[secret_settings] must have only one of the [%s] or the [%s] key set", API_KEY, ENTRA_ID)
+            );
+        }
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new AzureOpenAiSecretSettings(secureApiToken, secureEntraId);
+    }
+
+    public AzureOpenAiSecretSettings {
+        Objects.requireNonNullElse(apiKey, entraId);
+    }
+
+    public AzureOpenAiSecretSettings(StreamInput in) throws IOException {
+        this(in.readOptionalSecureString(), in.readOptionalSecureString());
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        if (apiKey != null) {
+            builder.field(API_KEY, apiKey.toString());
+        }
+
+        if (entraId != null) {
+            builder.field(ENTRA_ID, entraId.toString());
+        }
+
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalSecureString(apiKey);
+        out.writeOptionalSecureString(entraId);
+    }
+}

+ 296 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

@@ -0,0 +1,296 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
+import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator;
+import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.SenderService;
+import org.elasticsearch.xpack.inference.services.ServiceComponents;
+import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+
+public class AzureOpenAiService extends SenderService {
+    public static final String NAME = "azureopenai";
+
+    public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
+        super(factory, serviceComponents);
+    }
+
+    @Override
+    public String name() {
+        return NAME;
+    }
+
+    @Override
+    public void parseRequestConfig(
+        String inferenceEntityId,
+        TaskType taskType,
+        Map<String, Object> config,
+        Set<String> platformArchitectures,
+        ActionListener<Model> parsedModelListener
+    ) {
+        try {
+            Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+            Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+
+            AzureOpenAiModel model = createModel(
+                inferenceEntityId,
+                taskType,
+                serviceSettingsMap,
+                taskSettingsMap,
+                serviceSettingsMap,
+                TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
+                ConfigurationParseContext.REQUEST
+            );
+
+            throwIfNotEmptyMap(config, NAME);
+            throwIfNotEmptyMap(serviceSettingsMap, NAME);
+            throwIfNotEmptyMap(taskSettingsMap, NAME);
+
+            parsedModelListener.onResponse(model);
+        } catch (Exception e) {
+            parsedModelListener.onFailure(e);
+        }
+    }
+
+    private static AzureOpenAiModel createModelFromPersistent(
+        String inferenceEntityId,
+        TaskType taskType,
+        Map<String, Object> serviceSettings,
+        Map<String, Object> taskSettings,
+        @Nullable Map<String, Object> secretSettings,
+        String failureMessage
+    ) {
+        return createModel(
+            inferenceEntityId,
+            taskType,
+            serviceSettings,
+            taskSettings,
+            secretSettings,
+            failureMessage,
+            ConfigurationParseContext.PERSISTENT
+        );
+    }
+
+    private static AzureOpenAiModel createModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        Map<String, Object> serviceSettings,
+        Map<String, Object> taskSettings,
+        @Nullable Map<String, Object> secretSettings,
+        String failureMessage,
+        ConfigurationParseContext context
+    ) {
+        if (taskType == TaskType.TEXT_EMBEDDING) {
+            return new AzureOpenAiEmbeddingsModel(
+                inferenceEntityId,
+                taskType,
+                NAME,
+                serviceSettings,
+                taskSettings,
+                secretSettings,
+                context
+            );
+        }
+
+        throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
+    }
+
+    @Override
+    public AzureOpenAiModel parsePersistedConfigWithSecrets(
+        String inferenceEntityId,
+        TaskType taskType,
+        Map<String, Object> config,
+        Map<String, Object> secrets
+    ) {
+        Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+        Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
+        Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
+
+        return createModelFromPersistent(
+            inferenceEntityId,
+            taskType,
+            serviceSettingsMap,
+            taskSettingsMap,
+            secretSettingsMap,
+            parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+        );
+    }
+
+    @Override
+    public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
+        Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+        Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
+
+        return createModelFromPersistent(
+            inferenceEntityId,
+            taskType,
+            serviceSettingsMap,
+            taskSettingsMap,
+            null,
+            parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+        );
+    }
+
+    @Override
+    protected void doInfer(
+        Model model,
+        List<String> input,
+        Map<String, Object> taskSettings,
+        InputType inputType,
+        TimeValue timeout,
+        ActionListener<InferenceServiceResults> listener
+    ) {
+        if (model instanceof AzureOpenAiModel == false) {
+            listener.onFailure(createInvalidModelException(model));
+            return;
+        }
+
+        AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model;
+        var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());
+
+        var action = azureOpenAiModel.accept(actionCreator, taskSettings);
+        action.execute(new DocumentsOnlyInput(input), timeout, listener);
+    }
+
+    @Override
+    protected void doInfer(
+        Model model,
+        String query,
+        List<String> input,
+        Map<String, Object> taskSettings,
+        InputType inputType,
+        TimeValue timeout,
+        ActionListener<InferenceServiceResults> listener
+    ) {
+        throw new UnsupportedOperationException("Azure OpenAI service does not support inference with query input");
+    }
+
+    @Override
+    protected void doChunkedInfer(
+        Model model,
+        String query,
+        List<String> input,
+        Map<String, Object> taskSettings,
+        InputType inputType,
+        ChunkingOptions chunkingOptions,
+        TimeValue timeout,
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
+    ) {
+        ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
+            (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response))
+        );
+
+        doInfer(model, input, taskSettings, inputType, timeout, inferListener);
+    }
+
+    private static List<ChunkedInferenceServiceResults> translateToChunkedResults(
+        List<String> inputs,
+        InferenceServiceResults inferenceResults
+    ) {
+        if (inferenceResults instanceof TextEmbeddingResults textEmbeddingResults) {
+            return ChunkedTextEmbeddingResults.of(inputs, textEmbeddingResults);
+        } else if (inferenceResults instanceof ErrorInferenceResults error) {
+            return List.of(new ErrorChunkedInferenceResults(error.getException()));
+        } else {
+            throw createInvalidChunkedResultException(inferenceResults.getWriteableName());
+        }
+    }
+
+    /**
+     * For text embedding models get the embedding size and
+     * update the service settings.
+     *
+     * @param model The new model
+     * @param listener The listener
+     */
+    @Override
+    public void checkModelConfig(Model model, ActionListener<Model> listener) {
+        if (model instanceof AzureOpenAiEmbeddingsModel embeddingsModel) {
+            ServiceUtils.getEmbeddingSize(
+                model,
+                this,
+                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size)))
+            );
+        } else {
+            listener.onResponse(model);
+        }
+    }
+
+    private AzureOpenAiEmbeddingsModel updateModelWithEmbeddingDetails(AzureOpenAiEmbeddingsModel model, int embeddingSize) {
+        if (model.getServiceSettings().dimensionsSetByUser()
+            && model.getServiceSettings().dimensions() != null
+            && model.getServiceSettings().dimensions() != embeddingSize) {
+            throw new ElasticsearchStatusException(
+                Strings.format(
+                    "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
+                        + "Please recreate the [%s] configuration with the correct dimensions",
+                    embeddingSize,
+                    model.getServiceSettings().dimensions(),
+                    model.getConfigurations().getInferenceEntityId()
+                ),
+                RestStatus.BAD_REQUEST
+            );
+        }
+
+        var similarityFromModel = model.getServiceSettings().similarity();
+        var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
+
+        AzureOpenAiEmbeddingsServiceSettings serviceSettings = new AzureOpenAiEmbeddingsServiceSettings(
+            model.getServiceSettings().resourceName(),
+            model.getServiceSettings().deploymentId(),
+            model.getServiceSettings().apiVersion(),
+            embeddingSize,
+            model.getServiceSettings().dimensionsSetByUser(),
+            model.getServiceSettings().maxInputTokens(),
+            similarityToUse
+        );
+
+        return new AzureOpenAiEmbeddingsModel(model, serviceSettings);
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS;
+    }
+}

+ 16 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceFields.java

@@ -0,0 +1,16 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai;
+
+public class AzureOpenAiServiceFields {
+
+    public static final String RESOURCE_NAME = "resource_name";
+    public static final String DEPLOYMENT_ID = "deployment_id";
+    public static final String API_VERSION = "api_version";
+    public static final String USER = "user";
+}

+ 116 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java

@@ -0,0 +1,116 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.embeddings;
+
+import org.apache.http.client.utils.URIBuilder;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
+import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Map;
+
+import static org.elasticsearch.core.Strings.format;
+
+public class AzureOpenAiEmbeddingsModel extends AzureOpenAiModel {
+
+    public static AzureOpenAiEmbeddingsModel of(AzureOpenAiEmbeddingsModel model, Map<String, Object> taskSettings) {
+        if (taskSettings == null || taskSettings.isEmpty()) {
+            return model;
+        }
+
+        var requestTaskSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(taskSettings);
+        return new AzureOpenAiEmbeddingsModel(model, AzureOpenAiEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
+    }
+
+    public AzureOpenAiEmbeddingsModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        Map<String, Object> serviceSettings,
+        Map<String, Object> taskSettings,
+        @Nullable Map<String, Object> secrets,
+        ConfigurationParseContext context
+    ) {
+        this(
+            inferenceEntityId,
+            taskType,
+            service,
+            AzureOpenAiEmbeddingsServiceSettings.fromMap(serviceSettings, context),
+            AzureOpenAiEmbeddingsTaskSettings.fromMap(taskSettings),
+            AzureOpenAiSecretSettings.fromMap(secrets)
+        );
+    }
+
+    // Should only be used directly for testing
+    AzureOpenAiEmbeddingsModel(
+        String inferenceEntityId,
+        TaskType taskType,
+        String service,
+        AzureOpenAiEmbeddingsServiceSettings serviceSettings,
+        AzureOpenAiEmbeddingsTaskSettings taskSettings,
+        @Nullable AzureOpenAiSecretSettings secrets
+    ) {
+        super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets));
+        try {
+            this.uri = getEmbeddingsUri(serviceSettings.resourceName(), serviceSettings.deploymentId(), serviceSettings.apiVersion());
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    private AzureOpenAiEmbeddingsModel(AzureOpenAiEmbeddingsModel originalModel, AzureOpenAiEmbeddingsTaskSettings taskSettings) {
+        super(originalModel, taskSettings);
+    }
+
+    public AzureOpenAiEmbeddingsModel(AzureOpenAiEmbeddingsModel originalModel, AzureOpenAiEmbeddingsServiceSettings serviceSettings) {
+        super(originalModel, serviceSettings);
+    }
+
+    @Override
+    public AzureOpenAiEmbeddingsServiceSettings getServiceSettings() {
+        return (AzureOpenAiEmbeddingsServiceSettings) super.getServiceSettings();
+    }
+
+    @Override
+    public AzureOpenAiEmbeddingsTaskSettings getTaskSettings() {
+        return (AzureOpenAiEmbeddingsTaskSettings) super.getTaskSettings();
+    }
+
+    @Override
+    public AzureOpenAiSecretSettings getSecretSettings() {
+        return (AzureOpenAiSecretSettings) super.getSecretSettings();
+    }
+
+    @Override
+    public ExecutableAction accept(AzureOpenAiActionVisitor creator, Map<String, Object> taskSettings) {
+        return creator.create(this, taskSettings);
+    }
+
+    public static URI getEmbeddingsUri(String resourceName, String deploymentId, String apiVersion) throws URISyntaxException {
+        String hostname = format("%s.%s", resourceName, AzureOpenAiUtils.HOST_SUFFIX);
+        return new URIBuilder().setScheme("https")
+            .setHost(hostname)
+            .setPathSegments(
+                AzureOpenAiUtils.OPENAI_PATH,
+                AzureOpenAiUtils.DEPLOYMENTS_PATH,
+                deploymentId,
+                AzureOpenAiUtils.EMBEDDINGS_PATH
+            )
+            .addParameter(AzureOpenAiUtils.API_VERSION_PARAMETER, apiVersion)
+            .build();
+    }
+}

+ 54 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java

@@ -0,0 +1,54 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.embeddings;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER;
+
+/**
+ * This class handles extracting Azure OpenAI task settings from a request. The difference between this class and
+ * {@link AzureOpenAiEmbeddingsTaskSettings} is that this class considers all fields as optional. It will not throw an error if a field
+ * is missing. This allows overriding persistent task settings.
+ * @param user a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse
+ */
+public record AzureOpenAiEmbeddingsRequestTaskSettings(@Nullable String user) {
+    private static final Logger logger = LogManager.getLogger(AzureOpenAiEmbeddingsRequestTaskSettings.class);
+
+    public static final AzureOpenAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new AzureOpenAiEmbeddingsRequestTaskSettings(null);
+
+    /**
+     * Extracts the task settings from a map. All settings are considered optional and the absence of a setting
+     * does not throw an error.
+     *
+     * @param map the settings received from a request
+     * @return a {@link AzureOpenAiEmbeddingsRequestTaskSettings}
+     */
+    public static AzureOpenAiEmbeddingsRequestTaskSettings fromMap(Map<String, Object> map) {
+        if (map.isEmpty()) {
+            return AzureOpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS;
+        }
+
+        ValidationException validationException = new ValidationException();
+
+        String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new AzureOpenAiEmbeddingsRequestTaskSettings(user);
+    }
+}

+ 282 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java

@@ -0,0 +1,282 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.embeddings;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.ServiceUtils;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
+import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME;
+
+/**
+ * Defines the service settings for interacting with OpenAI's text embedding models.
+ */
+public class AzureOpenAiEmbeddingsServiceSettings implements ServiceSettings {
+
+    public static final String NAME = "azure_openai_embeddings_service_settings";
+
+    static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user";
+
+    public static AzureOpenAiEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
+        ValidationException validationException = new ValidationException();
+
+        var settings = fromMap(map, validationException, context);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new AzureOpenAiEmbeddingsServiceSettings(settings);
+    }
+
+    private static CommonFields fromMap(
+        Map<String, Object> map,
+        ValidationException validationException,
+        ConfigurationParseContext context
+    ) {
+        String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
+        Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
+        Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class);
+        SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
+
+        Boolean dimensionsSetByUser = extractOptionalBoolean(
+            map,
+            DIMENSIONS_SET_BY_USER,
+            ModelConfigurations.SERVICE_SETTINGS,
+            validationException
+        );
+
+        switch (context) {
+            case REQUEST -> {
+                if (dimensionsSetByUser != null) {
+                    validationException.addValidationError(
+                        ServiceUtils.invalidSettingError(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS)
+                    );
+                }
+                dimensionsSetByUser = dims != null;
+            }
+            case PERSISTENT -> {
+                if (dimensionsSetByUser == null) {
+                    validationException.addValidationError(
+                        ServiceUtils.missingSettingErrorMsg(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS)
+                    );
+                }
+            }
+        }
+
+        return new CommonFields(
+            resourceName,
+            deploymentId,
+            apiVersion,
+            dims,
+            Boolean.TRUE.equals(dimensionsSetByUser),
+            maxTokens,
+            similarity
+        );
+    }
+
+    private record CommonFields(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        @Nullable Integer dimensions,
+        Boolean dimensionsSetByUser,
+        @Nullable Integer maxInputTokens,
+        @Nullable SimilarityMeasure similarity
+    ) {}
+
+    private final String resourceName;
+    private final String deploymentId;
+    private final String apiVersion;
+    private final Integer dimensions;
+    private final Boolean dimensionsSetByUser;
+    private final Integer maxInputTokens;
+    private final SimilarityMeasure similarity;
+
+    public AzureOpenAiEmbeddingsServiceSettings(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        @Nullable Integer dimensions,
+        Boolean dimensionsSetByUser,
+        @Nullable Integer maxInputTokens,
+        @Nullable SimilarityMeasure similarity
+    ) {
+        this.resourceName = resourceName;
+        this.deploymentId = deploymentId;
+        this.apiVersion = apiVersion;
+        this.dimensions = dimensions;
+        this.dimensionsSetByUser = Objects.requireNonNull(dimensionsSetByUser);
+        this.maxInputTokens = maxInputTokens;
+        this.similarity = similarity;
+    }
+
+    public AzureOpenAiEmbeddingsServiceSettings(StreamInput in) throws IOException {
+        resourceName = in.readString();
+        deploymentId = in.readString();
+        apiVersion = in.readString();
+        dimensions = in.readOptionalVInt();
+        dimensionsSetByUser = in.readBoolean();
+        maxInputTokens = in.readOptionalVInt();
+        similarity = in.readOptionalEnum(SimilarityMeasure.class);
+    }
+
+    private AzureOpenAiEmbeddingsServiceSettings(CommonFields fields) {
+        this(
+            fields.resourceName,
+            fields.deploymentId,
+            fields.apiVersion,
+            fields.dimensions,
+            fields.dimensionsSetByUser,
+            fields.maxInputTokens,
+            fields.similarity
+        );
+    }
+
+    public String resourceName() {
+        return resourceName;
+    }
+
+    public String deploymentId() {
+        return deploymentId;
+    }
+
+    public String apiVersion() {
+        return apiVersion;
+    }
+
+    @Override
+    public Integer dimensions() {
+        return dimensions;
+    }
+
+    public Boolean dimensionsSetByUser() {
+        return dimensionsSetByUser;
+    }
+
+    public Integer maxInputTokens() {
+        return maxInputTokens;
+    }
+
+    @Override
+    public SimilarityMeasure similarity() {
+        return similarity;
+    }
+
+    @Override
+    public DenseVectorFieldMapper.ElementType elementType() {
+        return DenseVectorFieldMapper.ElementType.FLOAT;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        toXContentFragmentOfExposedFields(builder, params);
+
+        builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser);
+
+        builder.endObject();
+        return builder;
+    }
+
+    private void toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
+        builder.field(RESOURCE_NAME, resourceName);
+        builder.field(DEPLOYMENT_ID, deploymentId);
+        builder.field(API_VERSION, apiVersion);
+
+        if (dimensions != null) {
+            builder.field(DIMENSIONS, dimensions);
+        }
+        if (maxInputTokens != null) {
+            builder.field(MAX_INPUT_TOKENS, maxInputTokens);
+        }
+        if (similarity != null) {
+            builder.field(SIMILARITY, similarity);
+        }
+    }
+
+    @Override
+    public ToXContentObject getFilteredXContentObject() {
+        return (builder, params) -> {
+            builder.startObject();
+
+            toXContentFragmentOfExposedFields(builder, params);
+
+            builder.endObject();
+            return builder;
+        };
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeString(resourceName);
+        out.writeString(deploymentId);
+        out.writeString(apiVersion);
+        out.writeOptionalVInt(dimensions);
+        out.writeBoolean(dimensionsSetByUser);
+        out.writeOptionalVInt(maxInputTokens);
+        out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion()));
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        AzureOpenAiEmbeddingsServiceSettings that = (AzureOpenAiEmbeddingsServiceSettings) o;
+
+        return Objects.equals(resourceName, that.resourceName)
+            && Objects.equals(deploymentId, that.deploymentId)
+            && Objects.equals(apiVersion, that.apiVersion)
+            && Objects.equals(dimensions, that.dimensions)
+            && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser)
+            && Objects.equals(maxInputTokens, that.maxInputTokens)
+            && Objects.equals(similarity, that.similarity);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(resourceName, deploymentId, apiVersion, dimensions, dimensionsSetByUser, maxInputTokens, similarity);
+    }
+}

+ 114 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java

@@ -0,0 +1,114 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.embeddings;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER;
+
+/**
+ * Defines the task settings for the openai service.
+ *
+ * User is an optional unique identifier representing the end-user, which can help OpenAI to monitor and detect abuse
+ *  <a href="https://platform.openai.com/docs/api-reference/embeddings/create">see the openai docs for more details</a>
+ */
+public class AzureOpenAiEmbeddingsTaskSettings implements TaskSettings {
+
+    public static final String NAME = "azure_openai_embeddings_task_settings";
+
+    public static AzureOpenAiEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
+        ValidationException validationException = new ValidationException();
+
+        String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new AzureOpenAiEmbeddingsTaskSettings(user);
+    }
+
+    /**
+     * Creates a new {@link AzureOpenAiEmbeddingsTaskSettings} object by overriding the values in originalSettings with the ones
+     * passed in via requestSettings if the fields are not null.
+     * @param originalSettings the original {@link AzureOpenAiEmbeddingsTaskSettings} from the inference entity configuration from storage
+     * @param requestSettings the {@link AzureOpenAiEmbeddingsTaskSettings} from the request
+     * @return a new {@link AzureOpenAiEmbeddingsTaskSettings}
+     */
+    public static AzureOpenAiEmbeddingsTaskSettings of(
+        AzureOpenAiEmbeddingsTaskSettings originalSettings,
+        AzureOpenAiEmbeddingsRequestTaskSettings requestSettings
+    ) {
+        var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user();
+        return new AzureOpenAiEmbeddingsTaskSettings(userToUse);
+    }
+
+    private final String user;
+
+    public AzureOpenAiEmbeddingsTaskSettings(@Nullable String user) {
+        this.user = user;
+    }
+
+    public AzureOpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException {
+        this.user = in.readOptionalString();
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        if (user != null) {
+            builder.field(USER, user);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    public String user() {
+        return user;
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalString(user);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        AzureOpenAiEmbeddingsTaskSettings that = (AzureOpenAiEmbeddingsTaskSettings) o;
+        return Objects.equals(user, that.user);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(user);
+    }
+}

+ 454 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java

@@ -0,0 +1,454 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.azureopenai;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockRequest;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
+import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsRequestTaskSettingsTests.getRequestTaskSettingsMap;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+
+public class AzureOpenAiActionCreatorTests extends ESTestCase {
+    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+    private final MockWebServer webServer = new MockWebServer();
+    private ThreadPool threadPool;
+    private HttpClientManager clientManager;
+
+    @Before
+    public void init() throws Exception {
+        webServer.start();
+        threadPool = createThreadPool(inferenceUtilityPool());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        clientManager.close();
+        terminate(threadPool);
+        webServer.close();
+    }
+
+    public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            String responseJson = """
+                {
+                  "object": "list",
+                  "data": [
+                      {
+                          "object": "embedding",
+                          "index": 0,
+                          "embedding": [
+                              0.0123,
+                              -0.0123
+                          ]
+                      }
+                  ],
+                  "model": "text-embedding-ada-002-v2",
+                  "usage": {
+                      "prompt_tokens": 8,
+                      "total_tokens": 8
+                  }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = createModel("resource", "deployment", "apiversion", "orig_user", "apikey", null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
+            var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user");
+            var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+            assertThat(webServer.requests(), hasSize(1));
+            validateRequestWithApiKey(webServer.requests().get(0), "apikey");
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            validateRequestMapWithUser(requestMap, List.of("abc"), "overridden_user");
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            String responseJson = """
+                {
+                  "object": "list",
+                  "data": [
+                      {
+                          "object": "embedding",
+                          "index": 0,
+                          "embedding": [
+                              0.0123,
+                              -0.0123
+                          ]
+                      }
+                  ],
+                  "model": "text-embedding-ada-002-v2",
+                  "usage": {
+                      "prompt_tokens": 8,
+                      "total_tokens": 8
+                  }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
+            var overriddenTaskSettings = getRequestTaskSettingsMap(null);
+            var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+            assertThat(webServer.requests(), hasSize(1));
+            validateRequestWithApiKey(webServer.requests().get(0), "apikey");
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            validateRequestMapWithUser(requestMap, List.of("abc"), null);
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException {
+        // timeout as zero for no retries
+        var settings = buildSettingsWithRetryFields(
+            TimeValue.timeValueMillis(1),
+            TimeValue.timeValueMinutes(1),
+            TimeValue.timeValueSeconds(0)
+        );
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings);
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            String responseJson = """
+                {
+                  "object": "list",
+                  "data_does_not_exist": [
+                      {
+                          "object": "embedding",
+                          "index": 0,
+                          "embedding": [
+                              0.0123,
+                              -0.0123
+                          ]
+                      }
+                  ],
+                  "model": "text-embedding-ada-002-v2",
+                  "usage": {
+                      "prompt_tokens": 8,
+                      "total_tokens": 8
+                  }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
+            var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user");
+            var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+            assertThat(
+                thrownException.getMessage(),
+                is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer)))
+            );
+            assertThat(thrownException.getCause().getMessage(), is("Failed to find required field [data] in OpenAI embeddings response"));
+
+            assertThat(webServer.requests(), hasSize(1));
+            validateRequestWithApiKey(webServer.requests().get(0), "apikey");
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            validateRequestMapWithUser(requestMap, List.of("abc"), "overridden_user");
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            // note - there is no complete documentation on Azure's error messages
+            // but this error and response has been verified manually via CURL
+            var contentTooLargeErrorMessage =
+                "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;"
+                    + "0 for the completion). Please reduce your prompt; or completion length.";
+
+            String responseJsonContentTooLarge = Strings.format("""
+                    {
+                        "error": {
+                            "message": "%s",
+                            "type": "invalid_request_error",
+                            "param": null,
+                            "code": null
+                        }
+                    }
+                """, contentTooLargeErrorMessage);
+
+            String responseJson = """
+                {
+                  "object": "list",
+                  "data": [
+                      {
+                          "object": "embedding",
+                          "index": 0,
+                          "embedding": [
+                              0.0123,
+                              -0.0123
+                          ]
+                      }
+                  ],
+                  "model": "text-embedding-ada-002-v2",
+                  "usage": {
+                      "prompt_tokens": 8,
+                      "total_tokens": 8
+                  }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge));
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
+            var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user");
+            var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of("abcd")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+            assertThat(webServer.requests(), hasSize(2));
+            {
+                validateRequestWithApiKey(webServer.requests().get(0), "apikey");
+
+                var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+                validateRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user");
+            }
+            {
+                validateRequestWithApiKey(webServer.requests().get(1), "apikey");
+
+                var requestMap = entityAsMap(webServer.requests().get(1).getBody());
+                validateRequestMapWithUser(requestMap, List.of("ab"), "overridden_user");
+            }
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            // note - there is no complete documentation on Azure's error messages
+            // but this error and response has been verified manually via CURL
+            var contentTooLargeErrorMessage =
+                "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;"
+                    + "0 for the completion). Please reduce your prompt; or completion length.";
+
+            String responseJsonContentTooLarge = Strings.format("""
+                    {
+                        "error": {
+                            "message": "%s",
+                            "type": "invalid_request_error",
+                            "param": null,
+                            "code": null
+                        }
+                    }
+                """, contentTooLargeErrorMessage);
+
+            String responseJson = """
+                {
+                  "object": "list",
+                  "data": [
+                      {
+                          "object": "embedding",
+                          "index": 0,
+                          "embedding": [
+                              0.0123,
+                              -0.0123
+                          ]
+                      }
+                  ],
+                  "model": "text-embedding-ada-002-v2",
+                  "usage": {
+                      "prompt_tokens": 8,
+                      "total_tokens": 8
+                  }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJsonContentTooLarge));
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
+            var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user");
+            var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of("abcd")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+            assertThat(webServer.requests(), hasSize(2));
+            {
+                validateRequestWithApiKey(webServer.requests().get(0), "apikey");
+
+                var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+                validateRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user");
+            }
+            {
+                validateRequestWithApiKey(webServer.requests().get(1), "apikey");
+
+                var requestMap = entityAsMap(webServer.requests().get(1).getBody());
+                validateRequestMapWithUser(requestMap, List.of("ab"), "overridden_user");
+            }
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    public void testExecute_TruncatesInputBeforeSending() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            String responseJson = """
+                {
+                  "object": "list",
+                  "data": [
+                      {
+                          "object": "embedding",
+                          "index": 0,
+                          "embedding": [
+                              0.0123,
+                              -0.0123
+                          ]
+                      }
+                  ],
+                  "model": "text-embedding-ada-002-v2",
+                  "usage": {
+                      "prompt_tokens": 8,
+                      "total_tokens": 8
+                  }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            // truncated to 1 token = 3 characters
+            var model = createModel("resource", "deployment", "apiversion", null, false, 1, null, null, "apikey", null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool));
+            var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user");
+            var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings);
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of("super long input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+            assertThat(webServer.requests(), hasSize(1));
+            validateRequestWithApiKey(webServer.requests().get(0), "apikey");
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            validateRequestMapWithUser(requestMap, List.of("sup"), "overridden_user");
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    private void validateRequestMapWithUser(Map<String, Object> requestMap, List<String> input, @Nullable String user) {
+        var expectedSize = user == null ? 1 : 2;
+
+        assertThat(requestMap.size(), is(expectedSize));
+        assertThat(requestMap.get("input"), is(input));
+
+        if (user != null) {
+            assertThat(requestMap.get("user"), is(user));
+        }
+    }
+
+    private void validateRequestWithApiKey(MockRequest request, String apiKey) {
+        assertNull(request.getUri().getQuery());
+        assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+        assertThat(request.getHeader(AzureOpenAiUtils.API_KEY_HEADER), equalTo(apiKey));
+    }
+}

+ 219 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java

@@ -0,0 +1,219 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.action.azureopenai;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+
+public class AzureOpenAiEmbeddingsActionTests extends ESTestCase {
+    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+    private final MockWebServer webServer = new MockWebServer();
+    private ThreadPool threadPool;
+    private HttpClientManager clientManager;
+
+    @Before
+    public void init() throws Exception {
+        webServer.start();
+        threadPool = createThreadPool(inferenceUtilityPool());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        clientManager.close();
+        terminate(threadPool);
+        webServer.close();
+    }
+
+    public void testExecute_ReturnsSuccessfulResponse() throws IOException {
+        var senderFactory = new HttpRequestSender.Factory(
+            ServiceComponentsTests.createWithEmptySettings(threadPool),
+            clientManager,
+            mockClusterServiceEmpty()
+        );
+
+        try (var sender = senderFactory.createSender("test_service")) {
+            sender.start();
+
+            String responseJson = """
+                {
+                  "object": "list",
+                  "data": [
+                      {
+                          "object": "embedding",
+                          "index": 0,
+                          "embedding": [
+                              0.0123,
+                              -0.0123
+                          ]
+                      }
+                  ],
+                  "model": "text-embedding-ada-002-v2",
+                  "usage": {
+                      "prompt_tokens": 8,
+                      "total_tokens": 8
+                  }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
+
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+            assertThat(webServer.requests(), hasSize(1));
+            assertNull(webServer.requests().get(0).getUri().getQuery());
+            assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+            assertThat(webServer.requests().get(0).getHeader(AzureOpenAiUtils.API_KEY_HEADER), equalTo("apikey"));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            assertThat(requestMap.size(), is(2));
+            assertThat(requestMap.get("input"), is(List.of("abc")));
+            assertThat(requestMap.get("user"), is("user"));
+        }
+    }
+
+    public void testExecute_ThrowsElasticsearchException() {
+        var sender = mock(Sender.class);
+        doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
+
+        var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+        assertThat(thrownException.getMessage(), is("failed"));
+    }
+
+    public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
+        var sender = mock(Sender.class);
+
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[1];
+            listener.onFailure(new IllegalStateException("failed"));
+
+            return Void.TYPE;
+        }).when(sender).send(any(), any(), any(), any());
+
+        var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+        assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer))));
+    }
+
+    public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled_WhenUrlIsNull() {
+        var sender = mock(Sender.class);
+
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[1];
+            listener.onFailure(new IllegalStateException("failed"));
+
+            return Void.TYPE;
+        }).when(sender).send(any(), any(), any(), any());
+
+        var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+        assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer))));
+    }
+
+    public void testExecute_ThrowsException() {
+        var sender = mock(Sender.class);
+        doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any());
+
+        var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id");
+
+        PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+        action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
+
+        var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+
+        assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer))));
+    }
+
+    private AzureOpenAiEmbeddingsAction createAction(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        @Nullable String user,
+        String apiKey,
+        Sender sender,
+        String inferenceEntityId
+    ) {
+        AzureOpenAiEmbeddingsModel model = null;
+        try {
+            model = createModel(resourceName, deploymentId, apiVersion, user, apiKey, null, inferenceEntityId);
+            model.setUri(new URI(getUrl(webServer)));
+            var action = new AzureOpenAiEmbeddingsAction(sender, model, createWithEmptySettings(threadPool));
+            return action;
+        } catch (URISyntaxException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+}

+ 88 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandlerTests.java

@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.azureopenai;
+
+import org.apache.http.HttpResponse;
+import org.apache.http.StatusLine;
+import org.apache.http.message.BasicHeader;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.HttpResult;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class AzureOpenAiResponseHandlerTests extends ESTestCase {
+
+    public void testBuildRateLimitErrorMessage() {
+        int statusCode = 429;
+        var statusLine = mock(StatusLine.class);
+        when(statusLine.getStatusCode()).thenReturn(statusCode);
+        var response = mock(HttpResponse.class);
+        when(response.getStatusLine()).thenReturn(statusLine);
+        var httpResult = new HttpResult(response, new byte[] {});
+
+        {
+            when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS)).thenReturn(
+                new BasicHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS, "2999")
+            );
+            when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS)).thenReturn(
+                new BasicHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS, "99800")
+            );
+
+            var error = AzureOpenAiResponseHandler.buildRateLimitErrorMessage(httpResult);
+            assertThat(error, containsString("Remaining tokens [99800]. Remaining requests [2999]"));
+        }
+
+        {
+            when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS)).thenReturn(null);
+            var error = AzureOpenAiResponseHandler.buildRateLimitErrorMessage(httpResult);
+            assertThat(error, containsString("Remaining tokens [unknown]. Remaining requests [2999]"));
+        }
+
+        {
+            when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS)).thenReturn(
+                new BasicHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS, "2999")
+            );
+            when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS)).thenReturn(null);
+            var error = AzureOpenAiResponseHandler.buildRateLimitErrorMessage(httpResult);
+            assertThat(error, containsString("Remaining tokens [unknown]. Remaining requests [2999]"));
+        }
+    }
+
+    private static HttpResult createContentTooLargeResult(int statusCode) {
+        return createResult(
+            statusCode,
+            "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;"
+                + "0 for the completion). Please reduce your prompt; or completion length."
+        );
+    }
+
+    private static HttpResult createResult(int statusCode, String message) {
+        var statusLine = mock(StatusLine.class);
+        when(statusLine.getStatusCode()).thenReturn(statusCode);
+        var httpResponse = mock(HttpResponse.class);
+        when(httpResponse.getStatusLine()).thenReturn(statusLine);
+
+        String responseJson = Strings.format("""
+                {
+                    "error": {
+                        "message": "%s",
+                        "type": "content_too_large",
+                        "param": null,
+                        "code": null
+                    }
+                }
+            """, message);
+
+        return new HttpResult(httpResponse, responseJson.getBytes(StandardCharsets.UTF_8));
+    }
+}

+ 77 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntityTests.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.CoreMatchers.is;
+
+public class AzureOpenAiEmbeddingsRequestEntityTests extends ESTestCase {
+
+    public void testXContent_WritesUserWhenDefined() throws IOException {
+        var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), "testuser", null, false);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is("""
+            {"input":["abc"],"user":"testuser"}"""));
+    }
+
+    public void testXContent_DoesNotWriteUserWhenItIsNull() throws IOException {
+        var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, null, false);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is("""
+            {"input":["abc"]}"""));
+    }
+
+    public void testXContent_DoesNotWriteDimensionsWhenNotSetByUser() throws IOException {
+        var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, 100, false);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is("""
+            {"input":["abc"]}"""));
+    }
+
+    public void testXContent_DoesNotWriteDimensionsWhenNull_EvenIfSetByUserIsTrue() throws IOException {
+        var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, null, true);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is("""
+            {"input":["abc"]}"""));
+    }
+
+    public void testXContent_WritesDimensionsWhenNonNull_AndSetByUserIsTrue() throws IOException {
+        var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, 100, true);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is("""
+            {"input":["abc"],"dimensions":100}"""));
+    }
+}

+ 118 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestTests.java

@@ -0,0 +1,118 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.request.azureopenai;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.common.Truncator;
+import org.elasticsearch.xpack.inference.common.TruncatorTests;
+import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiAccount;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.List;
+
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
+import static org.hamcrest.Matchers.aMapWithSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
+
+public class AzureOpenAiEmbeddingsRequestTests extends ESTestCase {
+    public void testCreateRequest_WithApiKeyDefined() throws IOException, URISyntaxException {
+        var request = createRequest("resource", "deployment", "apiVersion", "apikey", null, "abc", "user");
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        var expectedUri = AzureOpenAiEmbeddingsModel.getEmbeddingsUri("resource", "deployment", "apiVersion").toString();
+        assertThat(httpPost.getURI().toString(), is(expectedUri));
+
+        assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is("apikey"));
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(2));
+        assertThat(requestMap.get("input"), is(List.of("abc")));
+        assertThat(requestMap.get("user"), is("user"));
+    }
+
+    public void testCreateRequest_WithEntraIdDefined() throws IOException, URISyntaxException {
+        var request = createRequest("resource", "deployment", "apiVersion", null, "entraId", "abc", "user");
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        var expectedUri = AzureOpenAiEmbeddingsModel.getEmbeddingsUri("resource", "deployment", "apiVersion").toString();
+        assertThat(httpPost.getURI().toString(), is(expectedUri));
+
+        assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer entraId"));
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(2));
+        assertThat(requestMap.get("input"), is(List.of("abc")));
+        assertThat(requestMap.get("user"), is("user"));
+    }
+
+    public void testTruncate_ReducesInputTextSizeByHalf() throws IOException {
+        var request = createRequest("resource", "deployment", "apiVersion", "apikey", null, "abcd", null);
+        var truncatedRequest = request.truncate();
+
+        var httpRequest = truncatedRequest.createHttpRequest();
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(1));
+        assertThat(requestMap.get("input"), is(List.of("ab")));
+    }
+
+    public void testIsTruncated_ReturnsTrue() {
+        var request = createRequest("resource", "deployment", "apiVersion", "apikey", null, "abcd", null);
+        assertFalse(request.getTruncationInfo()[0]);
+
+        var truncatedRequest = request.truncate();
+        assertTrue(truncatedRequest.getTruncationInfo()[0]);
+    }
+
+    public static AzureOpenAiEmbeddingsRequest createRequest(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        @Nullable String apiKey,
+        @Nullable String entraId,
+        String input,
+        @Nullable String user
+    ) {
+        var embeddingsModel = AzureOpenAiEmbeddingsModelTests.createModel(
+            resourceName,
+            deploymentId,
+            apiVersion,
+            user,
+            apiKey,
+            entraId,
+            "id"
+        );
+        var account = AzureOpenAiAccount.fromModel(embeddingsModel);
+
+        return new AzureOpenAiEmbeddingsRequest(
+            TruncatorTests.createTruncator(),
+            account,
+            new Truncator.TruncationResult(List.of(input), new boolean[] { false }),
+            embeddingsModel
+        );
+    }
+}

+ 160 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java

@@ -0,0 +1,160 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.hamcrest.CoreMatchers;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.API_KEY;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.ENTRA_ID;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+public class AzureOpenAiSecretSettingsTests extends AbstractWireSerializingTestCase<AzureOpenAiSecretSettings> {
+
+    public static AzureOpenAiSecretSettings createRandom() {
+        return new AzureOpenAiSecretSettings(
+            new SecureString(randomAlphaOfLength(15).toCharArray()),
+            new SecureString(randomAlphaOfLength(15).toCharArray())
+        );
+    }
+
+    public void testFromMap_ApiKey_Only() {
+        var serviceSettings = AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiSecretSettings.API_KEY, "abc")));
+        assertThat(new AzureOpenAiSecretSettings(new SecureString("abc".toCharArray()), null), is(serviceSettings));
+    }
+
+    public void testFromMap_EntraId_Only() {
+        var serviceSettings = AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(ENTRA_ID, "xyz")));
+        assertThat(new AzureOpenAiSecretSettings(null, new SecureString("xyz".toCharArray())), is(serviceSettings));
+    }
+
+    public void testFromMap_ReturnsNull_WhenMapIsNull() {
+        assertNull(AzureOpenAiSecretSettings.fromMap(null));
+    }
+
+    public void testFromMap_MissingApiKeyAndEntraId_ThrowsError() {
+        var thrownException = expectThrows(ValidationException.class, () -> AzureOpenAiSecretSettings.fromMap(new HashMap<>()));
+
+        assertThat(
+            thrownException.getMessage(),
+            containsString(
+                Strings.format(
+                    "[secret_settings] must have either the [%s] or the [%s] key set",
+                    AzureOpenAiSecretSettings.API_KEY,
+                    ENTRA_ID
+                )
+            )
+        );
+    }
+
+    public void testFromMap_HasBothApiKeyAndEntraId_ThrowsError() {
+        var mapValues = getAzureOpenAiSecretSettingsMap("apikey", "entraid");
+        var thrownException = expectThrows(ValidationException.class, () -> AzureOpenAiSecretSettings.fromMap(mapValues));
+
+        assertThat(
+            thrownException.getMessage(),
+            containsString(
+                Strings.format(
+                    "[secret_settings] must have only one of the [%s] or the [%s] key set",
+                    AzureOpenAiSecretSettings.API_KEY,
+                    ENTRA_ID
+                )
+            )
+        );
+    }
+
+    public void testFromMap_EmptyApiKey_ThrowsError() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiSecretSettings.API_KEY, "")))
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            containsString(
+                Strings.format(
+                    "[secret_settings] Invalid value empty string. [%s] must be a non-empty string",
+                    AzureOpenAiSecretSettings.API_KEY
+                )
+            )
+        );
+    }
+
+    public void testFromMap_EmptyEntraId_ThrowsError() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(ENTRA_ID, "")))
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            containsString(Strings.format("[secret_settings] Invalid value empty string. [%s] must be a non-empty string", ENTRA_ID))
+        );
+    }
+
+    // test toXContent
+    public void testToXContext_WritesApiKeyOnlyWhenEntraIdIsNull() throws IOException {
+        var testSettings = new AzureOpenAiSecretSettings(new SecureString("apikey"), null);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        testSettings.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        var expectedResult = Strings.format("{\"%s\":\"apikey\"}", API_KEY);
+        assertThat(xContentResult, CoreMatchers.is(expectedResult));
+    }
+
+    public void testToXContext_WritesEntraIdOnlyWhenApiKeyIsNull() throws IOException {
+        var testSettings = new AzureOpenAiSecretSettings(null, new SecureString("entraid"));
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        testSettings.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        var expectedResult = Strings.format("{\"%s\":\"entraid\"}", ENTRA_ID);
+        assertThat(xContentResult, CoreMatchers.is(expectedResult));
+    }
+
+    @Override
+    protected Writeable.Reader<AzureOpenAiSecretSettings> instanceReader() {
+        return AzureOpenAiSecretSettings::new;
+    }
+
+    @Override
+    protected AzureOpenAiSecretSettings createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected AzureOpenAiSecretSettings mutateInstance(AzureOpenAiSecretSettings instance) throws IOException {
+        return createRandom();
+    }
+
+    public static Map<String, Object> getAzureOpenAiSecretSettingsMap(@Nullable String apiKey, @Nullable String entraId) {
+        var map = new HashMap<String, Object>();
+        if (apiKey != null) {
+            map.put(AzureOpenAiSecretSettings.API_KEY, apiKey);
+        }
+        if (entraId != null) {
+            map.put(ENTRA_ID, entraId);
+        }
+        return map;
+    }
+}

+ 1180 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java

@@ -0,0 +1,1180 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ *
+ * this file was contributed to by a generative AI
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel;
+import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests;
+import org.hamcrest.CoreMatchers;
+import org.hamcrest.MatcherAssert;
+import org.hamcrest.Matchers;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER;
+import static org.elasticsearch.xpack.inference.results.ChunkedTextEmbeddingResultsTests.asMapWithListsInsteadOfArrays;
+import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel;
+import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettingsTests.getAzureOpenAiSecretSettingsMap;
+import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getPersistentAzureOpenAiServiceSettingsMap;
+import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getRequestAzureOpenAiServiceSettingsMap;
+import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettingsTests.getAzureOpenAiRequestTaskSettingsMap;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+public class AzureOpenAiServiceTests extends ESTestCase {
+    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+    private final MockWebServer webServer = new MockWebServer();
+    private ThreadPool threadPool;
+    private HttpClientManager clientManager;
+
+    @Before
+    public void init() throws Exception {
+        webServer.start();
+        threadPool = createThreadPool(inferenceUtilityPool());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        clientManager.close();
+        terminate(threadPool);
+        webServer.close();
+    }
+
+    public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
+                assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+                var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+                assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+                assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+                assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+                assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+                assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            }, exception -> fail("Unexpected exception: " + exception));
+
+            service.parseRequestConfig(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                getRequestConfigMap(
+                    getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                    getAzureOpenAiRequestTaskSettingsMap("user"),
+                    getAzureOpenAiSecretSettingsMap("secret", null)
+                ),
+                Set.of(),
+                modelVerificationListener
+            );
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(
+                model -> fail("Expected exception, but got model: " + model),
+                exception -> {
+                    assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+                    assertThat(exception.getMessage(), is("The [azureopenai] service does not support task type [sparse_embedding]"));
+                }
+            );
+
+            service.parseRequestConfig(
+                "id",
+                TaskType.SPARSE_EMBEDDING,
+                getRequestConfigMap(
+                    getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                    getAzureOpenAiRequestTaskSettingsMap("user"),
+                    getAzureOpenAiSecretSettingsMap("secret", null)
+                ),
+                Set.of(),
+                modelVerificationListener
+            );
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var config = getRequestConfigMap(
+                getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                getAzureOpenAiRequestTaskSettingsMap("user"),
+                getAzureOpenAiSecretSettingsMap("secret", null)
+            );
+            config.put("extra_key", "value");
+
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(
+                model -> fail("Expected exception, but got model: " + model),
+                exception -> {
+                    assertThat(exception, instanceOf(ElasticsearchStatusException.class));
+                    assertThat(
+                        exception.getMessage(),
+                        is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service")
+                    );
+                }
+            );
+
+            service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener);
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var serviceSettings = getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null);
+            serviceSettings.put("extra_key", "value");
+
+            var config = getRequestConfigMap(
+                serviceSettings,
+                getAzureOpenAiRequestTaskSettingsMap("user"),
+                getAzureOpenAiSecretSettingsMap("secret", null)
+            );
+
+            ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
+                fail("Expected exception, but got model: " + model);
+            }, e -> {
+                assertThat(e, instanceOf(ElasticsearchStatusException.class));
+                assertThat(
+                    e.getMessage(),
+                    is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service")
+                );
+            });
+
+            service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener);
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user");
+            taskSettingsMap.put("extra_key", "value");
+
+            var config = getRequestConfigMap(
+                getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                taskSettingsMap,
+                getAzureOpenAiSecretSettingsMap("secret", null)
+            );
+
+            ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
+                fail("Expected exception, but got model: " + model);
+            }, e -> {
+                assertThat(e, instanceOf(ElasticsearchStatusException.class));
+                assertThat(
+                    e.getMessage(),
+                    is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service")
+                );
+            });
+
+            service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener);
+        }
+    }
+
+    public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var secretSettingsMap = getAzureOpenAiSecretSettingsMap("secret", null);
+            secretSettingsMap.put("extra_key", "value");
+
+            var config = getRequestConfigMap(
+                getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                getAzureOpenAiRequestTaskSettingsMap("user"),
+                secretSettingsMap
+            );
+
+            ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
+                fail("Expected exception, but got model: " + model);
+            }, e -> {
+                assertThat(e, instanceOf(ElasticsearchStatusException.class));
+                assertThat(
+                    e.getMessage(),
+                    is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service")
+                );
+            });
+
+            service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener);
+        }
+    }
+
+    public void testParseRequestConfig_MovesModel() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
+                assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+                var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+                assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+                assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+                assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+                assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+                assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            }, exception -> fail("Unexpected exception: " + exception));
+
+            service.parseRequestConfig(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                getRequestConfigMap(
+                    getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                    getAzureOpenAiRequestTaskSettingsMap("user"),
+                    getAzureOpenAiSecretSettingsMap("secret", null)
+                ),
+                Set.of(),
+                modelVerificationListener
+            );
+        }
+    }
+
+    public void testParsePersistedConfigWithSecrets_CreatesAnAzureOpenAiEmbeddingsModel() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var persistedConfig = getPersistedConfigMap(
+                getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512),
+                getAzureOpenAiRequestTaskSettingsMap("user"),
+                getAzureOpenAiSecretSettingsMap("secret", null)
+            );
+
+            var model = service.parsePersistedConfigWithSecrets(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                persistedConfig.config(),
+                persistedConfig.secrets()
+            );
+
+            assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+            var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+            assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+            assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100));
+            assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512));
+            assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+        }
+    }
+
+    public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var persistedConfig = getPersistedConfigMap(
+                getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                getAzureOpenAiRequestTaskSettingsMap("user"),
+                getAzureOpenAiSecretSettingsMap("secret", null)
+            );
+
+            var thrownException = expectThrows(
+                ElasticsearchStatusException.class,
+                () -> service.parsePersistedConfigWithSecrets(
+                    "id",
+                    TaskType.SPARSE_EMBEDDING,
+                    persistedConfig.config(),
+                    persistedConfig.secrets()
+                )
+            );
+
+            assertThat(
+                thrownException.getMessage(),
+                is("Failed to parse stored model [id] for [azureopenai] service, please delete and add the service again")
+            );
+        }
+    }
+
+    public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var persistedConfig = getPersistedConfigMap(
+                getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512),
+                getAzureOpenAiRequestTaskSettingsMap("user"),
+                getAzureOpenAiSecretSettingsMap("secret", null)
+            );
+            persistedConfig.config().put("extra_key", "value");
+
+            var model = service.parsePersistedConfigWithSecrets(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                persistedConfig.config(),
+                persistedConfig.secrets()
+            );
+
+            assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+            var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+            assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+            assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100));
+            assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512));
+            assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+        }
+    }
+
+    public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var secretSettingsMap = getAzureOpenAiSecretSettingsMap("secret", null);
+            secretSettingsMap.put("extra_key", "value");
+
+            var persistedConfig = getPersistedConfigMap(
+                getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512),
+                getAzureOpenAiRequestTaskSettingsMap("user"),
+                secretSettingsMap
+            );
+
+            var model = service.parsePersistedConfigWithSecrets(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                persistedConfig.config(),
+                persistedConfig.secrets()
+            );
+
+            assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+            var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+            assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+            assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100));
+            assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512));
+            assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+        }
+    }
+
+    public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var persistedConfig = getPersistedConfigMap(
+                getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512),
+                getAzureOpenAiRequestTaskSettingsMap("user"),
+                getAzureOpenAiSecretSettingsMap("secret", null)
+            );
+            persistedConfig.secrets.put("extra_key", "value");
+
+            var model = service.parsePersistedConfigWithSecrets(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                persistedConfig.config(),
+                persistedConfig.secrets()
+            );
+
+            assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+            var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+            assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+            assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100));
+            assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512));
+            assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+        }
+    }
+
+    public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var serviceSettingsMap = getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512);
+            serviceSettingsMap.put("extra_key", "value");
+
+            var persistedConfig = getPersistedConfigMap(
+                serviceSettingsMap,
+                getAzureOpenAiRequestTaskSettingsMap("user"),
+                getAzureOpenAiSecretSettingsMap("secret", null)
+            );
+
+            var model = service.parsePersistedConfigWithSecrets(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                persistedConfig.config(),
+                persistedConfig.secrets()
+            );
+
+            assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+            var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+            assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+            assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100));
+            assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512));
+            assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+        }
+    }
+
+    public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user");
+            taskSettingsMap.put("extra_key", "value");
+
+            var persistedConfig = getPersistedConfigMap(
+                getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512),
+                taskSettingsMap,
+                getAzureOpenAiSecretSettingsMap("secret", null)
+            );
+
+            var model = service.parsePersistedConfigWithSecrets(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                persistedConfig.config(),
+                persistedConfig.secrets()
+            );
+
+            assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+            var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+            assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+            assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100));
+            assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512));
+            assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
+        }
+    }
+
+    public void testParsePersistedConfig_CreatesAnAzureOpenAiEmbeddingsModel() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var persistedConfig = getPersistedConfigMap(
+                getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                getAzureOpenAiRequestTaskSettingsMap("user")
+            );
+
+            var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+            assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+            var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+            assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+            assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            assertNull(embeddingsModel.getSecretSettings());
+        }
+    }
+
+    public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var persistedConfig = getPersistedConfigMap(
+                getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                getAzureOpenAiRequestTaskSettingsMap("user")
+            );
+
+            var thrownException = expectThrows(
+                ElasticsearchStatusException.class,
+                () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config())
+            );
+
+            assertThat(
+                thrownException.getMessage(),
+                is("Failed to parse stored model [id] for [azureopenai] service, please delete and add the service again")
+            );
+        }
+    }
+
+    public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var persistedConfig = getPersistedConfigMap(
+                getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                getAzureOpenAiRequestTaskSettingsMap("user")
+            );
+            persistedConfig.config().put("extra_key", "value");
+
+            var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+            assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+            var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+            assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+            assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            assertNull(embeddingsModel.getSecretSettings());
+        }
+    }
+
+    public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var serviceSettingsMap = getPersistentAzureOpenAiServiceSettingsMap(
+                "resource_name",
+                "deployment_id",
+                "api_version",
+                null,
+                null
+            );
+            serviceSettingsMap.put("extra_key", "value");
+
+            var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getAzureOpenAiRequestTaskSettingsMap("user"));
+
+            var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+            assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+            var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+            assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+            assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            assertNull(embeddingsModel.getSecretSettings());
+        }
+    }
+
+    public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException {
+        try (var service = createAzureOpenAiService()) {
+            var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user");
+            taskSettingsMap.put("extra_key", "value");
+
+            var persistedConfig = getPersistedConfigMap(
+                getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null),
+                taskSettingsMap
+            );
+
+            var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+            assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class));
+
+            var embeddingsModel = (AzureOpenAiEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name"));
+            assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id"));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version"));
+            assertThat(embeddingsModel.getTaskSettings().user(), is("user"));
+            assertNull(embeddingsModel.getSecretSettings());
+        }
+    }
+
+    public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOException {
+        var sender = mock(Sender.class);
+
+        var factory = mock(HttpRequestSender.Factory.class);
+        when(factory.createSender(anyString())).thenReturn(sender);
+
+        var mockModel = getInvalidModel("model_id", "service_name");
+
+        try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool))) {
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            service.infer(
+                mockModel,
+                null,
+                List.of(""),
+                new HashMap<>(),
+                InputType.INGEST,
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+            assertThat(
+                thrownException.getMessage(),
+                is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.")
+            );
+
+            verify(factory, times(1)).createSender(anyString());
+            verify(sender, times(1)).start();
+        }
+
+        verify(sender, times(1)).close();
+        verifyNoMoreInteractions(factory);
+        verifyNoMoreInteractions(sender);
+    }
+
+    public void testInfer_SendsRequest() throws IOException, URISyntaxException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "object": "list",
+                    "data": [
+                        {
+                            "object": "embedding",
+                            "index": 0,
+                            "embedding": [
+                                0.0123,
+                                -0.0123
+                            ]
+                        }
+                    ],
+                    "model": "text-embedding-ada-002-v2",
+                    "usage": {
+                        "prompt_tokens": 8,
+                        "total_tokens": 8
+                    }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", "apikey", null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            service.infer(
+                model,
+                null,
+                List.of("abc"),
+                new HashMap<>(),
+                InputType.INGEST,
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            var result = listener.actionGet(TIMEOUT);
+
+            assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.0123F, -0.0123F)))));
+            assertThat(webServer.requests(), hasSize(1));
+            assertNull(webServer.requests().get(0).getUri().getQuery());
+            assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+            assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey"));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            assertThat(requestMap.size(), Matchers.is(2));
+            assertThat(requestMap.get("input"), Matchers.is(List.of("abc")));
+            assertThat(requestMap.get("user"), Matchers.is("user"));
+        }
+    }
+
+    public void testCheckModelConfig_IncludesMaxTokens() throws IOException, URISyntaxException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "object": "list",
+                    "data": [
+                        {
+                            "object": "embedding",
+                            "index": 0,
+                            "embedding": [
+                                0.0123,
+                                -0.0123
+                            ]
+                        }
+                    ],
+                    "model": "text-embedding-ada-002-v2",
+                    "usage": {
+                        "prompt_tokens": 8,
+                        "total_tokens": 8
+                    }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = AzureOpenAiEmbeddingsModelTests.createModel(
+                "resource",
+                "deployment",
+                "apiversion",
+                null,
+                false,
+                100,
+                null,
+                "user",
+                "apikey",
+                null,
+                "id"
+            );
+            model.setUri(new URI(getUrl(webServer)));
+
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+            assertThat(
+                result,
+                is(
+                    AzureOpenAiEmbeddingsModelTests.createModel(
+                        "resource",
+                        "deployment",
+                        "apiversion",
+                        2,
+                        false,
+                        100,
+                        SimilarityMeasure.DOT_PRODUCT,
+                        "user",
+                        "apikey",
+                        null,
+                        "id"
+                    )
+                )
+            );
+
+            assertThat(webServer.requests(), hasSize(1));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user")));
+        }
+    }
+
+    public void testCheckModelConfig_HasSimilarity() throws IOException, URISyntaxException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "object": "list",
+                    "data": [
+                        {
+                            "object": "embedding",
+                            "index": 0,
+                            "embedding": [
+                                0.0123,
+                                -0.0123
+                            ]
+                        }
+                    ],
+                    "model": "text-embedding-ada-002-v2",
+                    "usage": {
+                        "prompt_tokens": 8,
+                        "total_tokens": 8
+                    }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = AzureOpenAiEmbeddingsModelTests.createModel(
+                "resource",
+                "deployment",
+                "apiversion",
+                null,
+                false,
+                null,
+                SimilarityMeasure.COSINE,
+                "user",
+                "apikey",
+                null,
+                "id"
+            );
+            model.setUri(new URI(getUrl(webServer)));
+
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+            assertThat(
+                result,
+                is(
+                    AzureOpenAiEmbeddingsModelTests.createModel(
+                        "resource",
+                        "deployment",
+                        "apiversion",
+                        2,
+                        false,
+                        null,
+                        SimilarityMeasure.COSINE,
+                        "user",
+                        "apikey",
+                        null,
+                        "id"
+                    )
+                )
+            );
+
+            assertThat(webServer.requests(), hasSize(1));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user")));
+        }
+    }
+
+    public void testCheckModelConfig_AddsDefaultSimilarityDotProduct() throws IOException, URISyntaxException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "object": "list",
+                    "data": [
+                        {
+                            "object": "embedding",
+                            "index": 0,
+                            "embedding": [
+                                0.0123,
+                                -0.0123
+                            ]
+                        }
+                    ],
+                    "model": "text-embedding-ada-002-v2",
+                    "usage": {
+                        "prompt_tokens": 8,
+                        "total_tokens": 8
+                    }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = AzureOpenAiEmbeddingsModelTests.createModel(
+                "resource",
+                "deployment",
+                "apiversion",
+                null,
+                false,
+                null,
+                null,
+                "user",
+                "apikey",
+                null,
+                "id"
+            );
+            model.setUri(new URI(getUrl(webServer)));
+
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+            assertThat(
+                result,
+                is(
+                    AzureOpenAiEmbeddingsModelTests.createModel(
+                        "resource",
+                        "deployment",
+                        "apiversion",
+                        2,
+                        false,
+                        null,
+                        SimilarityMeasure.DOT_PRODUCT,
+                        "user",
+                        "apikey",
+                        null,
+                        "id"
+                    )
+                )
+            );
+
+            assertThat(webServer.requests(), hasSize(1));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user")));
+        }
+    }
+
+    public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException, URISyntaxException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "object": "list",
+                    "data": [
+                        {
+                            "object": "embedding",
+                            "index": 0,
+                            "embedding": [
+                                0.0123,
+                                -0.0123
+                            ]
+                        }
+                    ],
+                    "model": "text-embedding-ada-002-v2",
+                    "usage": {
+                        "prompt_tokens": 8,
+                        "total_tokens": 8
+                    }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = AzureOpenAiEmbeddingsModelTests.createModel(
+                "resource",
+                "deployment",
+                "apiversion",
+                3,
+                true,
+                100,
+                null,
+                "user",
+                "apikey",
+                null,
+                "id"
+            );
+            model.setUri(new URI(getUrl(webServer)));
+
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+
+            var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+            assertThat(
+                exception.getMessage(),
+                is(
+                    "The retrieved embeddings size [2] does not match the size specified in the settings [3]. "
+                        + "Please recreate the [id] configuration with the correct dimensions"
+                )
+            );
+
+            assertThat(webServer.requests(), hasSize(1));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user", "dimensions", 3)));
+        }
+    }
+
+    public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException,
+        URISyntaxException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "object": "list",
+                    "data": [
+                        {
+                            "object": "embedding",
+                            "index": 0,
+                            "embedding": [
+                                0.0123,
+                                -0.0123
+                            ]
+                        }
+                    ],
+                    "model": "text-embedding-ada-002-v2",
+                    "usage": {
+                        "prompt_tokens": 8,
+                        "total_tokens": 8
+                    }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = AzureOpenAiEmbeddingsModelTests.createModel(
+                "resource",
+                "deployment",
+                "apiversion",
+                100,
+                false,
+                100,
+                null,
+                "user",
+                "apikey",
+                null,
+                "id"
+            );
+            model.setUri(new URI(getUrl(webServer)));
+
+            PlainActionFuture<Model> listener = new PlainActionFuture<>();
+            service.checkModelConfig(model, listener);
+
+            var result = listener.actionGet(TIMEOUT);
+            assertThat(
+                result,
+                is(
+                    AzureOpenAiEmbeddingsModelTests.createModel(
+                        "resource",
+                        "deployment",
+                        "apiversion",
+                        2,
+                        false,
+                        100,
+                        SimilarityMeasure.DOT_PRODUCT,
+                        "user",
+                        "apikey",
+                        null,
+                        "id"
+                    )
+                )
+            );
+
+            assertThat(webServer.requests(), hasSize(1));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user")));
+        }
+    }
+
+    public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "error": {
+                        "message": "Incorrect API key provided:",
+                        "type": "invalid_request_error",
+                        "param": null,
+                        "code": "invalid_api_key"
+                    }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson));
+
+            var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", "apikey", null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
+            service.infer(
+                model,
+                null,
+                List.of("abc"),
+                new HashMap<>(),
+                InputType.INGEST,
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
+            assertThat(error.getMessage(), containsString("Received an authentication error status code for request"));
+            assertThat(error.getMessage(), containsString("Error message: [Incorrect API key provided:]"));
+            assertThat(webServer.requests(), hasSize(1));
+        }
+    }
+
+    public void testChunkedInfer_CallsInfer_ConvertsFloatResponse() throws IOException, URISyntaxException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                {
+                    "object": "list",
+                    "data": [
+                        {
+                            "object": "embedding",
+                            "index": 0,
+                            "embedding": [
+                                0.0123,
+                                -0.0123
+                            ]
+                        }
+                    ],
+                    "model": "text-embedding-ada-002-v2",
+                    "usage": {
+                        "prompt_tokens": 8,
+                        "total_tokens": 8
+                    }
+                }
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", "apikey", null, "id");
+            model.setUri(new URI(getUrl(webServer)));
+            PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
+            service.chunkedInfer(
+                model,
+                List.of("abc"),
+                new HashMap<>(),
+                InputType.INGEST,
+                new ChunkingOptions(null, null),
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            var result = listener.actionGet(TIMEOUT).get(0);
+            assertThat(result, CoreMatchers.instanceOf(ChunkedTextEmbeddingResults.class));
+
+            assertThat(
+                asMapWithListsInsteadOfArrays((ChunkedTextEmbeddingResults) result),
+                Matchers.is(
+                    Map.of(
+                        ChunkedTextEmbeddingResults.FIELD_NAME,
+                        List.of(
+                            Map.of(
+                                ChunkedNlpInferenceResults.TEXT,
+                                "abc",
+                                ChunkedNlpInferenceResults.INFERENCE,
+                                List.of((double) 0.0123f, (double) -0.0123f)
+                            )
+                        )
+                    )
+                )
+            );
+            assertThat(webServer.requests(), hasSize(1));
+            assertNull(webServer.requests().get(0).getUri().getQuery());
+            assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
+            assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey"));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            assertThat(requestMap.size(), Matchers.is(2));
+            assertThat(requestMap.get("input"), Matchers.is(List.of("abc")));
+            assertThat(requestMap.get("user"), Matchers.is("user"));
+        }
+    }
+
+    private AzureOpenAiService createAzureOpenAiService() {
+        return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
+    }
+
+    private Map<String, Object> getRequestConfigMap(
+        Map<String, Object> serviceSettings,
+        Map<String, Object> taskSettings,
+        Map<String, Object> secretSettings
+    ) {
+        var builtServiceSettings = new HashMap<>();
+        builtServiceSettings.putAll(serviceSettings);
+        builtServiceSettings.putAll(secretSettings);
+
+        return new HashMap<>(
+            Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)
+        );
+    }
+
+    private PeristedConfig getPersistedConfigMap(
+        Map<String, Object> serviceSettings,
+        Map<String, Object> taskSettings,
+        Map<String, Object> secretSettings
+    ) {
+
+        return new PeristedConfig(
+            new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)),
+            new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings))
+        );
+    }
+
+    private PeristedConfig getPersistedConfigMap(Map<String, Object> serviceSettings, Map<String, Object> taskSettings) {
+
+        return new PeristedConfig(
+            new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)),
+            null
+        );
+    }
+
+    private record PeristedConfig(Map<String, Object> config, Map<String, Object> secrets) {}
+}

+ 121 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java

@@ -0,0 +1,121 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.embeddings;
+
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings;
+
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettingsTests.getAzureOpenAiRequestTaskSettingsMap;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.sameInstance;
+
+public class AzureOpenAiEmbeddingsModelTests extends ESTestCase {
+
+    public void testOverrideWith_OverridesUser() {
+        var model = createModel("resource", "deployment", "apiversion", null, "api_key", null, "id");
+        var requestTaskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user_override");
+
+        var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, requestTaskSettingsMap);
+
+        assertThat(overriddenModel, is(createModel("resource", "deployment", "apiversion", "user_override", "api_key", null, "id")));
+    }
+
+    public void testOverrideWith_EmptyMap() {
+        var model = createModel("resource", "deployment", "apiversion", null, "api_key", null, "id");
+
+        var requestTaskSettingsMap = Map.<String, Object>of();
+
+        var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, requestTaskSettingsMap);
+        assertThat(overriddenModel, sameInstance(model));
+    }
+
+    public void testOverrideWith_NullMap() {
+        var model = createModel("resource", "deployment", "apiversion", null, "api_key", null, "id");
+
+        var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, null);
+        assertThat(overriddenModel, sameInstance(model));
+    }
+
+    public void testCreateModel_FromUpdatedServiceSettings() {
+        var model = createModel("resource", "deployment", "apiversion", "user", "api_key", null, "id");
+        var updatedSettings = new AzureOpenAiEmbeddingsServiceSettings(
+            "resource",
+            "deployment",
+            "override_apiversion",
+            null,
+            false,
+            null,
+            null
+        );
+
+        var overridenModel = new AzureOpenAiEmbeddingsModel(model, updatedSettings);
+
+        assertThat(overridenModel, is(createModel("resource", "deployment", "override_apiversion", "user", "api_key", null, "id")));
+    }
+
+    public static AzureOpenAiEmbeddingsModel createModel(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        String user,
+        @Nullable String apiKey,
+        @Nullable String entraId,
+        String inferenceEntityId
+    ) {
+        var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null;
+        var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null;
+        return new AzureOpenAiEmbeddingsModel(
+            inferenceEntityId,
+            TaskType.TEXT_EMBEDDING,
+            "service",
+            new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, false, null, null),
+            new AzureOpenAiEmbeddingsTaskSettings(user),
+            new AzureOpenAiSecretSettings(secureApiKey, secureEntraId)
+        );
+    }
+
+    public static AzureOpenAiEmbeddingsModel createModel(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        @Nullable Integer dimensions,
+        Boolean dimensionsSetByUser,
+        @Nullable Integer maxInputTokens,
+        @Nullable SimilarityMeasure similarity,
+        @Nullable String user,
+        @Nullable String apiKey,
+        @Nullable String entraId,
+        String inferenceEntityId
+    ) {
+        var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null;
+        var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null;
+
+        return new AzureOpenAiEmbeddingsModel(
+            inferenceEntityId,
+            TaskType.TEXT_EMBEDDING,
+            "service",
+            new AzureOpenAiEmbeddingsServiceSettings(
+                resourceName,
+                deploymentId,
+                apiVersion,
+                dimensions,
+                dimensionsSetByUser,
+                maxInputTokens,
+                similarity
+            ),
+            new AzureOpenAiEmbeddingsTaskSettings(user),
+            new AzureOpenAiSecretSettings(secureApiKey, secureEntraId)
+        );
+    }
+}

+ 56 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java

@@ -0,0 +1,56 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.embeddings;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields;
+import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsRequestTaskSettings;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+public class AzureOpenAiEmbeddingsRequestTaskSettingsTests extends ESTestCase {
+    public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() {
+        var settings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of()));
+        assertThat(settings, is(OpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS));
+    }
+
+    public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() {
+        var settings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model")));
+        assertNull(settings.user());
+    }
+
+    public void testFromMap_ReturnsUser() {
+        var settings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")));
+        assertThat(settings.user(), is("user"));
+    }
+
+    public void testFromMap_WhenUserIsEmpty_ThrowsValidationException() {
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "")))
+        );
+
+        assertThat(exception.getMessage(), containsString("[user] must be a non-empty string"));
+    }
+
+    public static Map<String, Object> getRequestTaskSettingsMap(@Nullable String user) {
+        var map = new HashMap<String, Object>();
+
+        if (user != null) {
+            map.put(OpenAiServiceFields.USER, user);
+        }
+
+        return map;
+    }
+}

+ 389 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java

@@ -0,0 +1,389 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.embeddings;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.ServiceFields;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
+import org.hamcrest.CoreMatchers;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
+import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+
+public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase<AzureOpenAiEmbeddingsServiceSettings> {
+
+    private static AzureOpenAiEmbeddingsServiceSettings createRandom() {
+        var resourceName = randomAlphaOfLength(8);
+        var deploymentId = randomAlphaOfLength(8);
+        var apiVersion = randomAlphaOfLength(8);
+        Integer dims = randomBoolean() ? 1536 : null;
+        Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256);
+        return new AzureOpenAiEmbeddingsServiceSettings(
+            resourceName,
+            deploymentId,
+            apiVersion,
+            dims,
+            randomBoolean(),
+            maxInputTokens,
+            null
+        );
+    }
+
+    public void testFromMap_Request_CreatesSettingsCorrectly() {
+        var resourceName = "this-resource";
+        var deploymentId = "this-deployment";
+        var apiVersion = "2024-01-01";
+        var dims = 1536;
+        var maxInputTokens = 512;
+        var serviceSettings = AzureOpenAiEmbeddingsServiceSettings.fromMap(
+            new HashMap<>(
+                Map.of(
+                    AzureOpenAiServiceFields.RESOURCE_NAME,
+                    resourceName,
+                    AzureOpenAiServiceFields.DEPLOYMENT_ID,
+                    deploymentId,
+                    AzureOpenAiServiceFields.API_VERSION,
+                    apiVersion,
+                    ServiceFields.DIMENSIONS,
+                    dims,
+                    ServiceFields.MAX_INPUT_TOKENS,
+                    maxInputTokens,
+                    SIMILARITY,
+                    SimilarityMeasure.COSINE.toString()
+                )
+            ),
+            ConfigurationParseContext.REQUEST
+        );
+
+        assertThat(
+            serviceSettings,
+            is(
+                new AzureOpenAiEmbeddingsServiceSettings(
+                    resourceName,
+                    deploymentId,
+                    apiVersion,
+                    dims,
+                    true,
+                    maxInputTokens,
+                    SimilarityMeasure.COSINE
+                )
+            )
+        );
+    }
+
+    public void testFromMap_Request_DimensionsSetByUser_IsFalse_WhenDimensionsAreNotPresent() {
+        var resourceName = "this-resource";
+        var deploymentId = "this-deployment";
+        var apiVersion = "2024-01-01";
+        var maxInputTokens = 512;
+        var serviceSettings = AzureOpenAiEmbeddingsServiceSettings.fromMap(
+            new HashMap<>(
+                Map.of(
+                    AzureOpenAiServiceFields.RESOURCE_NAME,
+                    resourceName,
+                    AzureOpenAiServiceFields.DEPLOYMENT_ID,
+                    deploymentId,
+                    AzureOpenAiServiceFields.API_VERSION,
+                    apiVersion,
+                    ServiceFields.MAX_INPUT_TOKENS,
+                    maxInputTokens
+                )
+            ),
+            ConfigurationParseContext.REQUEST
+        );
+
+        assertThat(
+            serviceSettings,
+            is(new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, false, maxInputTokens, null))
+        );
+    }
+
+    public void testFromMap_Request_DimensionsSetByUser_ShouldThrowWhenPresent() {
+        var resourceName = "this-resource";
+        var deploymentId = "this-deployment";
+        var apiVersion = "2024-01-01";
+        var maxInputTokens = 512;
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> AzureOpenAiEmbeddingsServiceSettings.fromMap(
+                new HashMap<>(
+                    Map.of(
+                        AzureOpenAiServiceFields.RESOURCE_NAME,
+                        resourceName,
+                        AzureOpenAiServiceFields.DEPLOYMENT_ID,
+                        deploymentId,
+                        AzureOpenAiServiceFields.API_VERSION,
+                        apiVersion,
+                        ServiceFields.MAX_INPUT_TOKENS,
+                        maxInputTokens,
+                        ServiceFields.DIMENSIONS,
+                        1024,
+                        DIMENSIONS_SET_BY_USER,
+                        false
+                    )
+                ),
+                ConfigurationParseContext.REQUEST
+            )
+        );
+
+        MatcherAssert.assertThat(
+            thrownException.getMessage(),
+            containsString(
+                Strings.format("Validation Failed: 1: [service_settings] does not allow the setting [%s];", DIMENSIONS_SET_BY_USER)
+            )
+        );
+    }
+
+    public void testFromMap_Persistent_CreatesSettingsCorrectly() {
+        var resourceName = "this-resource";
+        var deploymentId = "this-deployment";
+        var apiVersion = "2024-01-01";
+        var encodingFormat = "float";
+        var dims = 1536;
+        var maxInputTokens = 512;
+
+        var serviceSettings = AzureOpenAiEmbeddingsServiceSettings.fromMap(
+            new HashMap<>(
+                Map.of(
+                    AzureOpenAiServiceFields.RESOURCE_NAME,
+                    resourceName,
+                    AzureOpenAiServiceFields.DEPLOYMENT_ID,
+                    deploymentId,
+                    AzureOpenAiServiceFields.API_VERSION,
+                    apiVersion,
+                    ServiceFields.DIMENSIONS,
+                    dims,
+                    DIMENSIONS_SET_BY_USER,
+                    false,
+                    ServiceFields.MAX_INPUT_TOKENS,
+                    maxInputTokens,
+                    SIMILARITY,
+                    SimilarityMeasure.DOT_PRODUCT.toString()
+                )
+            ),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(
+            serviceSettings,
+            is(
+                new AzureOpenAiEmbeddingsServiceSettings(
+                    resourceName,
+                    deploymentId,
+                    apiVersion,
+                    dims,
+                    false,
+                    maxInputTokens,
+                    SimilarityMeasure.DOT_PRODUCT
+                )
+            )
+        );
+    }
+
+    public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIsNull() {
+        var resourceName = "this-resource";
+        var deploymentId = "this-deployment";
+        var apiVersion = "2024-01-01";
+
+        var settings = AzureOpenAiEmbeddingsServiceSettings.fromMap(
+            new HashMap<>(
+                Map.of(
+                    AzureOpenAiServiceFields.RESOURCE_NAME,
+                    resourceName,
+                    AzureOpenAiServiceFields.DEPLOYMENT_ID,
+                    deploymentId,
+                    AzureOpenAiServiceFields.API_VERSION,
+                    apiVersion,
+                    DIMENSIONS_SET_BY_USER,
+                    true
+                )
+            ),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(settings, is(new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, true, null, null)));
+    }
+
+    public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIsPresent() {
+        var resourceName = "this-resource";
+        var deploymentId = "this-deployment";
+        var apiVersion = "2024-01-01";
+
+        var settings = AzureOpenAiEmbeddingsServiceSettings.fromMap(
+            new HashMap<>(
+                Map.of(
+                    AzureOpenAiServiceFields.RESOURCE_NAME,
+                    resourceName,
+                    AzureOpenAiServiceFields.DEPLOYMENT_ID,
+                    deploymentId,
+                    AzureOpenAiServiceFields.API_VERSION,
+                    apiVersion,
+                    DIMENSIONS_SET_BY_USER,
+                    true,
+                    SIMILARITY,
+                    SimilarityMeasure.COSINE.toString()
+                )
+            ),
+            ConfigurationParseContext.PERSISTENT
+        );
+
+        assertThat(
+            settings,
+            is(new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, true, null, SimilarityMeasure.COSINE))
+        );
+    }
+
+    public void testFromMap_PersistentContext_ThrowsException_WhenDimensionsSetByUserIsNull() {
+        var resourceName = "this-resource";
+        var deploymentId = "this-deployment";
+        var apiVersion = "2024-01-01";
+
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> AzureOpenAiEmbeddingsServiceSettings.fromMap(
+                new HashMap<>(
+                    Map.of(
+                        AzureOpenAiServiceFields.RESOURCE_NAME,
+                        resourceName,
+                        AzureOpenAiServiceFields.DEPLOYMENT_ID,
+                        deploymentId,
+                        AzureOpenAiServiceFields.API_VERSION,
+                        apiVersion,
+                        ServiceFields.DIMENSIONS,
+                        1
+                    )
+                ),
+                ConfigurationParseContext.PERSISTENT
+            )
+        );
+
+        assertThat(
+            exception.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] does not contain the required setting [dimensions_set_by_user];")
+        );
+    }
+
+    public void testToXContent_WritesDimensionsSetByUserTrue() throws IOException {
+        var entity = new AzureOpenAiEmbeddingsServiceSettings("resource", "deployment", "apiVersion", null, true, null, null);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, CoreMatchers.is("""
+            {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """
+            "dimensions_set_by_user":true}"""));
+    }
+
+    public void testToXContent_WritesAllValues() throws IOException {
+        var entity = new AzureOpenAiEmbeddingsServiceSettings("resource", "deployment", "apiVersion", 1024, false, 512, null);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, CoreMatchers.is("""
+            {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """
+            "dimensions":1024,"max_input_tokens":512,"dimensions_set_by_user":false}"""));
+    }
+
+    public void testToFilteredXContent_WritesAllValues_ExceptDimensionsSetByUser() throws IOException {
+        var entity = new AzureOpenAiEmbeddingsServiceSettings("resource", "deployment", "apiVersion", 1024, false, 512, null);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        var filteredXContent = entity.getFilteredXContentObject();
+        filteredXContent.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, CoreMatchers.is("""
+            {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """
+            "dimensions":1024,"max_input_tokens":512}"""));
+    }
+
+    @Override
+    protected Writeable.Reader<AzureOpenAiEmbeddingsServiceSettings> instanceReader() {
+        return AzureOpenAiEmbeddingsServiceSettings::new;
+    }
+
+    @Override
+    protected AzureOpenAiEmbeddingsServiceSettings createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected AzureOpenAiEmbeddingsServiceSettings mutateInstance(AzureOpenAiEmbeddingsServiceSettings instance) throws IOException {
+        return createRandom();
+    }
+
+    public static Map<String, Object> getPersistentAzureOpenAiServiceSettingsMap(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        @Nullable Integer dimensions,
+        @Nullable Integer maxInputTokens
+    ) {
+        var map = new HashMap<String, Object>();
+
+        map.put(AzureOpenAiServiceFields.RESOURCE_NAME, resourceName);
+        map.put(AzureOpenAiServiceFields.DEPLOYMENT_ID, deploymentId);
+        map.put(AzureOpenAiServiceFields.API_VERSION, apiVersion);
+
+        if (dimensions != null) {
+            map.put(ServiceFields.DIMENSIONS, dimensions);
+            map.put(DIMENSIONS_SET_BY_USER, true);
+        } else {
+            map.put(DIMENSIONS_SET_BY_USER, false);
+        }
+
+        if (maxInputTokens != null) {
+            map.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens);
+        }
+
+        return map;
+    }
+
+    public static Map<String, Object> getRequestAzureOpenAiServiceSettingsMap(
+        String resourceName,
+        String deploymentId,
+        String apiVersion,
+        @Nullable Integer dimensions,
+        @Nullable Integer maxInputTokens
+    ) {
+        var map = new HashMap<String, Object>();
+
+        map.put(AzureOpenAiServiceFields.RESOURCE_NAME, resourceName);
+        map.put(AzureOpenAiServiceFields.DEPLOYMENT_ID, deploymentId);
+        map.put(AzureOpenAiServiceFields.API_VERSION, apiVersion);
+
+        if (dimensions != null) {
+            map.put(ServiceFields.DIMENSIONS, dimensions);
+        }
+
+        if (maxInputTokens != null) {
+            map.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens);
+        }
+
+        return map;
+    }
+}

+ 107 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java

@@ -0,0 +1,107 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.azureopenai.embeddings;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields;
+import org.hamcrest.MatcherAssert;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class AzureOpenAiEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase<AzureOpenAiEmbeddingsTaskSettings> {
+
+    public static AzureOpenAiEmbeddingsTaskSettings createRandomWithUser() {
+        return new AzureOpenAiEmbeddingsTaskSettings(randomAlphaOfLength(15));
+    }
+
+    /**
+     * The created settings can have the user set to null.
+     */
+    public static AzureOpenAiEmbeddingsTaskSettings createRandom() {
+        var user = randomBoolean() ? randomAlphaOfLength(15) : null;
+        return new AzureOpenAiEmbeddingsTaskSettings(user);
+    }
+
+    public void testFromMap_WithUser() {
+        assertEquals(
+            new AzureOpenAiEmbeddingsTaskSettings("user"),
+            AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user")))
+        );
+    }
+
+    public void testFromMap_UserIsEmptyString() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "")))
+        );
+
+        MatcherAssert.assertThat(
+            thrownException.getMessage(),
+            is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;"))
+        );
+    }
+
+    public void testFromMap_MissingUser_DoesNotThrowException() {
+        var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of()));
+        assertNull(taskSettings.user());
+    }
+
+    public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() {
+        var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user")));
+
+        var overriddenTaskSettings = AzureOpenAiEmbeddingsTaskSettings.of(
+            taskSettings,
+            AzureOpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS
+        );
+        MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings));
+    }
+
+    public void testOverrideWith_UsesOverriddenSettings() {
+        var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user")));
+
+        var requestTaskSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(
+            new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user2"))
+        );
+
+        var overriddenTaskSettings = AzureOpenAiEmbeddingsTaskSettings.of(taskSettings, requestTaskSettings);
+        MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureOpenAiEmbeddingsTaskSettings("user2")));
+    }
+
+    @Override
+    protected Writeable.Reader<AzureOpenAiEmbeddingsTaskSettings> instanceReader() {
+        return AzureOpenAiEmbeddingsTaskSettings::new;
+    }
+
+    @Override
+    protected AzureOpenAiEmbeddingsTaskSettings createTestInstance() {
+        return createRandomWithUser();
+    }
+
+    @Override
+    protected AzureOpenAiEmbeddingsTaskSettings mutateInstance(AzureOpenAiEmbeddingsTaskSettings instance) throws IOException {
+        return createRandomWithUser();
+    }
+
+    public static Map<String, Object> getAzureOpenAiRequestTaskSettingsMap(@Nullable String user) {
+        var map = new HashMap<String, Object>();
+
+        if (user != null) {
+            map.put(AzureOpenAiServiceFields.USER, user);
+        }
+
+        return map;
+    }
+}